|
| 1 | +# Copyright 2025 ETH Zurich and University of Bologna. |
| 2 | +# Licensed under the Apache License, Version 2.0, see LICENSE for details. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +# |
| 5 | +# Federico Brancasi <[email protected]> |
| 6 | + |
| 7 | +""" |
| 8 | +Script to fix the CCTTrueQuantized ONNX model by duplicating shared constants. |
| 9 | +This resolves the issue where a single Floor constant (onnx::Floor_772) is shared |
| 10 | +across multiple bias quantization operations. |
| 11 | +""" |
| 12 | + |
| 13 | +import argparse |
| 14 | +import os |
| 15 | + |
| 16 | +import numpy as np |
| 17 | +import onnx |
| 18 | +from onnx import helper |
| 19 | + |
| 20 | + |
| 21 | +def fix_shared_constants(model_path, output_path): |
| 22 | + """Fix shared constants in ONNX model by creating unique copies.""" |
| 23 | + print(f"Loading ONNX model from: {model_path}") |
| 24 | + model = onnx.load(model_path) |
| 25 | + |
| 26 | + graph = model.graph |
| 27 | + |
| 28 | + shared_floor_tensor = None |
| 29 | + for initializer in graph.initializer: |
| 30 | + if initializer.name == "onnx::Floor_772": |
| 31 | + shared_floor_tensor = initializer |
| 32 | + break |
| 33 | + |
| 34 | + if shared_floor_tensor is None: |
| 35 | + print("No shared Floor constant found. Model may already be fixed.") |
| 36 | + return False |
| 37 | + |
| 38 | + print(f"Found shared Floor constant: {shared_floor_tensor.name}") |
| 39 | + print(f"Tensor shape: {shared_floor_tensor.dims}") |
| 40 | + |
| 41 | + floor_nodes = [] |
| 42 | + for node in graph.node: |
| 43 | + if node.op_type == "Floor": |
| 44 | + for input_name in node.input: |
| 45 | + if input_name == shared_floor_tensor.name: |
| 46 | + floor_nodes.append(node) |
| 47 | + break |
| 48 | + |
| 49 | + print(f"Found {len(floor_nodes)} Floor nodes sharing the constant:") |
| 50 | + for node in floor_nodes: |
| 51 | + print(f" - {node.name}") |
| 52 | + |
| 53 | + new_initializers = [] |
| 54 | + for i, node in enumerate(floor_nodes): |
| 55 | + unique_name = f"Floor_772_unique_{i}_{node.name.replace('/', '_')}" |
| 56 | + |
| 57 | + new_tensor = helper.make_tensor( |
| 58 | + name=unique_name, |
| 59 | + data_type=shared_floor_tensor.data_type, |
| 60 | + dims=shared_floor_tensor.dims, |
| 61 | + vals=( |
| 62 | + shared_floor_tensor.float_data |
| 63 | + if shared_floor_tensor.float_data |
| 64 | + else np.frombuffer( |
| 65 | + shared_floor_tensor.raw_data, dtype=np.float32 |
| 66 | + ).tolist() |
| 67 | + ), |
| 68 | + ) |
| 69 | + |
| 70 | + new_initializers.append(new_tensor) |
| 71 | + |
| 72 | + for j, input_name in enumerate(node.input): |
| 73 | + if input_name == shared_floor_tensor.name: |
| 74 | + node.input[j] = unique_name |
| 75 | + break |
| 76 | + |
| 77 | + print(f" Created unique constant: {unique_name} for node: {node.name}") |
| 78 | + |
| 79 | + graph.initializer.remove(shared_floor_tensor) |
| 80 | + |
| 81 | + for new_tensor in new_initializers: |
| 82 | + graph.initializer.append(new_tensor) |
| 83 | + |
| 84 | + inputs_to_remove = [] |
| 85 | + for input_tensor in graph.input: |
| 86 | + if input_tensor.name == shared_floor_tensor.name: |
| 87 | + inputs_to_remove.append(input_tensor) |
| 88 | + |
| 89 | + for input_tensor in inputs_to_remove: |
| 90 | + graph.input.remove(input_tensor) |
| 91 | + |
| 92 | + try: |
| 93 | + onnx.checker.check_model(model) |
| 94 | + print("Model validation passed!") |
| 95 | + except Exception as e: |
| 96 | + print(f"Model validation failed: {e}") |
| 97 | + return False |
| 98 | + |
| 99 | + print(f"Saving fixed model to: {output_path}") |
| 100 | + onnx.save(model, output_path) |
| 101 | + |
| 102 | + return True |
| 103 | + |
| 104 | + |
| 105 | +def main(): |
| 106 | + parser = argparse.ArgumentParser( |
| 107 | + description="Fix shared constants in CCTTrueQuantized ONNX model" |
| 108 | + ) |
| 109 | + parser.add_argument( |
| 110 | + "--input", |
| 111 | + required=True, |
| 112 | + help="Path to input ONNX model", |
| 113 | + ) |
| 114 | + parser.add_argument( |
| 115 | + "--output", |
| 116 | + required=True, |
| 117 | + help="Path to output fixed ONNX model", |
| 118 | + ) |
| 119 | + |
| 120 | + args = parser.parse_args() |
| 121 | + |
| 122 | + if not os.path.exists(args.input): |
| 123 | + print(f"Error: Input file does not exist: {args.input}") |
| 124 | + return 1 |
| 125 | + |
| 126 | + success = fix_shared_constants(args.input, args.output) |
| 127 | + |
| 128 | + if success: |
| 129 | + print("Successfully fixed the ONNX model!") |
| 130 | + print(f"Original model: {args.input}") |
| 131 | + print(f"Fixed model: {args.output}") |
| 132 | + |
| 133 | + # FBRANCASI: Replace the original model with the fixed one |
| 134 | + backup_path = args.input + ".backup" |
| 135 | + print(f"Creating backup: {backup_path}") |
| 136 | + os.rename(args.input, backup_path) |
| 137 | + os.rename(args.output, args.input) |
| 138 | + print("Replaced original model with fixed version") |
| 139 | + |
| 140 | + return 0 |
| 141 | + else: |
| 142 | + print("Failed to fix the ONNX model") |
| 143 | + return 1 |
| 144 | + |
| 145 | + |
| 146 | +if __name__ == "__main__": |
| 147 | + exit(main()) |
0 commit comments