Skip to content

Commit 3d3b90e

Browse files
Update Tests to use right version of CCT
1 parent 0adc5a3 commit 3d3b90e

File tree

5 files changed

+1399
-1
lines changed

5 files changed

+1399
-1
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ dist/
2525
*.gz
2626
*-ubyte
2727
*.pt
28-
.c*
2928
*.onnx
3029
*.npz
3130
onnx/*

DeepQuant/Utils/FixCTT2Graph.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright 2025 ETH Zurich and University of Bologna.
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Federico Brancasi <[email protected]>
6+
7+
"""
8+
Script to fix the CCTTrueQuantized ONNX model by duplicating shared constants.
9+
This resolves the issue where a single Floor constant (onnx::Floor_772) is shared
10+
across multiple bias quantization operations.
11+
"""
12+
13+
import argparse
14+
import os
15+
16+
import numpy as np
17+
import onnx
18+
from onnx import helper
19+
20+
21+
def fix_shared_constants(model_path, output_path):
22+
"""Fix shared constants in ONNX model by creating unique copies."""
23+
print(f"Loading ONNX model from: {model_path}")
24+
model = onnx.load(model_path)
25+
26+
graph = model.graph
27+
28+
shared_floor_tensor = None
29+
for initializer in graph.initializer:
30+
if initializer.name == "onnx::Floor_772":
31+
shared_floor_tensor = initializer
32+
break
33+
34+
if shared_floor_tensor is None:
35+
print("No shared Floor constant found. Model may already be fixed.")
36+
return False
37+
38+
print(f"Found shared Floor constant: {shared_floor_tensor.name}")
39+
print(f"Tensor shape: {shared_floor_tensor.dims}")
40+
41+
floor_nodes = []
42+
for node in graph.node:
43+
if node.op_type == "Floor":
44+
for input_name in node.input:
45+
if input_name == shared_floor_tensor.name:
46+
floor_nodes.append(node)
47+
break
48+
49+
print(f"Found {len(floor_nodes)} Floor nodes sharing the constant:")
50+
for node in floor_nodes:
51+
print(f" - {node.name}")
52+
53+
new_initializers = []
54+
for i, node in enumerate(floor_nodes):
55+
unique_name = f"Floor_772_unique_{i}_{node.name.replace('/', '_')}"
56+
57+
new_tensor = helper.make_tensor(
58+
name=unique_name,
59+
data_type=shared_floor_tensor.data_type,
60+
dims=shared_floor_tensor.dims,
61+
vals=(
62+
shared_floor_tensor.float_data
63+
if shared_floor_tensor.float_data
64+
else np.frombuffer(
65+
shared_floor_tensor.raw_data, dtype=np.float32
66+
).tolist()
67+
),
68+
)
69+
70+
new_initializers.append(new_tensor)
71+
72+
for j, input_name in enumerate(node.input):
73+
if input_name == shared_floor_tensor.name:
74+
node.input[j] = unique_name
75+
break
76+
77+
print(f" Created unique constant: {unique_name} for node: {node.name}")
78+
79+
graph.initializer.remove(shared_floor_tensor)
80+
81+
for new_tensor in new_initializers:
82+
graph.initializer.append(new_tensor)
83+
84+
inputs_to_remove = []
85+
for input_tensor in graph.input:
86+
if input_tensor.name == shared_floor_tensor.name:
87+
inputs_to_remove.append(input_tensor)
88+
89+
for input_tensor in inputs_to_remove:
90+
graph.input.remove(input_tensor)
91+
92+
try:
93+
onnx.checker.check_model(model)
94+
print("Model validation passed!")
95+
except Exception as e:
96+
print(f"Model validation failed: {e}")
97+
return False
98+
99+
print(f"Saving fixed model to: {output_path}")
100+
onnx.save(model, output_path)
101+
102+
return True
103+
104+
105+
def main():
106+
parser = argparse.ArgumentParser(
107+
description="Fix shared constants in CCTTrueQuantized ONNX model"
108+
)
109+
parser.add_argument(
110+
"--input",
111+
required=True,
112+
help="Path to input ONNX model",
113+
)
114+
parser.add_argument(
115+
"--output",
116+
required=True,
117+
help="Path to output fixed ONNX model",
118+
)
119+
120+
args = parser.parse_args()
121+
122+
if not os.path.exists(args.input):
123+
print(f"Error: Input file does not exist: {args.input}")
124+
return 1
125+
126+
success = fix_shared_constants(args.input, args.output)
127+
128+
if success:
129+
print("Successfully fixed the ONNX model!")
130+
print(f"Original model: {args.input}")
131+
print(f"Fixed model: {args.output}")
132+
133+
# FBRANCASI: Replace the original model with the fixed one
134+
backup_path = args.input + ".backup"
135+
print(f"Creating backup: {backup_path}")
136+
os.rename(args.input, backup_path)
137+
os.rename(args.output, args.input)
138+
print("Replaced original model with fixed version")
139+
140+
return 0
141+
else:
142+
print("Failed to fix the ONNX model")
143+
return 1
144+
145+
146+
if __name__ == "__main__":
147+
exit(main())

0 commit comments

Comments
 (0)