Skip to content

Commit f681c58

Browse files
committed
fix: preserve original dtype for all-zero input type inference
1 parent 3eda2dc commit f681c58

File tree

4 files changed

+29
-16
lines changed

4 files changed

+29
-16
lines changed

Deeploy/DeeployTypes.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3105,16 +3105,9 @@ def _exportGraph(self, folderPath, fileName):
31053105
raise OSError(f"Error exporting the context to: {absoluteOnnxPath}")
31063106

31073107
# VJUNG: ONNX-Graphsurgeon needs tensors to be in their export types
3108-
<<<<<<< HEAD
3109-
constTensors = [tensor for tensor in self.graph.tensors().values() if isinstance(tensor, gs.Constant)]
3110-
for tensor in constTensors:
3111-
if tensor.dtype != tensor.export_dtype:
3112-
=======
3113-
# Added hasattr check for compatibility with different onnx-graphsurgeon versions
31143108
constTensors = [tensor for tensor in self.graph.tensors().values() if isinstance(tensor, gs.Constant)]
31153109
for tensor in constTensors:
31163110
if hasattr(tensor, 'export_dtype') and tensor.dtype != tensor.export_dtype:
3117-
>>>>>>> 937e3cb3 (refactor: restore Snitch framework code to origin/devel)
31183111
tensor.values = tensor.values.astype(tensor.export_dtype)
31193112

31203113
model = gs.export_onnx(self.graph)

DeeployTest/generateNetwork.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def generateNetwork(args):
7373
test_inputs, test_outputs, graph = generateDebugConfig(inputs, outputs, activations, graph)
7474

7575
else:
76-
# Load as float64 and infer types later
76+
# Load as float64 for uniform handling, but preserve original dtypes for type inference
77+
test_input_original_dtypes = [inputs[x].dtype for x in inputs.files]
7778
test_inputs = [inputs[x].reshape(-1).astype(np.float64) for x in inputs.files]
7879
test_outputs = [outputs[x].reshape(-1).astype(np.float64) for x in outputs.files]
7980

@@ -122,7 +123,8 @@ def generateNetwork(args):
122123

123124
_type = PointerClass(_type)
124125
else:
125-
_type, offset = inferTypeAndOffset(values, signProp)
126+
original_dtype = test_input_original_dtypes[index] if index < len(test_input_original_dtypes) else None
127+
_type, offset = inferTypeAndOffset(values, signProp, original_dtype = original_dtype)
126128

127129
inputTypes[f"input_{index}"] = _type
128130
inputOffsets[f"input_{index}"] = offset

DeeployTest/testMVP.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def setupDeployer(graph: gs.Graph, memoryHierarchy: MemoryHierarchy, defaultTarg
6868
inputs = np.load(f'{args.dir}/inputs.npz')
6969
tensors = graph.tensors()
7070

71-
# Load as int64 and infer types later
71+
# Load as float64 for uniform handling, but preserve original dtypes for type inference
72+
test_input_original_dtypes = [inputs[x].dtype for x in inputs.files]
7273
test_inputs = [inputs[x].reshape(-1).astype(np.float64) for x in inputs.files]
7374

7475
platform, signProp = mapPlatform(args.platform)
@@ -83,7 +84,8 @@ def setupDeployer(graph: gs.Graph, memoryHierarchy: MemoryHierarchy, defaultTarg
8384
cluster.n_cores = args.cores
8485

8586
for index, num in enumerate(test_inputs):
86-
_type, offset = inferTypeAndOffset(num, signProp)
87+
original_dtype = test_input_original_dtypes[index] if index < len(test_input_original_dtypes) else None
88+
_type, offset = inferTypeAndOffset(num, signProp, original_dtype = original_dtype)
8789
inputTypes[f"input_{index}"] = _type
8890
inputOffsets[f"input_{index}"] = offset
8991

@@ -241,7 +243,8 @@ def setupDeployer(graph: gs.Graph, memoryHierarchy: MemoryHierarchy, defaultTarg
241243
if args.debug:
242244
test_inputs, test_outputs, graph = generateDebugConfig(inputs, outputs, activations, graph)
243245
else:
244-
# Load as int64 and infer types later
246+
# Load as float64 for uniform handling, but preserve original dtypes for type inference
247+
test_input_original_dtypes = [inputs[x].dtype for x in inputs.files]
245248
test_inputs = [inputs[x].reshape(-1).astype(np.float64) for x in inputs.files]
246249
test_outputs = [outputs[x].reshape(-1).astype(np.float64) for x in outputs.files]
247250

@@ -280,7 +283,8 @@ def setupDeployer(graph: gs.Graph, memoryHierarchy: MemoryHierarchy, defaultTarg
280283
log.debug(f"Deployer: {deployer}")
281284

282285
for index, num in enumerate(test_inputs):
283-
_type, offset = inferTypeAndOffset(num, signProp)
286+
original_dtype = test_input_original_dtypes[index] if index < len(test_input_original_dtypes) else None
287+
_type, offset = inferTypeAndOffset(num, signProp, original_dtype = original_dtype)
284288
inputTypes[f"input_{index}"] = _type
285289
inputOffsets[f"input_{index}"] = offset
286290

DeeployTest/testUtils/typeMapping.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,20 @@ def isInteger(x: npt.NDArray) -> bool:
4242
return np.abs((x.astype(int) - x)).max() <= 0.001
4343

4444

45-
def inferMinimalType(values: np.ndarray, default: Type[BaseType] = int8_t) -> Type[BaseType]:
45+
def inferMinimalType(values: np.ndarray,
46+
default: Type[BaseType] = int8_t,
47+
original_dtype: np.dtype = None) -> Type[BaseType]:
4648
# WIESEP: We cannot do type inference for empty arrays.
4749
if np.prod(values.shape) == 0:
4850
print(f"Warning: Empty input array for type inference for {values}!")
4951
return default
5052

53+
# For all-zero arrays, use original dtype to distinguish int vs float
54+
if np.all(values == 0) and original_dtype is not None:
55+
if np.issubdtype(original_dtype, np.floating):
56+
return minimalFloatType(values)
57+
return minimalIntegerType(values)
58+
5159
if isInteger(values):
5260
return minimalIntegerType(values)
5361
else:
@@ -67,7 +75,9 @@ def signPropTypeAndOffset(_type: Type[IntegerImmediate]) -> Tuple[Type[IntegerIm
6775
return signedType, 2**(signedType.typeWidth - 1)
6876

6977

70-
def inferTypeAndOffset(values: np.ndarray, signProp: bool = False) -> Tuple[Type[Pointer], int]:
78+
def inferTypeAndOffset(values: np.ndarray,
79+
signProp: bool = False,
80+
original_dtype: np.dtype = None) -> Tuple[Type[Pointer], int]:
7181
"""Infers the data type of the provided input array.
7282
7383
Parameters
@@ -77,13 +87,17 @@ def inferTypeAndOffset(values: np.ndarray, signProp: bool = False) -> Tuple[Type
7787
7888
signProp : bool
7989
Whether to consider signedness when inferring the data type.
90+
91+
original_dtype : np.dtype, optional
92+
Original numpy dtype before float64 cast, used to resolve all-zero ambiguity.
93+
8094
Returns
8195
-------
8296
Tuple[Type[BaseType], int]
8397
The inferred type and offset
8498
"""
8599

86-
_type = inferMinimalType(values)
100+
_type = inferMinimalType(values, original_dtype = original_dtype)
87101

88102
if signProp and issubclass(_type, IntegerImmediate):
89103
_type, offset = signPropTypeAndOffset(_type)

0 commit comments

Comments
 (0)