Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/source/en/optimization/fp16.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,23 @@ pipeline(prompt, num_inference_steps=30).images[0]

Compilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient.

### Compilation on shape changes

`torch.compile()` maintains a stack of "guards" for the shapes and conditions it sees when it is triggered. When that is violated, the compiler triggers recompilation. This means that if a model was compiled on the 1024x1024 resolution, for example, it will trigger recompilation if it is called on a different resolution.

In these cases, it's beneficial to compile with `dynamic=True`:

```diff
+ torch.fx.experimental._config.use_duck_shape = False
+ pipeline.unet = torch.compile(
pipeline.unet, fullgraph=True, dynamic=True
)
```

Make sure to always use the nightly version of PyTorch for this. Specifying `use_duck_shape` to be `False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out [this comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).

All models might not benefit from this out of the box and may require changes. Refer to [this PR](https://github.com/huggingface/diffusers/pull/11297/) that improved the implementation of [`AuraFlowPipeline`] to benefit from compilation with `dynamic=True`. Feel free to open an issue if dynamic compilation doesn't work expected for a model inside Diffusers.

### Regional compilation

[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks.
Expand Down
24 changes: 21 additions & 3 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
require_torch_accelerator_with_training,
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_version_greater,
run_test_in_subprocess,
slow,
torch_all_close,
Expand Down Expand Up @@ -1908,6 +1909,8 @@ def test_push_to_hub_library_name(self):
@is_torch_compile
@slow
class TorchCompileTesterMixin:
different_shapes_for_compilation = None

def setUp(self):
# clean up the VRAM before each test
super().setUp()
Expand Down Expand Up @@ -1937,14 +1940,14 @@ def test_torch_compile_recompilation_and_graph_break(self):
_ = model(**inputs_dict)

def test_compile_with_group_offloading(self):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")

torch._dynamo.config.cache_size_limit = 10000

init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)

if not getattr(model, "_supports_group_offloading", True):
return

model.eval()
# TODO: Can test for other group offloading kwargs later if needed.
group_offload_kwargs = {
Expand All @@ -1961,6 +1964,21 @@ def test_compile_with_group_offloading(self):
_ = model(**inputs_dict)
_ = model(**inputs_dict)

@require_torch_version_greater("2.7.1")
def test_compile_on_different_shapes(self):
if self.different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
torch.fx.experimental._config.use_duck_shape = False

init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True, dynamic=True)

for height, width in self.different_shapes_for_compilation:
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**inputs_dict)


@slow
@require_torch_2
Expand Down
24 changes: 15 additions & 9 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,20 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):

@property
def dummy_input(self):
return self.prepare_dummy_input()

@property
def input_shape(self):
return (16, 4)

@property
def output_shape(self):
return (16, 4)

def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32

Expand All @@ -114,14 +124,6 @@ def dummy_input(self):
"timestep": timestep,
}

@property
def input_shape(self):
return (16, 4)

@property
def output_shape(self):
return (16, 4)

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
Expand Down Expand Up @@ -173,10 +175,14 @@ def test_gradient_checkpointing_is_applied(self):

class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]

def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()

def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)


class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
Expand Down
Loading