diff --git a/graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py b/graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py index 9c55aa518..dfd81479b 100644 --- a/graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py +++ b/graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py @@ -69,6 +69,24 @@ def get_new_node_arg(i, arg, len_args): return size_node + # def create_new_node(node): + # if not (self._node_need_rewrite(node) and last_node_axis is not None): + # # Copy other nodes to the new graph + # new_node = new_graph.node_copy(node, lambda x: val_map[x]) + # try_reset_last_node_axis(node=node, new_node=new_node) + # return new_node + + # new_node_args = tuple( + # get_new_node_arg(i, arg, len(node.args)) + # for i, arg in enumerate(node.args) + # ) + + # new_node = new_graph.call_function( + # self.node_target(), args=new_node_args, kwargs=node.kwargs + # ) + + # return new_node + def create_new_node(node): if not (self._node_need_rewrite(node) and last_node_axis is not None): # Copy other nodes to the new graph @@ -85,8 +103,22 @@ def create_new_node(node): self.node_target(), args=new_node_args, kwargs=node.kwargs ) + static_limit = _get_static_limit(node) + if static_limit != float("inf"): + max_val = int(static_limit - 1) + new_node = new_graph.call_function( + torch.clamp, args=(new_node, 0, max_val) + ) return new_node + def _get_static_limit(node): + static_limit = float("inf") + for user in node.users: + if user.op == "call_function" and ("embedding" in str(user.target)): + indexed_dim_size = user.args[1].meta["tensor_meta"].shape[0] + static_limit = min(static_limit, indexed_dim_size) + return static_limit + for node in traced_module.graph.nodes: val_map[node] = create_new_node(node)