Skip to content

Commit 4da07a7

Browse files
committed
add resolution changes tests to hotswapping test suite.
1 parent 76ec3d1 commit 4da07a7

File tree

1 file changed

+36
-2
lines changed

1 file changed

+36
-2
lines changed

tests/models/test_modeling_common.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,7 +2046,9 @@ def get_linear_module_name_other_than_attn(self, model):
20462046
]
20472047
return linear_names[0]
20482048

2049-
def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
2049+
def check_model_hotswap(
2050+
self, do_compile, rank0, rank1, target_modules0, target_modules1=None, different_resolutions=None
2051+
):
20502052
"""
20512053
Check that hotswapping works on a small unet.
20522054
@@ -2056,6 +2058,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
20562058
- hotswap the second adapter
20572059
- check that the outputs are correct
20582060
- optionally compile the model
2061+
- optionally check if recompilations happen on different shapes
20592062
20602063
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
20612064
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
@@ -2110,10 +2113,17 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
21102113
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
21112114

21122115
if do_compile:
2113-
model = torch.compile(model, mode="reduce-overhead")
2116+
model = torch.compile(model, mode="reduce-overhead", dynamic=different_resolutions is not None)
21142117

21152118
with torch.inference_mode():
21162119
output0_after = model(**inputs_dict)["sample"]
2120+
2121+
# additionally check if dynamic compilation works.
2122+
if different_resolutions is not None:
2123+
for height, width in self.different_shapes_for_compilation:
2124+
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
2125+
_ = model(**new_inputs_dict)
2126+
21172127
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
21182128

21192129
# hotswap the 2nd adapter
@@ -2122,6 +2132,12 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
21222132
# we need to call forward to potentially trigger recompilation
21232133
with torch.inference_mode():
21242134
output1_after = model(**inputs_dict)["sample"]
2135+
2136+
if different_resolutions is not None:
2137+
for height, width in self.different_shapes_for_compilation:
2138+
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
2139+
_ = model(**new_inputs_dict)
2140+
21252141
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
21262142

21272143
# check error when not passing valid adapter name
@@ -2240,3 +2256,21 @@ def test_hotswap_second_adapter_targets_more_layers_raises(self):
22402256
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
22412257
)
22422258
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
2259+
2260+
@parameterized.expand([(11, 11), (7, 13), (13, 7)])
2261+
@require_torch_version_greater("2.7.1")
2262+
def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
2263+
different_shapes_for_compilation = self.different_shapes_for_compilation
2264+
if different_shapes_for_compilation is None:
2265+
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
2266+
torch.fx.experimental._config.use_duck_shape = False
2267+
2268+
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
2269+
with torch._dynamo.config.patch(error_on_recompile=True):
2270+
self.check_model_hotswap(
2271+
do_compile=True,
2272+
rank0=rank0,
2273+
rank1=rank1,
2274+
target_modules0=target_modules,
2275+
different_resolutions=different_shapes_for_compilation,
2276+
)

0 commit comments

Comments
 (0)