Skip to content

Commit 9f1c83f

Browse files
committed
fixes
1 parent 4da07a7 commit 9f1c83f

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

tests/models/test_modeling_common.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,7 +1350,6 @@ def test_model_parallelism(self):
13501350
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
13511351
# Making sure part of the model will actually end up offloaded
13521352
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
1353-
print(f" new_model.hf_device_map:{new_model.hf_device_map}")
13541353

13551354
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
13561355

@@ -2019,6 +2018,8 @@ class LoraHotSwappingForModelTesterMixin:
20192018
20202019
"""
20212020

2021+
different_shapes_for_compilation = None
2022+
20222023
def tearDown(self):
20232024
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
20242025
# there will be recompilation errors, as torch caches the model when run in the same process.
@@ -2116,29 +2117,27 @@ def check_model_hotswap(
21162117
model = torch.compile(model, mode="reduce-overhead", dynamic=different_resolutions is not None)
21172118

21182119
with torch.inference_mode():
2119-
output0_after = model(**inputs_dict)["sample"]
2120-
21212120
# additionally check if dynamic compilation works.
21222121
if different_resolutions is not None:
21232122
for height, width in self.different_shapes_for_compilation:
21242123
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
21252124
_ = model(**new_inputs_dict)
2126-
2127-
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
2125+
else:
2126+
output0_after = model(**inputs_dict)["sample"]
2127+
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
21282128

21292129
# hotswap the 2nd adapter
21302130
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
21312131

21322132
# we need to call forward to potentially trigger recompilation
21332133
with torch.inference_mode():
2134-
output1_after = model(**inputs_dict)["sample"]
2135-
21362134
if different_resolutions is not None:
21372135
for height, width in self.different_shapes_for_compilation:
21382136
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
21392137
_ = model(**new_inputs_dict)
2140-
2141-
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
2138+
else:
2139+
output1_after = model(**inputs_dict)["sample"]
2140+
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
21422141

21432142
# check error when not passing valid adapter name
21442143
name = "does-not-exist"

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ def prepare_dummy_input(self, height, width):
186186

187187
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
188188
model_class = FluxTransformer2DModel
189+
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
189190

190191
def prepare_init_args_and_inputs_for_common(self):
191192
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
193+
194+
def prepare_dummy_input(self, height, width):
195+
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)

0 commit comments

Comments
 (0)