Skip to content

Commit 579fb76

Browse files
committed
fix
1 parent 2076a53 commit 579fb76

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

tests/models/test_modeling_common.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,9 +2047,7 @@ def get_linear_module_name_other_than_attn(self, model):
20472047
]
20482048
return linear_names[0]
20492049

2050-
def check_model_hotswap(
2051-
self, do_compile, rank0, rank1, target_modules0, target_modules1=None, different_resolutions=None
2052-
):
2050+
def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
20532051
"""
20542052
Check that hotswapping works on a small unet.
20552053
@@ -2065,6 +2063,7 @@ def check_model_hotswap(
20652063
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
20662064
fine.
20672065
"""
2066+
different_shapes = self.different_shapes_for_compilation
20682067
# create 2 adapters with different ranks and alphas
20692068
torch.manual_seed(0)
20702069
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -2114,12 +2113,12 @@ def check_model_hotswap(
21142113
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
21152114

21162115
if do_compile:
2117-
model = torch.compile(model, mode="reduce-overhead", dynamic=different_resolutions is not None)
2116+
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
21182117

21192118
with torch.inference_mode():
21202119
# additionally check if dynamic compilation works.
2121-
if different_resolutions is not None:
2122-
for height, width in self.different_shapes_for_compilation:
2120+
if different_shapes is not None:
2121+
for height, width in different_shapes:
21232122
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
21242123
_ = model(**new_inputs_dict)
21252124
else:
@@ -2131,8 +2130,8 @@ def check_model_hotswap(
21312130

21322131
# we need to call forward to potentially trigger recompilation
21332132
with torch.inference_mode():
2134-
if different_resolutions is not None:
2135-
for height, width in self.different_shapes_for_compilation:
2133+
if different_shapes is not None:
2134+
for height, width in different_shapes:
21362135
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
21372136
_ = model(**new_inputs_dict)
21382137
else:
@@ -2274,5 +2273,4 @@ def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
22742273
rank0=rank0,
22752274
rank1=rank1,
22762275
target_modules0=target_modules,
2277-
different_resolutions=different_shapes_for_compilation,
22782276
)

0 commit comments

Comments
 (0)