Skip to content
18 changes: 18 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
require_torch_accelerator_with_training,
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_version_greater,
run_test_in_subprocess,
slow,
torch_all_close,
Expand Down Expand Up @@ -1908,6 +1909,8 @@ def test_push_to_hub_library_name(self):
@is_torch_compile
@slow
class TorchCompileTesterMixin:
different_shapes_for_compilation = None

def setUp(self):
# clean up the VRAM before each test
super().setUp()
Expand Down Expand Up @@ -1961,6 +1964,21 @@ def test_compile_with_group_offloading(self):
_ = model(**inputs_dict)
_ = model(**inputs_dict)

@require_torch_version_greater("2.7.1")
def test_compile_on_different_shapes(self):
if self.different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
torch.fx.experimental._config.use_duck_shape = False

init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True, dynamic=True)

for height, width in self.different_shapes_for_compilation:
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**inputs_dict)


@slow
@require_torch_2
Expand Down
24 changes: 15 additions & 9 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,20 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):

@property
def dummy_input(self):
return self.prepare_dummy_input()

@property
def input_shape(self):
return (16, 4)

@property
def output_shape(self):
return (16, 4)

def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32

Expand All @@ -114,14 +124,6 @@ def dummy_input(self):
"timestep": timestep,
}

@property
def input_shape(self):
return (16, 4)

@property
def output_shape(self):
return (16, 4)

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
Expand Down Expand Up @@ -173,10 +175,14 @@ def test_gradient_checkpointing_is_applied(self):

class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]

def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()

def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)


class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
Expand Down
Loading