Skip to content

Commit 47ecd20

Browse files
ironsided777pytorchmergebot
authored andcommitted
[ONNX] Fix index_put_ usage (pytorch#161263)
Summary: It's hard to understand how it's working in most of our models, but in general it looks like `aten::copy_` is replaced incorrectly. There are two schemas for `aten::copy_`: 1. `aten::copy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)` 2. `aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)` According to the logic in the comments we don't need one of the parameters for `aten::index_put_`. It seems logic has been inferred from ordinary `aten::copy` where there could be a third parameter which is `non_blocking` flag. Depending on the execution environment the sliced copying can be replaced either by first schema or by second schema with explicitly setting default parameter to `False`. If first schema is selected it will lead to the crash (which is easily to catch in our prod env). In case of the second schema selection, there is no crash, but the third parameter is treated as `accumulate` parameter of the `index_put_` function which doesn't make sense. So, in any case usage of the third parameter must be removed from the `aten::copy_` replacement. For more details and check this post: https://fb.workplace.com/groups/1405155842844877/permalink/25337687649165028/ Test Plan: The test fails in production envirounment only. In the test env `non_blocking` flag is mapped as `False` to the `acumulate` flag, which doesn't cause test to fail, but has no sense in terms of flags mapping. The export works without errors, before the fix it was failing with accessing by index out of bounds vector, like this: ``` 1095 _C._jit_onnx_log("Torch IR graph at exception: ", graph) File ~/.bento/kernels/bento_kernel_gaia_ml/1578/bento_kernel_gaia_ml_binary-inplace#link-tree/torch/onnx/utils.py:636, in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module) 629 _C._jit_pass_lower_all_tuples(graph) 630 # in _jit_pass_onnx, symbolic functions are called for each node for conversion. 631 # However, there are nodes that cannot be converted without additional context. 632 # For example, the number of outputs from split (and whether it is static or dynamic) is unknown 633 # until the point where it is unpacked by listUnpack node. 634 # This pass does a preprocess, and prepares the nodes such that enough context can be received 635 # by the symbolic function. --> 636 _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) 637 _C._jit_pass_onnx_preprocess(graph) 639 # onnx does not support tuples, so try to remove them RuntimeError: vector::_M_range_check: __n (which is 2) >= this->size() (which is 2) ``` The test script: ``` import torch as th import tempfile class CopyTest(th.nn.Module): def forward( self, input_th: th.Tensor ): to_fill = th.ones((3, 3)) to_fill[:, 0] = input_th[:, 0] return to_fill m = CopyTest() test_tensor = th.zeros((3, 3)) with tempfile.NamedTemporaryFile() as f: th.onnx.export( m, (test_tensor,), f, export_params=True, opset_version=17, do_constant_folding=True, input_names=["input"], output_names=["features"], dynamo=False, ) ``` The exported model test: ``` import torch import onnx import onnxruntime model_name = '/home/ironsided/test_model.onnx' onnx_model = onnx.load(model_name) onnx.checker.check_model(onnx_model) example_inputs = (torch.zeros(3, 3),) onnx_inputs = [tensor.numpy(force=True) for tensor in example_inputs] print(f"Input length: {len(onnx_inputs)}") print(f"Sample input: {onnx_inputs}") ort_session = onnxruntime.InferenceSession( model_name, providers=["CPUExecutionProvider"] ) onnxruntime_input = {input_arg.name: input_value for input_arg, input_value in zip(ort_session.get_inputs(), onnx_inputs)} # ONNX Runtime returns a list of outputs onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0] print(onnxruntime_outputs) ``` The produced result is correct: ``` Input length: 1 Sample input: [array([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=float32)] [[0. 1. 1.] [0. 1. 1.] [0. 1. 1.]] ``` Rollback Plan: Differential Revision: D80797028 Pull Request resolved: pytorch#161263 Approved by: https://github.com/justinchuby, https://github.com/jermenkoo
1 parent 1750cc8 commit 47ecd20

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,7 @@ std::pair<Value*, Value*> PrepareCopyForONNX(Node* node) {
191191
expanded_value->node()->copyMetadata(node);
192192

193193
auto index_put = graph->insert(
194-
aten::index_put_,
195-
{node->input(0), dummy_list, expanded_value, node->input(2)});
194+
aten::index_put_, {node->input(0), dummy_list, expanded_value});
196195
index_put->node()->copyMetadata(node);
197196
index_put->copyMetadata(node->output());
198197
node->output()->replaceAllUsesWith(index_put);

0 commit comments

Comments
 (0)