@@ -187,7 +187,7 @@ def test_flux2_consistency(self, seed=0):
187187
188188 flat_output = output .cpu ().flatten ()
189189 generated_slice = torch .cat ([flat_output [:8 ], flat_output [- 8 :]])
190- self .assertTrue (torch .allclose (generated_slice , expected_slice ))
190+ self .assertTrue (torch .allclose (generated_slice , expected_slice , atol = 1e-4 ))
191191
192192 def test_gradient_checkpointing_is_applied (self ):
193193 expected_set = {"Flux2Transformer2DModel" }
@@ -200,7 +200,7 @@ def test_lora_exclude_modules(self):
200200 from peft import LoraConfig , get_peft_model_state_dict , inject_adapter_in_model , set_peft_model_state_dict
201201
202202 lora_rank = 4
203- target_module = "single_transformer_blocks.0.proj_out "
203+ target_module = "single_transformer_blocks.0.attn.to_out "
204204 adapter_name = "foo"
205205 init_dict , _ = self .prepare_init_args_and_inputs_for_common ()
206206 model = self .model_class (** init_dict ).to (torch_device )
@@ -213,14 +213,14 @@ def test_lora_exclude_modules(self):
213213 }
214214 # Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
215215 config = LoraConfig (
216- r = lora_rank , target_modules = ["single_transformer_blocks.0.proj_out" ], exclude_modules = ["proj_out " ]
216+ r = lora_rank , target_modules = [target_module ], exclude_modules = ["to_out " ]
217217 )
218218 inject_adapter_in_model (config , model , adapter_name = adapter_name , state_dict = lora_state_dict )
219219 set_peft_model_state_dict (model , lora_state_dict , adapter_name )
220220 retrieved_lora_state_dict = get_peft_model_state_dict (model , adapter_name = adapter_name )
221221 assert len (retrieved_lora_state_dict ) == len (lora_state_dict )
222- assert (retrieved_lora_state_dict ["single_transformer_blocks.0.proj_out .lora_A.weight" ] == 22 ).all ()
223- assert (retrieved_lora_state_dict ["single_transformer_blocks.0.proj_out .lora_B.weight" ] == 33 ).all ()
222+ assert (retrieved_lora_state_dict [f" { target_module } .lora_A.weight" ] == 22 ).all ()
223+ assert (retrieved_lora_state_dict [f" { target_module } .lora_B.weight" ] == 33 ).all ()
224224
225225
226226class Flux2TransformerCompileTests (TorchCompileTesterMixin , unittest .TestCase ):
0 commit comments