diff --git a/.ci/docker/ci_commit_pins/torchao.txt b/.ci/docker/ci_commit_pins/torchao.txt index 768110b82ff..842571de2a3 100644 --- a/.ci/docker/ci_commit_pins/torchao.txt +++ b/.ci/docker/ci_commit_pins/torchao.txt @@ -1 +1 @@ -0916b5b29b092afcbf2b898caae49abe80662bac +c6abf2bd576828dc8ed175fba2c4c1d0d3681a1d diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index da832f8285a..06c294002ad 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -73,9 +73,12 @@ def quantize( if group_size is None: raise Exception("For 8da4w quantization, group size must be specified.") from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + from torchao.quantization.quant_primitives import MappingType model = Int8DynActInt4WeightQuantizer( - precision=torch_dtype, groupsize=group_size + precision=torch_dtype, + groupsize=group_size, + mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR, ).quantize(model) if verbose: print("quantized model:", model)