Skip to content

Commit 25d3cf8

Browse files
committed
add torch compile test
1 parent 747bd7d commit 25d3cf8

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

tests/quantization/torchao/test_torchao.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,21 +154,31 @@ def get_dummy_inputs(self, device: torch.device, seed: int = 0):
154154

155155
return inputs
156156

157-
def get_dummy_tensor_inputs(self, device=None):
157+
def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
158158
batch_size = 1
159159
num_latent_channels = 4
160160
num_image_channels = 3
161161
height = width = 4
162162
sequence_length = 48
163163
embedding_dim = 32
164164

165+
torch.manual_seed(seed)
165166
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
167+
168+
torch.manual_seed(seed)
166169
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
167170
device, dtype=torch.bfloat16
168171
)
172+
173+
torch.manual_seed(seed)
169174
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
175+
176+
torch.manual_seed(seed)
170177
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
178+
179+
torch.manual_seed(seed)
171180
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
181+
172182
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
173183

174184
return {
@@ -322,6 +332,22 @@ def test_training(self):
322332
self.assertTrue(module.adapter[1].weight.grad is not None)
323333
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
324334

335+
def test_torch_compile(self):
336+
quantization_config = TorchAoConfig("int8_weight_only")
337+
components = self.get_dummy_components(quantization_config)
338+
pipe = FluxPipeline(**components)
339+
pipe.to(device=torch_device, dtype=torch.bfloat16)
340+
341+
inputs = self.get_dummy_inputs(torch_device)
342+
normal_output = pipe(**inputs)[0].flatten()[-32:]
343+
344+
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False)
345+
inputs = self.get_dummy_inputs(torch_device)
346+
compile_output = pipe(**inputs)[0].flatten()[-32:]
347+
348+
# Note: Seems to require higher tolerance
349+
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
350+
325351

326352
@require_torch
327353
@require_torch_gpu

0 commit comments

Comments
 (0)