|  | 
| 12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
| 13 | 13 | # See the License for the specific language governing permissions and | 
| 14 | 14 | # limitations under the License. | 
|  | 15 | +import copy | 
| 15 | 16 | import gc | 
| 16 | 17 | import os | 
| 17 | 18 | import sys | 
| @@ -162,6 +163,56 @@ def test_with_alpha_in_state_dict(self): | 
| 162 | 163 |         ) | 
| 163 | 164 |         self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) | 
| 164 | 165 | 
 | 
|  | 166 | +    def test_lora_expansion_works_for_absent_keys(self): | 
|  | 167 | +        components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) | 
|  | 168 | +        pipe = self.pipeline_class(**components) | 
|  | 169 | +        pipe = pipe.to(torch_device) | 
|  | 170 | +        pipe.set_progress_bar_config(disable=None) | 
|  | 171 | +        _, _, inputs = self.get_dummy_inputs(with_generator=False) | 
|  | 172 | + | 
|  | 173 | +        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images | 
|  | 174 | +        self.assertTrue(output_no_lora.shape == self.output_shape) | 
|  | 175 | + | 
|  | 176 | +        # Modify the config to have a layer which won't be present in the second LoRA we will load. | 
|  | 177 | +        modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) | 
|  | 178 | +        modified_denoiser_lora_config.target_modules.add("x_embedder") | 
|  | 179 | + | 
|  | 180 | +        pipe.transformer.add_adapter(modified_denoiser_lora_config) | 
|  | 181 | +        self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") | 
|  | 182 | + | 
|  | 183 | +        images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images | 
|  | 184 | +        self.assertFalse( | 
|  | 185 | +            np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3), | 
|  | 186 | +            "LoRA should lead to different results.", | 
|  | 187 | +        ) | 
|  | 188 | + | 
|  | 189 | +        with tempfile.TemporaryDirectory() as tmpdirname: | 
|  | 190 | +            denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) | 
|  | 191 | +            self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) | 
|  | 192 | + | 
|  | 193 | +            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) | 
|  | 194 | +            pipe.unload_lora_weights() | 
|  | 195 | +            # Modify the state dict to exclude "x_embedder" related LoRA params. | 
|  | 196 | +            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) | 
|  | 197 | +            lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} | 
|  | 198 | +            pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") | 
|  | 199 | + | 
|  | 200 | +            # Load state dict with `x_embedder`. | 
|  | 201 | +            pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") | 
|  | 202 | + | 
|  | 203 | +        pipe.set_adapters(["one", "two"]) | 
|  | 204 | +        self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") | 
|  | 205 | +        images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images | 
|  | 206 | + | 
|  | 207 | +        self.assertFalse( | 
|  | 208 | +            np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), | 
|  | 209 | +            "Different LoRAs should lead to different results.", | 
|  | 210 | +        ) | 
|  | 211 | +        self.assertFalse( | 
|  | 212 | +            np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), | 
|  | 213 | +            "LoRA should lead to different results.", | 
|  | 214 | +        ) | 
|  | 215 | + | 
| 165 | 216 |     @unittest.skip("Not supported in Flux.") | 
| 166 | 217 |     def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): | 
| 167 | 218 |         pass | 
|  | 
0 commit comments