Skip to content

Commit 05fcddb

Browse files
committed
Get remaining Flux 2 transformer tests passing
1 parent f6059b7 commit 05fcddb

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/models/transformers/test_models_transformer_flux2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def test_flux2_consistency(self, seed=0):
187187

188188
flat_output = output.cpu().flatten()
189189
generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
190-
self.assertTrue(torch.allclose(generated_slice, expected_slice))
190+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4))
191191

192192
def test_gradient_checkpointing_is_applied(self):
193193
expected_set = {"Flux2Transformer2DModel"}
@@ -200,7 +200,7 @@ def test_lora_exclude_modules(self):
200200
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
201201

202202
lora_rank = 4
203-
target_module = "single_transformer_blocks.0.proj_out"
203+
target_module = "single_transformer_blocks.0.attn.to_out"
204204
adapter_name = "foo"
205205
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
206206
model = self.model_class(**init_dict).to(torch_device)
@@ -213,14 +213,14 @@ def test_lora_exclude_modules(self):
213213
}
214214
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
215215
config = LoraConfig(
216-
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
216+
r=lora_rank, target_modules=[target_module], exclude_modules=["to_out"]
217217
)
218218
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
219219
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
220220
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
221221
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
222-
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
223-
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
222+
assert (retrieved_lora_state_dict[f"{target_module}.lora_A.weight"] == 22).all()
223+
assert (retrieved_lora_state_dict[f"{target_module}.lora_B.weight"] == 33).all()
224224

225225

226226
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):

0 commit comments

Comments
 (0)