Skip to content

Commit c47d09d

Browse files
deit3_small_patch16_224_in21ft1k
1 parent 093c428 commit c47d09d

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

backends/openvino/quantizer/quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch.fx
1616
from torch.ao.quantization.observer import HistogramObserver
1717
from torch.ao.quantization.observer import PerChannelMinMaxObserver
18+
from torch.ao.quantization.observer import MinMaxObserver
1819
from torch.ao.quantization.quantizer.quantizer import EdgeOrNode
1920
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
2021
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
@@ -276,6 +277,7 @@ def _get_torch_ao_qspec_from_qp(qp: QuantizationPointBase) -> QuantizationSpec:
276277
torch.per_tensor_symmetric if qconfig.mode is QuantizationScheme.SYMMETRIC else torch.per_tensor_affine
277278
)
278279
if is_weight:
280+
observer = PerChannelMinMaxObserver if qconfig.per_channel else MinMaxObserver
279281
observer = PerChannelMinMaxObserver
280282
quant_min = -128
281283
quant_max = 127

examples/openvino/aot/aot_openvino_compiler.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,17 @@ def load_calibration_dataset(dataset_path: str):
6565
return calibration_dataset
6666

6767

68-
def quantize_model(model: torch.fx.GraphModule, calibration_dataset: torch.utils.data.DataLoader, subset_size=300):
69-
quantizer = OpenVINOQuantizer()
68+
def quantize_model(model: torch.fx.GraphModule, example_args, subset_size=300):
69+
quantizer = OpenVINOQuantizer(ignored_scope=nncf.IgnoredScope(types=["__getitem__", "layer_norm"]))
7070

7171
print("PTQ: Annotate the model...")
7272
annotated_model = prepare_pt2e(model, quantizer)
7373

7474
print("PTQ: Calibrate the model...")
75-
for idx, data in enumerate(calibration_dataset):
76-
if idx >= subset_size:
77-
break
78-
annotated_model(data[0])
75+
annotated_model(*example_args)
7976

8077
print("PTQ: Convert the quantized model...")
81-
quantized_model = convert_pt2e(annotated_model)
78+
quantized_model = convert_pt2e(annotated_model, fold_quantize=False)
8279
return quantized_model
8380

8481

@@ -106,7 +103,9 @@ def main(suite: str, model_name: str, input_shape, quantize: bool, dataset_path:
106103
calibration_dataset = load_calibration_dataset(dataset_path)
107104

108105
captured_model = aten_dialect.module()
109-
quantized_model = quantize_model(captured_model, calibration_dataset)
106+
visualize_fx_model(captured_model, f"{model_name}_fp32.svg")
107+
quantized_model = quantize_model(captured_model, example_args)
108+
visualize_fx_model(quantized_model, f"{model_name}_int8.svg")
110109
aten_dialect: ExportedProgram = export(quantized_model, example_args)
111110

112111
# Convert to edge dialect
@@ -121,9 +120,15 @@ def main(suite: str, model_name: str, input_shape, quantize: bool, dataset_path:
121120
exec_prog = lowered_module.to_executorch(config=executorch.exir.ExecutorchBackendConfig())
122121

123122
# Serialize and save it to a file
124-
with open(f"{model_name}.pte", "wb") as file:
123+
model_name = f"{model_name}_{'int8' if quantize else 'fp32'}.pte"
124+
with open(model_name, "wb") as file:
125125
exec_prog.write_to_file(file)
126-
print(f"Model exported and saved as {model_name}.pte on {device}.")
126+
print(f"Model exported and saved as {model_name} on {device}.")
127+
128+
from torch.fx.passes.graph_drawer import FxGraphDrawer
129+
def visualize_fx_model(model: torch.fx.GraphModule, output_svg_path: str):
130+
g = FxGraphDrawer(model, output_svg_path)
131+
g.get_dot_graph().write_svg(output_svg_path)
127132

128133
if __name__ == "__main__":
129134
# Argument parser for dynamic inputs

0 commit comments

Comments
 (0)