Skip to content

Commit dbc78a4

Browse files
committed
move comment.
1 parent de312da commit dbc78a4

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/models/test_modeling_common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,20 +1858,20 @@ def test_hotswapping_compiled_model_linear(self, rank0, rank1):
18581858

18591859
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
18601860
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
1861-
# It's important to add this context to raise an error on recompilation
18621861
if "unet" not in self.model_class.__name__.lower():
18631862
return
18641863

1864+
# It's important to add this context to raise an error on recompilation
18651865
target_modules = ["conv", "conv1", "conv2"]
18661866
with torch._dynamo.config.patch(error_on_recompile=True):
18671867
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
18681868

18691869
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
18701870
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
1871-
# It's important to add this context to raise an error on recompilation
18721871
if "unet" not in self.model_class.__name__.lower():
18731872
return
18741873

1874+
# It's important to add this context to raise an error on recompilation
18751875
target_modules = ["to_q", "conv"]
18761876
with torch._dynamo.config.patch(error_on_recompile=True):
18771877
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
@@ -1882,14 +1882,14 @@ def test_hotswapping_compiled_model_both_linear_and_other(self, rank0, rank1):
18821882
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
18831883
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
18841884
# block.
1885-
# It's important to add this context to raise an error on recompilation
18861885
target_modules = ["to_q"]
18871886
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
18881887
model = self.model_class(**init_dict)
18891888

18901889
target_modules.append(self.get_linear_module_name_other_than_attn(model))
18911890
del model
18921891

1892+
# It's important to add this context to raise an error on recompilation
18931893
with torch._dynamo.config.patch(error_on_recompile=True):
18941894
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
18951895

0 commit comments

Comments
 (0)