Skip to content

Commit 0d3af90

Browse files
authored
Merge branch 'main' into benchmarking-overhaul
2 parents 36afdea + 9836f0e commit 0d3af90

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

docs/source/en/optimization/torch2.0.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,23 @@ For more information and different options about `torch.compile`, refer to the [
7878
> [!TIP]
7979
> Learn more about other ways PyTorch 2.0 can help optimize your model in the [Accelerate inference of text-to-image diffusion models](../tutorials/fast_diffusion) tutorial.
8080
81+
### Regional compilation
82+
83+
Compiling the whole model usually has a big problem space for optimization. Models are often composed of multiple repeated blocks. [Regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html) compiles the repeated block first (a transformer encoder block, for example), so that the Torch compiler would re-use its cached/optimized generated code for the other blocks, reducing (often massively) the cold start compilation time observed on the first inference call.
84+
85+
Enabling regional compilation might require simple yet intrusive changes to the
86+
modeling code. However, 🤗 Accelerate provides a utility [`compile_regions()`](https://huggingface.co/docs/accelerate/main/en/usage_guides/compilation#how-to-use-regional-compilation) which automatically compiles
87+
the repeated blocks of the provided `nn.Module` sequentially, and the rest of the model separately. This helps with reducing cold start time while keeping most (if not all) of the speedup you would get from full compilation.
88+
89+
```py
90+
# Make sure you're on the latest `accelerate`: `pip install -U accelerate`.
91+
from accelerate.utils import compile_regions
92+
93+
pipe.unet = compile_regions(pipe.unet, mode="reduce-overhead", fullgraph=True)
94+
```
95+
96+
As you may have noticed `compile_regions()` takes the same arguments as `torch.compile()`, allowing flexibility.
97+
8198
## Benchmark
8299

83100
We conducted a comprehensive benchmark with PyTorch 2.0's efficient attention implementation and `torch.compile` across different GPUs and batch sizes for five of our most used pipelines. The code is benchmarked on 🤗 Diffusers v0.17.0.dev0 to optimize `torch.compile` usage (see [here](https://github.com/huggingface/diffusers/pull/3313) for more details).

tests/models/test_modeling_common.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,34 @@ def run_forward(model):
15801580
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
15811581
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
15821582

1583+
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
1584+
@require_torch_accelerator
1585+
@torch.no_grad()
1586+
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
1587+
torch.manual_seed(0)
1588+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1589+
model = self.model_class(**init_dict)
1590+
1591+
if not getattr(model, "_supports_group_offloading", True):
1592+
return
1593+
1594+
model.to(torch_device)
1595+
model.eval()
1596+
_ = model(**inputs_dict)[0]
1597+
1598+
torch.manual_seed(0)
1599+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1600+
storage_dtype, compute_dtype = torch.float16, torch.float32
1601+
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
1602+
model = self.model_class(**init_dict)
1603+
model.eval()
1604+
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
1605+
model.enable_group_offload(
1606+
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
1607+
)
1608+
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
1609+
_ = model(**inputs_dict)[0]
1610+
15831611
def test_auto_model(self, expected_max_diff=5e-5):
15841612
if self.forward_requires_fresh_args:
15851613
model = self.model_class(**self.init_dict)

0 commit comments

Comments
 (0)