Skip to content

Commit 35fc83f

Browse files
authored
[onnx_importer.py] Fix dim_value None not correctly processed and missing Float8E4M3FNUZType. (#4037)
As per title. Changes tested on SHARK-TestSuite's `alt_e2eshark`.
1 parent 2c7a639 commit 35fc83f

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

python/torch_mlir/extras/onnx_importer.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
F16Type,
5656
F32Type,
5757
F64Type,
58+
Float8E4M3FNUZType,
5859
Float8E4M3FNType,
5960
Float8E5M2FNUZType,
6061
Float8E5M2Type,
@@ -643,7 +644,7 @@ def get_list_element_type(self, tp: onnx.TypeProto) -> IrType:
643644
if tt.elem_type:
644645
element_type = self.tensor_element_type(tt.elem_type)
645646
dims = tuple(
646-
(d.dim_value if not d.dim_param else None) for d in tt.shape.dim
647+
(d.dim_value if d.HasField("dim_value") else None) for d in tt.shape.dim
647648
)
648649
shape_asm = ",".join("?" if d is None else str(d) for d in dims)
649650
return f"vtensor<[{shape_asm}],{element_type}>"
@@ -656,7 +657,7 @@ def get_optional_element_type(self, tp: onnx.TypeProto) -> IrType:
656657
if tt.elem_type:
657658
element_type = self.tensor_element_type(tt.elem_type)
658659
dims = tuple(
659-
(d.dim_value if not d.dim_param else None) for d in tt.shape.dim
660+
(d.dim_value if d.HasField("dim_value") else None) for d in tt.shape.dim
660661
)
661662
shape_asm = ",".join("?" if d is None else str(d) for d in dims)
662663
return f"vtensor<[{shape_asm}],{element_type}>"
@@ -707,13 +708,15 @@ def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType:
707708

708709
tt = tp.tensor_type
709710
if tt.elem_type:
710-
if not tt.shape:
711-
raise OnnxImportError(
712-
f"Unsupported Tensor type without shape (run shape inference?): {tp}"
713-
)
714711
element_type = self.tensor_element_type(tt.elem_type)
715712
dims = tuple(
716-
(d.dim_value if not d.dim_param else None) for d in tt.shape.dim
713+
# NOTE: dynamic dimension can either be denoted by d.dim_param being set
714+
# (and d.dim_value consequently not set) or
715+
# by neither d.dim_value nor d.dim_param being set. Also note that
716+
# d.dim_value being 0 corresponds to the protobuf default when the field
717+
# is not set.
718+
d.dim_value if d.HasField("dim_value") else None
719+
for d in tt.shape.dim
717720
)
718721
return self.get_vtensor_type(dims, element_type)
719722

@@ -1097,7 +1100,7 @@ def get_operator_function(
10971100
onnx.TensorProto.DataType.COMPLEX128: lambda: ComplexType.get(F64Type.get()),
10981101
onnx.TensorProto.DataType.BFLOAT16: lambda: BF16Type.get(),
10991102
onnx.TensorProto.DataType.FLOAT8E4M3FN: lambda: Float8E4M3FNType.get(),
1100-
onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: lambda: Float8E5M2FNUZType.get(),
1103+
onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: lambda: Float8E4M3FNUZType.get(),
11011104
onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(),
11021105
onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(),
11031106
onnx.TensorProto.DataType.STRING: lambda: "!torch.str",

python/torch_mlir/tools/import_onnx/__main__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,19 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
8686
raw_model = onnx.load(args.input_file, load_external_data=False)
8787
onnx.load_external_data_for_model(raw_model, str(args.data_dir))
8888

89+
raw_model_modified = False
90+
8991
if args.opset_version:
9092
raw_model = onnx.version_converter.convert_version(
9193
raw_model, args.opset_version
9294
)
95+
raw_model_modified = True
9396

9497
if args.clear_domain:
9598
graph = raw_model.graph
9699
for n in graph.node:
97100
n.ClearField("domain")
101+
raw_model_modified = True
98102

99103
# Run the checker to test whether the file is above the threshold for
100104
# in-memory shape inference. If not, go ahead and do the shape inference.
@@ -119,9 +123,15 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
119123

120124
# Model is too big for in-memory inference: do file-based shape inference
121125
# to a temp file.
126+
# First need to save as model when it has been changed (e.g. version conversion).
127+
if raw_model_modified:
128+
temp_raw_file = temp_dir / "raw.onnx"
129+
onnx.save(raw_model, temp_raw_file, save_as_external_data=True)
122130
temp_inferred_file = temp_dir / "inferred.onnx"
123131
onnx.shape_inference.infer_shapes_path(
124-
args.input_file, temp_inferred_file, data_prop=args.data_prop
132+
temp_raw_file if raw_model_modified else args.input_file,
133+
temp_inferred_file,
134+
data_prop=args.data_prop,
125135
)
126136

127137
# Sanity check the shape-inferred model to be sure we have a good model

0 commit comments

Comments
 (0)