Skip to content

Commit de97a51

Browse files
committed
improve test check
1 parent 55d6155 commit de97a51

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

docs/source/en/quantization/torchao.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. -->
1111

1212
# torchao
1313

14-
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
14+
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
1515

1616
Before you begin, make sure you have Pytorch version 2.5, or above, and TorchAO installed:
1717

@@ -21,7 +21,7 @@ pip install -U torch torchao
2121

2222
## Usage
2323

24-
Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
24+
Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. Loading pre-quantized models is supported as well! This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
2525

2626
```python
2727
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

tests/quantization/torchao/test_torchao.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def forward(self, input, *args, **kwargs):
7474
if is_torchao_available():
7575
from torchao.dtypes import AffineQuantizedTensor
7676
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
77+
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
7778

7879

7980
@require_torch
@@ -494,6 +495,11 @@ def check_serialization_expected_slice(self, expected_slice):
494495
output = loaded_quantized_model(**inputs)[0]
495496

496497
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
498+
self.assertTrue(
499+
isinstance(
500+
loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)
501+
)
502+
)
497503
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
498504

499505
def test_serialization_expected_slice(self):

0 commit comments

Comments
 (0)