Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/en/optimization/torch2.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,24 @@ For more information and different options about `torch.compile`, refer to the [
> [!TIP]
> Learn more about other ways PyTorch 2.0 can help optimize your model in the [Accelerate inference of text-to-image diffusion models](../tutorials/fast_diffusion) tutorial.

### Regional compilation

Compiling the whole model usually has a big problem space for optimization. Models are often composed of multiple repeated blocks. [Regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html) compiles the repeated block first (a transformer encoder block, for example), so that the Torch compiler would re-use its cached/optimized generated code for the other blocks, reducing (often massively) the cold start compilation time observed on the first inference call.

Enabling regional compilation might require simple yet intrusive changes to the
modeling code. However, 🤗 Accelerate provides a utility [`compile_regions()`](https://huggingface.co/docs/accelerate/main/en/usage_guides/compilation#how-to-use-regional-compilation) which automatically _only_ compiles
the repeated blocks of the provided `nn.Module`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no we actually compile the rest of the model as well 😅 I found out in my post that some people thought only the encoder/decoder block will be compiled in regional, which is not true.
I changed the docs to be more explicit huggingface/accelerate#3572 (comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👁️ But https://docs.pytorch.org/tutorials/recipes/regional_compilation.html suggests a completely different recipe no? No full compilation but only regional and I always thought that is what should be done.

What am I missing?

Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil May 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regional compilation is simply: cut into regions and then compile those regions. I didn't compare the two approaches but I believe in the context of the pytorch tutorial they were simply trying to reduce cold start, not trying to keep inference optimized as well (they didn't benchamrk inference).

Copy link
Member Author

@sayakpaul sayakpaul May 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So

  1. inference latency of compiling full model >= inference latency of regionally compiling repeated blocks + compiling additional blocks in a model
  2. cold start time of compiling full model >> cold start time of regionally compiling repeated blocks + compiling additional blocks in a model

Is my understanding right or is it still fragmented?

Do you think providing an option to NOT compile the rest of the blocks could still make sense?

Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil May 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that is how it works !

Do you think providing an option to NOT compile the rest of the blocks could still make sense?

doesn't make sense for me personally, since you will miss on the tuning of the task-specific head. Do you have any specific cases where we don't want to compile the rest of the model ?


```py
# Make sure you're on the latest `accelerate`: `pip install -U accelerate`.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge after accelerate new version is released this week.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

released !

from accelerate.utils import compile_regions

pipe.unet = compile_regions(pipe.unet, mode="reduce-overhead", fullgraph=True)
```

As you may have noticed `compile_regions()` takes the same arguments as `torch.compile()`, allowing
flexibility.

## Benchmark

We conducted a comprehensive benchmark with PyTorch 2.0's efficient attention implementation and `torch.compile` across different GPUs and batch sizes for five of our most used pipelines. The code is benchmarked on 🤗 Diffusers v0.17.0.dev0 to optimize `torch.compile` usage (see [here](https://github.com/huggingface/diffusers/pull/3313) for more details).
Expand Down