diff --git a/onnx2pytorch/convert/operations.py b/onnx2pytorch/convert/operations.py index abdaea3..bdd6e29 100644 --- a/onnx2pytorch/convert/operations.py +++ b/onnx2pytorch/convert/operations.py @@ -75,6 +75,15 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr iterator: (op_id, op_name, op) """ weights = {tensor.name: tensor for tensor in onnx_graph.initializer} + branching_factor = {} + + # `len(node.output)` is unreliable, so we need to count ourselves + for i, node in enumerate(onnx_graph.node): + if len(node.input) > 0: + input_id = node.input[0] + if input_id not in branching_factor: + branching_factor[input_id] = 0 + branching_factor[input_id] += 1 for i, node in enumerate(onnx_graph.node): # extract only useful inputs @@ -107,7 +116,8 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr elif node.op_type == "Div": op = Div() elif node.op_type == "Elu": - op = nn.ELU(**extract_attributes(node), inplace=True) + is_input_branching = branching_factor[node.input[0]] > 1 + op = nn.ELU(**extract_attributes(node), inplace=not is_input_branching) elif node.op_type == "Equal": op = OperatorWrapper(torch.eq) elif node.op_type == "Erf": @@ -136,7 +146,10 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr elif node.op_type == "InstanceNormalization": op = convert_instance_norm_layer(node, params=params) elif node.op_type == "LeakyRelu": - op = nn.LeakyReLU(**extract_attributes(node), inplace=True) + is_input_branching = branching_factor[node.input[0]] > 1 + op = nn.LeakyReLU( + **extract_attributes(node), inplace=not is_input_branching + ) elif node.op_type == "Less": op = OperatorWrapper(torch.less) elif node.op_type == "Log": @@ -215,7 +228,8 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr elif node.op_type == "ReduceSum": op = ReduceSum(opset_version=opset_version, **extract_attributes(node)) elif node.op_type == "Relu": - op = nn.ReLU(inplace=True) + is_input_branching = branching_factor[node.input[0]] > 1 + op = nn.ReLU(inplace=not is_input_branching) elif node.op_type == "Reshape": shape = list( filter(lambda x: x.name == node.input[1], onnx_graph.initializer)