|
25 | 25 | import numpy as np
|
26 | 26 | import onnx
|
27 | 27 | import onnx_graphsurgeon as gs
|
28 |
| -from onnx import TensorProto, ValueInfoProto, numpy_helper |
29 | 28 | from onnx.helper import get_attribute_value
|
30 | 29 | from onnx_graphsurgeon import Constant, Node, Variable
|
31 | 30 |
|
@@ -289,7 +288,7 @@ def _convert_types_to_np(types: dict[str, int] | list[int] | int) -> Any:
|
289 | 288 |
|
290 | 289 | def get_tensor_by_name(
|
291 | 290 | onnx_model: onnx.ModelProto, tensor_name: str
|
292 |
| -) -> ValueInfoProto | TensorProto | None: |
| 291 | +) -> onnx.ValueInfoProto | onnx.TensorProto | None: |
293 | 292 | """This function returns a tensor from its name.
|
294 | 293 |
|
295 | 294 | This function searches for a tensor in the model's:
|
@@ -438,7 +437,7 @@ def randomize_weights_onnx_bytes(onnx_bytes: bytes, seed: int = 0) -> bytes:
|
438 | 437 | numpy_array = np.random.normal(float(avg), float(var), size=init.dims).astype(
|
439 | 438 | dtype
|
440 | 439 | )
|
441 |
| - tensor = numpy_helper.from_array(numpy_array, init.name) |
| 440 | + tensor = onnx.numpy_helper.from_array(numpy_array, init.name) |
442 | 441 | model.graph.initializer[idx].CopyFrom(tensor)
|
443 | 442 |
|
444 | 443 | buffer = io.BytesIO()
|
@@ -751,3 +750,53 @@ def onnx_type_str_to_enum(dtype: str) -> int:
|
751 | 750 | dtype = dtype.split("tensor(")[-1].split(")")[0]
|
752 | 751 | dtype = "FLOAT" if dtype == "float32" else dtype.upper()
|
753 | 752 | return getattr(onnx.TensorProto, dtype)
|
| 753 | + |
| 754 | + |
| 755 | +def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto: |
| 756 | + """Remove `training_mode` attribute and extra training outputs from nodes of a given op type. |
| 757 | +
|
| 758 | + This also removes the unused outputs from the training_mode nodes. |
| 759 | +
|
| 760 | + Args: |
| 761 | + onnx_model: The onnx model. |
| 762 | + node_op_type: The node type to remove training_mode attribute from. |
| 763 | +
|
| 764 | + Returns: |
| 765 | + The onnx model with the training_mode attribute removed. |
| 766 | + """ |
| 767 | + removed_output_names = set() |
| 768 | + all_inputs = {inp for n in onnx_model.graph.node for inp in n.input} |
| 769 | + graph_outputs = {o.name for o in onnx_model.graph.output} |
| 770 | + keep = all_inputs | graph_outputs |
| 771 | + |
| 772 | + for node in onnx_model.graph.node: |
| 773 | + if node.op_type != node_op_type: |
| 774 | + continue |
| 775 | + |
| 776 | + is_training_mode = False |
| 777 | + # Drop the 'training_mode' attribute if present |
| 778 | + for idx, attr in enumerate(list(node.attribute)): |
| 779 | + if attr.name == "training_mode": |
| 780 | + del node.attribute[idx] |
| 781 | + if attr.i == 1: |
| 782 | + is_training_mode = True |
| 783 | + break |
| 784 | + |
| 785 | + # If the node has extra outputs, remove them all including the training outputs |
| 786 | + if is_training_mode: |
| 787 | + to_remove = [] |
| 788 | + for name in node.output: |
| 789 | + if name not in keep: |
| 790 | + removed_output_names.add(name) |
| 791 | + to_remove.append(name) |
| 792 | + |
| 793 | + for name in to_remove: |
| 794 | + node.output.remove(name) |
| 795 | + |
| 796 | + if removed_output_names: |
| 797 | + # Clean up corresponding value_info entries |
| 798 | + keep = [vi for vi in onnx_model.graph.value_info if vi.name not in removed_output_names] |
| 799 | + del onnx_model.graph.value_info[:] |
| 800 | + onnx_model.graph.value_info.extend(keep) |
| 801 | + |
| 802 | + return onnx_model |
0 commit comments