55
55
F16Type ,
56
56
F32Type ,
57
57
F64Type ,
58
+ Float8E4M3FNUZType ,
58
59
Float8E4M3FNType ,
59
60
Float8E5M2FNUZType ,
60
61
Float8E5M2Type ,
@@ -643,7 +644,7 @@ def get_list_element_type(self, tp: onnx.TypeProto) -> IrType:
643
644
if tt .elem_type :
644
645
element_type = self .tensor_element_type (tt .elem_type )
645
646
dims = tuple (
646
- (d .dim_value if not d . dim_param else None ) for d in tt .shape .dim
647
+ (d .dim_value if d . HasField ( "dim_value" ) else None ) for d in tt .shape .dim
647
648
)
648
649
shape_asm = "," .join ("?" if d is None else str (d ) for d in dims )
649
650
return f"vtensor<[{ shape_asm } ],{ element_type } >"
@@ -656,7 +657,7 @@ def get_optional_element_type(self, tp: onnx.TypeProto) -> IrType:
656
657
if tt .elem_type :
657
658
element_type = self .tensor_element_type (tt .elem_type )
658
659
dims = tuple (
659
- (d .dim_value if not d . dim_param else None ) for d in tt .shape .dim
660
+ (d .dim_value if d . HasField ( "dim_value" ) else None ) for d in tt .shape .dim
660
661
)
661
662
shape_asm = "," .join ("?" if d is None else str (d ) for d in dims )
662
663
return f"vtensor<[{ shape_asm } ],{ element_type } >"
@@ -707,13 +708,15 @@ def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType:
707
708
708
709
tt = tp .tensor_type
709
710
if tt .elem_type :
710
- if not tt .shape :
711
- raise OnnxImportError (
712
- f"Unsupported Tensor type without shape (run shape inference?): { tp } "
713
- )
714
711
element_type = self .tensor_element_type (tt .elem_type )
715
712
dims = tuple (
716
- (d .dim_value if not d .dim_param else None ) for d in tt .shape .dim
713
+ # NOTE: dynamic dimension can either be denoted by d.dim_param being set
714
+ # (and d.dim_value consequently not set) or
715
+ # by neither d.dim_value nor d.dim_param being set. Also note that
716
+ # d.dim_value being 0 corresponds to the protobuf default when the field
717
+ # is not set.
718
+ d .dim_value if d .HasField ("dim_value" ) else None
719
+ for d in tt .shape .dim
717
720
)
718
721
return self .get_vtensor_type (dims , element_type )
719
722
@@ -1097,7 +1100,7 @@ def get_operator_function(
1097
1100
onnx .TensorProto .DataType .COMPLEX128 : lambda : ComplexType .get (F64Type .get ()),
1098
1101
onnx .TensorProto .DataType .BFLOAT16 : lambda : BF16Type .get (),
1099
1102
onnx .TensorProto .DataType .FLOAT8E4M3FN : lambda : Float8E4M3FNType .get (),
1100
- onnx .TensorProto .DataType .FLOAT8E4M3FNUZ : lambda : Float8E5M2FNUZType .get (),
1103
+ onnx .TensorProto .DataType .FLOAT8E4M3FNUZ : lambda : Float8E4M3FNUZType .get (),
1101
1104
onnx .TensorProto .DataType .FLOAT8E5M2 : lambda : Float8E5M2Type .get (),
1102
1105
onnx .TensorProto .DataType .FLOAT8E5M2FNUZ : lambda : Float8E5M2FNUZType .get (),
1103
1106
onnx .TensorProto .DataType .STRING : lambda : "!torch.str" ,
0 commit comments