@@ -1216,24 +1216,30 @@ def get_tensor_shape(tensor: onnx.TensorProto) -> List[int]:
12161216 return [dim .dim_value for dim in tensor .type .tensor_type .shape .dim ]
12171217
12181218
1219- def get_tensor_dim_shape (tensor : onnx .TensorProto , dim : int ) -> int :
1219+ def get_tensor_dim_shape (tensor : onnx .TensorProto , dim : Union [ int , str ] ) -> int :
12201220 """
12211221 :param tensor: ONNX tensor to get the shape of a dimension of
12221222 :param dim: dimension index of the tensor to get the shape of
12231223 :return: shape of the tensor at the given dimension
12241224 """
1225- return tensor .type .tensor_type .shape .dim [dim ].dim_value
1225+ return (
1226+ tensor .type .tensor_type .shape .dim [dim ].dim_value
1227+ or tensor .type .tensor_type .shape .dim [dim ].dim_param
1228+ )
12261229
12271230
1228- def set_tensor_dim_shape (tensor : onnx .TensorProto , dim : int , value : int ):
1231+ def set_tensor_dim_shape (tensor : onnx .TensorProto , dim : int , value : Union [ int , str ] ):
12291232 """
12301233 Sets the shape of the tensor at the given dimension to the given value
12311234
12321235 :param tensor: ONNX tensor to modify the shape of
12331236 :param dim: dimension index of the tensor to modify the shape of
12341237 :param value: new shape for the given dimension
12351238 """
1236- tensor .type .tensor_type .shape .dim [dim ].dim_value = value
1239+ if isinstance (value , str ):
1240+ tensor .type .tensor_type .shape .dim [dim ].dim_param = value
1241+ else :
1242+ tensor .type .tensor_type .shape .dim [dim ].dim_value = value
12371243
12381244
12391245def override_model_input_shape (model : Union [str , onnx .ModelProto ], shape : List [int ]):
0 commit comments