|
4 | 4 | This module implements ONNX control flow operators like Loop and If. |
5 | 5 | """ |
6 | 6 |
|
7 | | -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple |
| 7 | +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple |
8 | 8 |
|
9 | 9 | import onnx |
10 | 10 | import torch |
@@ -70,6 +70,7 @@ def _build_subgraph_module( |
70 | 70 | parent_env: Dict[str, torch.fx.Node], |
71 | 71 | parent_opset_versions: Dict[str, int], |
72 | 72 | parent_type_info: Optional[Dict[str, bool]] = None, |
| 73 | + tensor_loader: Optional[Callable[[onnx.TensorProto], torch.Tensor]] = None, |
73 | 74 | ) -> Tuple[torch.fx.GraphModule, List[str], List[str], List[str]]: |
74 | 75 | """Build an FX GraphModule from an ONNX subgraph. |
75 | 76 |
|
@@ -106,8 +107,11 @@ def _build_subgraph_module( |
106 | 107 | # Load initializers from subgraph |
107 | 108 | initializer_map: Dict[str, torch.Tensor] = {} |
108 | 109 | for initializer in body_graph.initializer: |
109 | | - np_array = numpy_helper.to_array(initializer) |
110 | | - initializer_map[initializer.name] = torch.from_numpy(np_array.copy()) |
| 110 | + if tensor_loader is not None: |
| 111 | + initializer_map[initializer.name] = tensor_loader(initializer) |
| 112 | + else: |
| 113 | + np_array = numpy_helper.to_array(initializer) |
| 114 | + initializer_map[initializer.name] = torch.from_numpy(np_array.copy()) |
111 | 115 |
|
112 | 116 | # Register initializers as constants |
113 | 117 | for name, tensor in initializer_map.items(): |
@@ -159,9 +163,16 @@ def __init__(self): |
159 | 163 | self.initializer_map = initializer_map |
160 | 164 | self._body_graph = body_graph |
161 | 165 | self._parent_type_info = parent_type_info |
| 166 | + self._tensor_loader = tensor_loader |
162 | 167 | # Build type info for this subgraph (to pass to nested subgraphs) |
163 | 168 | self._type_info = self._build_type_info() |
164 | 169 |
|
| 170 | + def load_tensor(self, tensor: onnx.TensorProto) -> torch.Tensor: |
| 171 | + if self._tensor_loader is not None: |
| 172 | + return self._tensor_loader(tensor) |
| 173 | + np_array = numpy_helper.to_array(tensor) |
| 174 | + return torch.from_numpy(np_array.copy()) |
| 175 | + |
165 | 176 | def _build_type_info(self) -> Dict[str, bool]: |
166 | 177 | """Build a mapping of value names to whether they are optional types.""" |
167 | 178 | info: Dict[str, bool] = {} |
@@ -437,7 +448,11 @@ def loop_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node: |
437 | 448 | # Build subgraph module |
438 | 449 | body_module, body_input_names, body_output_names, outer_refs = ( |
439 | 450 | _build_subgraph_module( |
440 | | - body_graph, builder.env, builder._opset_versions, parent_type_info |
| 451 | + body_graph, |
| 452 | + builder.env, |
| 453 | + builder._opset_versions, |
| 454 | + parent_type_info, |
| 455 | + tensor_loader=builder.load_tensor, |
441 | 456 | ) |
442 | 457 | ) |
443 | 458 |
|
@@ -628,7 +643,11 @@ def scan_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node: |
628 | 643 | # Build subgraph module |
629 | 644 | body_module, body_input_names, body_output_names, outer_refs = ( |
630 | 645 | _build_subgraph_module( |
631 | | - body_graph, builder.env, builder._opset_versions, parent_type_info |
| 646 | + body_graph, |
| 647 | + builder.env, |
| 648 | + builder._opset_versions, |
| 649 | + parent_type_info, |
| 650 | + tensor_loader=builder.load_tensor, |
632 | 651 | ) |
633 | 652 | ) |
634 | 653 |
|
@@ -712,7 +731,11 @@ def scan_op_v8(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node: |
712 | 731 | # Build subgraph module |
713 | 732 | body_module, body_input_names, body_output_names, outer_refs = ( |
714 | 733 | _build_subgraph_module( |
715 | | - body_graph, builder.env, builder._opset_versions, parent_type_info |
| 734 | + body_graph, |
| 735 | + builder.env, |
| 736 | + builder._opset_versions, |
| 737 | + parent_type_info, |
| 738 | + tensor_loader=builder.load_tensor, |
716 | 739 | ) |
717 | 740 | ) |
718 | 741 |
|
@@ -884,12 +907,20 @@ def if_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node: |
884 | 907 | # Build subgraph modules for both branches |
885 | 908 | then_module, then_input_names, then_output_names, then_outer_refs = ( |
886 | 909 | _build_subgraph_module( |
887 | | - then_graph, builder.env, builder._opset_versions, parent_type_info |
| 910 | + then_graph, |
| 911 | + builder.env, |
| 912 | + builder._opset_versions, |
| 913 | + parent_type_info, |
| 914 | + tensor_loader=builder.load_tensor, |
888 | 915 | ) |
889 | 916 | ) |
890 | 917 | else_module, else_input_names, else_output_names, else_outer_refs = ( |
891 | 918 | _build_subgraph_module( |
892 | | - else_graph, builder.env, builder._opset_versions, parent_type_info |
| 919 | + else_graph, |
| 920 | + builder.env, |
| 921 | + builder._opset_versions, |
| 922 | + parent_type_info, |
| 923 | + tensor_loader=builder.load_tensor, |
893 | 924 | ) |
894 | 925 | ) |
895 | 926 |
|
|
0 commit comments