Skip to content

Commit 934624b

Browse files
committed
update
1 parent bc83cb8 commit 934624b

File tree

2 files changed

+26
-18
lines changed

2 files changed

+26
-18
lines changed

tests/models/test_modeling_common.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
require_torch_accelerator_with_training,
7777
require_torch_gpu,
7878
require_torch_multi_accelerator,
79+
require_torch_version_greater,
7980
run_test_in_subprocess,
8081
slow,
8182
torch_all_close,
@@ -1908,6 +1909,8 @@ def test_push_to_hub_library_name(self):
19081909
@is_torch_compile
19091910
@slow
19101911
class TorchCompileTesterMixin:
1912+
different_shapes_for_compilation = None
1913+
19111914
def setUp(self):
19121915
# clean up the VRAM before each test
19131916
super().setUp()
@@ -1961,21 +1964,20 @@ def test_compile_with_group_offloading(self):
19611964
_ = model(**inputs_dict)
19621965
_ = model(**inputs_dict)
19631966

1967+
@require_torch_version_greater("2.7.1")
19641968
def test_compile_on_different_shapes(self):
1969+
if self.different_shapes_for_compilation is None:
1970+
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
19651971
torch.fx.experimental._config.use_duck_shape = False
19661972

1967-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1973+
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
19681974
model = self.model_class(**init_dict).to(torch_device)
19691975
model = torch.compile(model, fullgraph=True, dynamic=True)
19701976

1971-
with (
1972-
torch._inductor.utils.fresh_inductor_cache(),
1973-
torch._dynamo.config.patch(error_on_recompile=True),
1974-
torch.no_grad(),
1975-
):
1976-
print(f"{inputs_dict.keys()=}")
1977-
out = model(**inputs_dict)
1978-
assert out is None
1977+
for height, width in self.different_shapes_for_compilation:
1978+
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
1979+
inputs_dict = self.prepare_dummy_input(height=height, width=width)
1980+
_ = model(**inputs_dict)
19791981

19801982

19811983
@slow

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,20 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
9191

9292
@property
9393
def dummy_input(self):
94+
return self.prepare_dummy_input()
95+
96+
@property
97+
def input_shape(self):
98+
return (16, 4)
99+
100+
@property
101+
def output_shape(self):
102+
return (16, 4)
103+
104+
def prepare_dummy_input(self, height=4, width=4):
94105
batch_size = 1
95106
num_latent_channels = 4
96107
num_image_channels = 3
97-
height = width = 4
98108
sequence_length = 48
99109
embedding_dim = 32
100110

@@ -114,14 +124,6 @@ def dummy_input(self):
114124
"timestep": timestep,
115125
}
116126

117-
@property
118-
def input_shape(self):
119-
return (16, 4)
120-
121-
@property
122-
def output_shape(self):
123-
return (16, 4)
124-
125127
def prepare_init_args_and_inputs_for_common(self):
126128
init_dict = {
127129
"patch_size": 1,
@@ -173,10 +175,14 @@ def test_gradient_checkpointing_is_applied(self):
173175

174176
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
175177
model_class = FluxTransformer2DModel
178+
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
176179

177180
def prepare_init_args_and_inputs_for_common(self):
178181
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
179182

183+
def prepare_dummy_input(self, height, width):
184+
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
185+
180186

181187
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
182188
model_class = FluxTransformer2DModel

0 commit comments

Comments
 (0)