Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit faebb8b

Browse files
authored
[cherry-pick-1.4.4][ONNX] patch previous commit to accept path-like objects (#1475) (#1476)
* [ONNX] override_model_input_shape helper function (#1471) * [ONNX] patch previous commit to accept path-like objects (#1475) * bump version to 1.4.4
1 parent cfbfedf commit faebb8b

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/sparseml/onnx/utils/helpers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
"get_tensor_shape",
7575
"get_tensor_dim_shape",
7676
"set_tensor_dim_shape",
77+
"override_model_input_shape",
7778
]
7879

7980

@@ -1233,3 +1234,25 @@ def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: int):
12331234
:param value: new shape for the given dimension
12341235
"""
12351236
tensor.type.tensor_type.shape.dim[dim].dim_value = value
1237+
1238+
1239+
def override_model_input_shape(model: Union[str, onnx.ModelProto], shape: List[int]):
1240+
"""
1241+
Set the shape of the first input of the given model to the given shape.
1242+
If given a file, the file will be overwritten
1243+
1244+
:param model: ONNX model or model path to overrwrite
1245+
:param shape: shape as list of integers to override with. must match
1246+
existing dimensions
1247+
"""
1248+
if not isinstance(model, onnx.ModelProto):
1249+
model_path = model
1250+
model = onnx.load(model)
1251+
else:
1252+
model_path = None
1253+
1254+
for dim, dim_size in enumerate(shape):
1255+
set_tensor_dim_shape(model.graph.input[0], dim, dim_size)
1256+
1257+
if model_path:
1258+
onnx.save(model, model_path)

src/sparseml/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from datetime import date
2020

2121

22-
version_base = "1.4.3"
22+
version_base = "1.4.4"
2323
is_release = False # change to True to set the generated version as a release version
2424

2525

0 commit comments

Comments
 (0)