Skip to content

Commit 2cbc76c

Browse files
Handle Deterministic Session for ORT and Update Tests
1 parent 01a60c5 commit 2cbc76c

File tree

5 files changed

+322
-218
lines changed

5 files changed

+322
-218
lines changed

DeepQuant/Pipeline/OnnxExport.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,30 @@
1616
from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc
1717

1818

19+
def create_deterministic_session():
20+
"""
21+
Create ONNX Runtime session with deterministic settings for exact reproducibility.
22+
"""
23+
options = ort.SessionOptions()
24+
25+
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
26+
27+
options.use_deterministic_compute = True
28+
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
29+
30+
options.intra_op_num_threads = 1
31+
options.inter_op_num_threads = 1
32+
33+
options.enable_cpu_mem_arena = False
34+
options.enable_mem_pattern = False
35+
options.enable_mem_reuse = False
36+
37+
options.log_severity_level = 3
38+
options.enable_profiling = False
39+
40+
return options
41+
42+
1943
def exportToOnnx(
2044
model: nn.Module,
2145
exampleInput: torch.Tensor,
@@ -50,7 +74,11 @@ def exportToOnnx(
5074
print()
5175
print(cc.success(f"Input data saved to {inputFile}"))
5276

53-
ortSession = ort.InferenceSession(onnxFile)
77+
options = create_deterministic_session()
78+
# ortSession = ort.InferenceSession(onnxFile)
79+
ortSession = ort.InferenceSession(
80+
onnxFile, sess_options=options, providers=["CPUExecutionProvider"]
81+
)
5482
ortInputs = {"input": exampleInput.cpu().numpy()}
5583
ortOutput = ortSession.run(None, ortInputs)[0]
5684

Tests/TestCCT.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -123,24 +123,7 @@ def prepareCCT(model) -> nn.Module:
123123
),
124124
}
125125

126-
quantActMap = {
127-
nn.ReLU: (
128-
qnn.QuantReLU,
129-
{
130-
"act_quant": Uint8ActPerTensorFloat,
131-
"return_quant_tensor": True,
132-
"bit_width": 8,
133-
},
134-
),
135-
nn.GELU: (
136-
qnn.QuantReLU,
137-
{
138-
"act_quant": Uint8ActPerTensorFloat,
139-
"return_quant_tensor": True,
140-
"bit_width": 8,
141-
},
142-
),
143-
}
126+
quantActMap = {}
144127

145128
quantIdentityMap = {
146129
"signed": (

Tests/TestResNet18.py

Lines changed: 13 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,12 @@
44
#
55
# Federico Brancasi <[email protected]>
66

7-
import tarfile
8-
import urllib.request
9-
from pathlib import Path
107

118
import brevitas.nn as qnn
129
import pytest
1310
import torch
1411
import torch.nn as nn
15-
import torchvision
16-
import torchvision.transforms as transforms
17-
from brevitas.graph.calibrate import calibration_mode
12+
import torchvision.models as models
1813
from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool
1914
from brevitas.graph.quantize import preprocess_for_quantize, quantize
2015
from brevitas.quant import (
@@ -23,79 +18,15 @@
2318
Int32Bias,
2419
Uint8ActPerTensorFloat,
2520
)
26-
from torch.utils.data import DataLoader, Subset
27-
from torchvision.datasets import ImageFolder
28-
from tqdm import tqdm
2921

3022
from DeepQuant import brevitasToTrueQuant
3123

3224

33-
def evaluateModel(model, dataLoader, evalDevice, name="Model"):
34-
model.eval()
35-
correctTop1 = 0
36-
correctTop5 = 0
37-
total = 0
38-
39-
with torch.no_grad():
40-
for inputs, targets in tqdm(dataLoader, desc=f"Evaluating {name}"):
41-
isTQ = "TQ" in name
42-
43-
if isTQ:
44-
# FBRANCASI: Process different batches for the TQ model
45-
for i in range(inputs.size(0)):
46-
singleInput = inputs[i : i + 1].to(evalDevice)
47-
singleOutput = model(singleInput)
48-
49-
_, predicted = singleOutput.max(1)
50-
if predicted.item() == targets[i].item():
51-
correctTop1 += 1
52-
53-
_, top5Pred = singleOutput.topk(5, dim=1, largest=True, sorted=True)
54-
if targets[i].item() in top5Pred[0].cpu().numpy():
55-
correctTop5 += 1
56-
57-
total += 1
58-
else:
59-
inputs = inputs.to(evalDevice)
60-
targets = targets.to(evalDevice)
61-
output = model(inputs)
62-
63-
_, predicted = output.max(1)
64-
correctTop1 += (predicted == targets).sum().item()
65-
66-
_, top5Pred = output.topk(5, dim=1, largest=True, sorted=True)
67-
for i in range(targets.size(0)):
68-
if targets[i] in top5Pred[i]:
69-
correctTop5 += 1
70-
71-
total += targets.size(0)
72-
73-
top1Accuracy = 100.0 * correctTop1 / total
74-
top5Accuracy = 100.0 * correctTop5 / total
75-
76-
print(
77-
f"{name} - Top-1 Accuracy: {top1Accuracy:.2f}% ({correctTop1}/{total}), "
78-
f"Top-5 Accuracy: {top5Accuracy:.2f}%"
79-
)
80-
81-
return top1Accuracy, top5Accuracy
82-
83-
84-
def calibrateModel(model, calibLoader):
85-
model.eval()
86-
with torch.no_grad(), calibration_mode(model):
87-
for inputs, _ in tqdm(calibLoader, desc="Calibrating model"):
88-
inputs = inputs.to("cpu")
89-
model(inputs)
90-
print("Calibration completed.")
91-
92-
93-
def prepareFQResNet18():
25+
def prepareResnet18Model() -> nn.Module:
9426
"""Prepare a fake-quantized (FQ) ResNet18 model."""
95-
baseModel = torchvision.models.resnet18(
96-
weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
97-
)
98-
baseModel = baseModel.eval().to("cpu")
27+
baseModel = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
28+
29+
baseModel = baseModel.eval()
9930

10031
computeLayerMap = {
10132
nn.Conv2d: (
@@ -126,16 +57,7 @@ def prepareFQResNet18():
12657
),
12758
}
12859

129-
quantActMap = {
130-
nn.ReLU: (
131-
qnn.QuantReLU,
132-
{
133-
"act_quant": Uint8ActPerTensorFloat,
134-
"return_quant_tensor": True,
135-
"bit_width": 8,
136-
},
137-
),
138-
}
60+
quantActMap = {}
13961

14062
quantIdentityMap = {
14163
"signed": (
@@ -156,133 +78,25 @@ def prepareFQResNet18():
15678
),
15779
}
15880

159-
dummyInput = torch.ones(1, 3, 224, 224).to("cpu")
160-
161-
print("Preprocessing model for quantization...")
16281
baseModel = preprocess_for_quantize(
16382
baseModel, equalize_iters=20, equalize_scale_computation="range"
16483
)
84+
baseModel = AdaptiveAvgPoolToAvgPool().apply(baseModel, torch.ones(1, 3, 224, 224))
16585

166-
print("Converting AdaptiveAvgPool to AvgPool...")
167-
baseModel = AdaptiveAvgPoolToAvgPool().apply(baseModel, dummyInput)
168-
169-
print("Quantizing model...")
170-
FQModel = quantize(
86+
quantizedResnet = quantize(
17187
graph_model=baseModel,
17288
compute_layer_map=computeLayerMap,
17389
quant_act_map=quantActMap,
17490
quant_identity_map=quantIdentityMap,
17591
)
17692

177-
return FQModel
93+
return quantizedResnet
17894

17995

18096
@pytest.mark.ModelTests
18197
def deepQuantTestResnet18() -> None:
182-
HOME = Path.home()
183-
BASE = HOME / "Documents" / "ImagenetV2"
184-
TAR_URL = (
185-
"https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/"
186-
"imagenetv2-matched-frequency.tar.gz"
187-
)
188-
TAR_PATH = BASE / "imagenetv2-matched-frequency.tar.gz"
189-
EXTRACT_DIR = BASE / "imagenetv2-matched-frequency-format-val"
190-
191-
if not TAR_PATH.exists():
192-
BASE.mkdir(parents=True, exist_ok=True)
193-
print(f"Downloading ImageNetV2 from {TAR_URL}...")
194-
urllib.request.urlretrieve(TAR_URL, TAR_PATH)
195-
196-
if not EXTRACT_DIR.exists():
197-
print(f"Extracting to {EXTRACT_DIR}...")
198-
with tarfile.open(TAR_PATH, "r:*") as tar:
199-
for member in tqdm(tar.getmembers(), desc="Extracting files"):
200-
tar.extract(member, BASE)
201-
print("Extraction completed.")
202-
203-
transformsVal = transforms.Compose(
204-
[
205-
transforms.Resize(256),
206-
transforms.CenterCrop(224),
207-
transforms.ToTensor(),
208-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
209-
]
210-
)
211-
212-
dataset = ImageFolder(root=str(EXTRACT_DIR), transform=transformsVal)
213-
dataset.classes = sorted(dataset.classes, key=lambda x: int(x))
214-
dataset.class_to_idx = {cls: i for i, cls in enumerate(dataset.classes)}
215-
216-
newSamples = []
217-
for path, _ in dataset.samples:
218-
clsName = Path(path).parent.name
219-
newLabel = dataset.class_to_idx[clsName]
220-
newSamples.append((path, newLabel))
221-
dataset.samples = newSamples
222-
dataset.targets = [s[1] for s in newSamples]
223-
224-
# FBRANCASI: Optional, reduce number of example for faster validation
225-
DATASET_LIMIT = 256
226-
dataset = Subset(dataset, list(range(DATASET_LIMIT)))
227-
print(f"Validation dataset size set to {len(dataset)} images.")
228-
229-
calibLoader = DataLoader(
230-
Subset(dataset, list(range(256))), batch_size=32, shuffle=False, pin_memory=True
231-
)
232-
valLoader = DataLoader(dataset, batch_size=32, shuffle=False, pin_memory=True)
233-
234-
# FBRANCASI: I'm on mac, so mps for me
235-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
236-
device = torch.device("mps" if torch.backends.mps.is_available() else device)
237-
print(f"Using device: {device}")
238-
239-
originalModel = torchvision.models.resnet18(
240-
weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
241-
)
242-
originalModel = originalModel.eval().to(device)
243-
print("Original ResNet18 loaded.")
244-
245-
print("Evaluating original model...")
246-
originalTop1, originalTop5 = evaluateModel(
247-
originalModel, valLoader, device, "Original ResNet18"
248-
)
249-
250-
print("Preparing and quantizing ResNet18...")
251-
FQModel = prepareFQResNet18()
252-
253-
print("Calibrating FQ model...")
254-
calibrateModel(FQModel, calibLoader)
255-
256-
print("Evaluating FQ model...")
257-
# FBRANCASI: I'm on mac, mps doesn't work with brevitas
258-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
259-
FQTop1, FQTop5 = evaluateModel(FQModel, valLoader, device, "FQ ResNet18")
260-
261-
sampleInputImg = torch.randn(1, 3, 224, 224).to("cpu")
262-
TQModel = brevitasToTrueQuant(FQModel, sampleInputImg, debug=True)
263-
264-
numParameters = sum(p.numel() for p in TQModel.parameters())
265-
print(f"Number of parameters: {numParameters:,}")
266-
267-
print("Evaluating TQ model...")
268-
TQTop1, TQTop5 = evaluateModel(TQModel, valLoader, device, "TQ ResNet18")
269-
270-
print("\nComparison Summary:")
271-
print(f"{'Model':<25} {'Top-1 Accuracy':<25} {'Top-5 Accuracy':<25}")
272-
print("-" * 75)
273-
print(f"{'Original ResNet18':<25} {originalTop1:<24.2f} {originalTop5:<24.2f}")
274-
print(f"{'FQ ResNet18':<25} {FQTop1:<24.2f} {FQTop5:<24.2f}")
275-
print(f"{'TQ ResNet18':<25} {TQTop1:<24.2f} {TQTop5:<24.2f}")
276-
print(
277-
f"{'FQ Drop':<25} {originalTop1 - FQTop1:<24.2f} {originalTop5 - FQTop5:<24.2f}"
278-
)
279-
print(
280-
f"{'TQ Drop':<25} {originalTop1 - TQTop1:<24.2f} {originalTop5 - TQTop5:<24.2f}"
281-
)
28298

283-
if abs(FQTop1 - TQTop1) > 5.0 or abs(FQTop5 - TQTop5) > 5.0:
284-
print(
285-
f"Warning: Large accuracy drop between FQ and TQ models. "
286-
f"Top-1 difference: {abs(FQTop1 - TQTop1):.2f}%, "
287-
f"Top-5 difference: {abs(FQTop5 - TQTop5):.2f}%"
288-
)
99+
torch.manual_seed(42)
100+
quantizedModel = prepareResnet18Model()
101+
sampleInput = torch.randn(1, 3, 224, 224)
102+
brevitasToTrueQuant(quantizedModel, sampleInput, debug=True)

0 commit comments

Comments
 (0)