Skip to content

Commit 0e4f152

Browse files
committed
add group offloading+compile
1 parent 11cfd6c commit 0e4f152

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

tests/quantization/bnb/test_4bit.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,8 @@ def test_torch_compile(self):
879879

880880
def test_torch_compile_with_cpu_offload(self):
881881
torch._dynamo.config.capture_dynamic_output_shape_ops = True
882-
super()._test_torch_compile_with_cpu_offload(
883-
quantization_config=self.quantization_config, torch_dtype=torch.float16
884-
)
882+
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
883+
884+
def test_torch_compile_with_group_offload(self):
885+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
886+
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,3 +796,9 @@ def test_torch_compile_with_cpu_offload(self):
796796
super()._test_torch_compile_with_cpu_offload(
797797
quantization_config=self.quantization_config, torch_dtype=torch.float16
798798
)
799+
800+
def test_torch_compile_with_group_offload(self):
801+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
802+
super()._test_torch_compile_with_group_offload(
803+
quantization_config=self.quantization_config, torch_dtype=torch.float16
804+
)

tests/quantization/test_torch_compile_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,20 @@ def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=
6363
for _ in range(2):
6464
# small resolutions to ensure speedy execution.
6565
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
66+
67+
def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
68+
pipe = self._init_pipeline(quantization_config, torch_dtype)
69+
group_offload_kwargs = {
70+
"onload_device": "cuda",
71+
"offload_device": "cpu",
72+
"offload_type": "block_level",
73+
"num_blocks_per_group": 1,
74+
"use_stream": True,
75+
"non_blocking": True,
76+
}
77+
pipe.enable_group_offload(**group_offload_kwargs)
78+
pipe.transformer.compile()
79+
80+
for _ in range(2):
81+
# small resolutions to ensure speedy execution.
82+
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)

0 commit comments

Comments
 (0)