Skip to content

Commit 2bd9302

Browse files
committed
fix
1 parent fb8b44e commit 2bd9302

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/quantization/torchao/test_torchao.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,13 +624,13 @@ def _test_quant_type(self, quantization_config, expected_slice):
624624
components = self.get_dummy_components(quantization_config)
625625
pipe = FluxPipeline(**components)
626626
pipe.enable_model_cpu_offload()
627+
628+
weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight
629+
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
627630

628631
inputs = self.get_dummy_inputs(torch_device)
629632
output = pipe(**inputs)[0].flatten()
630633
output_slice = np.concatenate((output[:16], output[-16:]))
631-
632-
weight = pipe.transformer.x_embedder.weight
633-
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
634634
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
635635

636636
def test_quantization(self):

0 commit comments

Comments
 (0)