diff --git a/src/qonnx/transformation/change_datalayout.py b/src/qonnx/transformation/change_datalayout.py index 7b73e4bf..62e6140b 100644 --- a/src/qonnx/transformation/change_datalayout.py +++ b/src/qonnx/transformation/change_datalayout.py @@ -30,7 +30,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class ChangeDataLayoutQuantAvgPool2d(Transformation): @@ -78,6 +78,7 @@ def apply(self, model): graph.value_info.append(quantavg_out) quantavg_out = quantavg_out.name inp_trans_node = helper.make_node("Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]) + copy_metadata_props(n, inp_trans_node) quantavg_node = helper.make_node( "QuantAvgPool2d", [inp_trans_out], @@ -90,8 +91,10 @@ def apply(self, model): signed=signed, data_layout="NHWC", ) + copy_metadata_props(n, quantavg_node) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [quantavg_out], [node_output], perm=[0, 3, 1, 2]) + copy_metadata_props(n, out_trans_node) # insert nodes graph.node.insert(node_ind, inp_trans_node) graph.node.insert(node_ind + 1, quantavg_node) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 175af058..444a326c 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -40,7 +40,7 @@ from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name from qonnx.util.onnx import is_eltwise_optype # Standard ONNX nodes which require a ChannelsLast data format to function properly @@ -96,6 +96,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [eltwise_inp], [new_t_inp], perm=inv_perm) + copy_metadata_props(transpose_node, new_transpose_node) t_shape = np.transpose(np.empty(inp_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp @@ -107,6 +108,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe model.set_initializer(unsqueeze_param_name, np.asarray(list(range(ndim_inp - ndim)), dtype=np.int64)) unsqueeze_out_name = model.make_new_valueinfo_name() new_unsqueeze_node = helper.make_node("Unsqueeze", [eltwise_inp, unsqueeze_param_name], [unsqueeze_out_name]) + copy_metadata_props(eltwise_inp, new_unsqueeze_node) unsqueeze_out_shape = np.expand_dims(np.empty(inp_shape), axis=tuple(range(ndim_inp - ndim))).shape model.set_tensor_shape(unsqueeze_out_name, unsqueeze_out_shape) model.graph.node.append(new_unsqueeze_node) @@ -114,6 +116,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [unsqueeze_out_name], [new_t_inp], perm=inv_perm) + copy_metadata_props(transpose_node, new_transpose_node) t_shape = np.transpose(np.empty(unsqueeze_out_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp diff --git a/src/qonnx/transformation/extract_conv_bias.py b/src/qonnx/transformation/extract_conv_bias.py index bf2cf8b4..34b017bd 100644 --- a/src/qonnx/transformation/extract_conv_bias.py +++ b/src/qonnx/transformation/extract_conv_bias.py @@ -30,6 +30,7 @@ from onnx import helper from qonnx.transformation.base import Transformation +from qonnx.util.basic import copy_metadata_props class ExtractBiasFromConv(Transformation): @@ -75,6 +76,7 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) + copy_metadata_props(n, add_node) graph.node.insert(node_ind, add_node) # Repoint Conv output and remove bias tensor diff --git a/src/qonnx/transformation/extract_quant_scale_zeropt.py b/src/qonnx/transformation/extract_quant_scale_zeropt.py index 58863f08..f76e5555 100644 --- a/src/qonnx/transformation/extract_quant_scale_zeropt.py +++ b/src/qonnx/transformation/extract_quant_scale_zeropt.py @@ -33,6 +33,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.general import GiveUniqueParameterTensors, SortGraph from qonnx.transformation.remove import RemoveIdentityOps +from qonnx.util.basic import copy_metadata_props class ExtractQuantScaleZeroPt(Transformation): @@ -69,6 +70,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_scaled) inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm]) + copy_metadata_props(node, inp_scale_node) graph.node.append(inp_scale_node) # create new Mul node # remove scale from Quant node @@ -87,6 +89,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_zeropt) inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm]) + copy_metadata_props(node, inp_zeropt_node) graph.node.append(inp_zeropt_node) # remove zeropt from Quant node new_zeropt_nm = model.make_new_valueinfo_name() @@ -108,6 +111,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(out_zeropt) out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output]) + copy_metadata_props(node, out_zeropt_node) last_node.output[0] = out_zeropt_nm graph.node.append(out_zeropt_node) # important: when tracking a pointer to newly added nodes, @@ -127,6 +131,7 @@ def apply(self, model: ModelWrapper): last_node.output[0] = out_scale_nm graph.value_info.append(out_scale) out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output]) + copy_metadata_props(node, out_scale_node) graph.node.append(out_scale_node) if extract_scale or extract_zeropt: diff --git a/src/qonnx/transformation/gemm_to_matmul.py b/src/qonnx/transformation/gemm_to_matmul.py index 5396a7d6..245a0a2a 100644 --- a/src/qonnx/transformation/gemm_to_matmul.py +++ b/src/qonnx/transformation/gemm_to_matmul.py @@ -32,7 +32,7 @@ from qonnx.core.datatype import DataType from qonnx.transformation.base import Transformation from qonnx.transformation.remove import RemoveIdentityOps -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class GemmToMatMul(Transformation): @@ -76,6 +76,7 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[0]], [inp_trans_out.name]) + copy_metadata_props(n, inp_trans_node) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[0]) @@ -98,6 +99,7 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[1]], [inp_trans_out.name]) + copy_metadata_props(n, inp_trans_node) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 # Copy over the datatype @@ -109,6 +111,7 @@ def apply(self, model): # Insert MatMul: A * B matMul_node = helper.make_node("MatMul", [n.input[0], n.input[1]], [n.output[0]]) + copy_metadata_props(n, matMul_node) graph.node.insert(running_node_index, matMul_node) matMul_node = graph.node[running_node_index] running_node_index += 1 @@ -144,6 +147,7 @@ def apply(self, model): [act_mul_tensor.name, mul_tensor.name], [n.output[0]], ) + copy_metadata_props(n, mul_node) graph.node.insert(running_node_index, mul_node) mul_node_main_branch = graph.node[running_node_index] running_node_index += 1 @@ -175,6 +179,7 @@ def apply(self, model): [n.input[2], mul_tensor.name], [act_mul_tensor.name], ) + copy_metadata_props(n, mul_node) graph.node.insert(running_node_index, mul_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[2]) @@ -196,7 +201,7 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) - + copy_metadata_props(n, add_node) graph.node.insert(running_node_index, add_node) running_node_index += 1 diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index 81f0b713..5140b71d 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -32,7 +32,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.extract_conv_bias import ExtractBiasFromConv -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name class LowerConvsToMatMul(Transformation): @@ -178,8 +178,10 @@ def apply(self, model): matmul_input = im2col_out if need_im2col else inp_trans_out # do matmul matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out]) + copy_metadata_props(node, matmul_node) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2]) + copy_metadata_props(node, out_trans_node) nodes_to_insert.extend([matmul_node, out_trans_node]) diff --git a/src/qonnx/transformation/qcdq_to_qonnx.py b/src/qonnx/transformation/qcdq_to_qonnx.py index b7e35c0d..b4e18f25 100644 --- a/src/qonnx/transformation/qcdq_to_qonnx.py +++ b/src/qonnx/transformation/qcdq_to_qonnx.py @@ -34,7 +34,7 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.base import Transformation -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name def extract_elem_type(elem_type: int, clip_range=None) -> Tuple[int, int, bool]: @@ -203,6 +203,8 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: rounding_mode="ROUND", # round-to-even signed=signed, ) + # Preserve metadata from all nodes being fused + copy_metadata_props(node, fused_node) model.graph.node.insert(dequant_node_index, fused_node) for node_to_remove in nodes_to_remove: model.graph.node.remove(node_to_remove) diff --git a/src/qonnx/transformation/rebalance_conv.py b/src/qonnx/transformation/rebalance_conv.py index ecb2b5e4..0107a62a 100644 --- a/src/qonnx/transformation/rebalance_conv.py +++ b/src/qonnx/transformation/rebalance_conv.py @@ -31,6 +31,7 @@ from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.base import Transformation +from qonnx.util.basic import copy_metadata_props class RebalanceIm2Col(Transformation): @@ -103,6 +104,7 @@ def apply(self, model): inp_reshape_node = helper.make_node( "Reshape", [node.input[0], inp_shapedata.name], [inp_reshape_out.name] ) + copy_metadata_props(node, inp_reshape_node) graph.node.insert(running_node_index, inp_reshape_node) # rewire Im2Col input node.input[0] = inp_reshape_out.name diff --git a/src/qonnx/transformation/resize_conv_to_deconv.py b/src/qonnx/transformation/resize_conv_to_deconv.py index 0dd40972..7eda4fa7 100644 --- a/src/qonnx/transformation/resize_conv_to_deconv.py +++ b/src/qonnx/transformation/resize_conv_to_deconv.py @@ -33,7 +33,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.general.quant import quant, resolve_rounding_mode from qonnx.transformation.base import Transformation -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name def _weight_convolution(cnv_weights: np.ndarray, scale: int) -> np.ndarray: @@ -242,6 +242,7 @@ def apply(self, model): group=group, dilations=dilation, ) + copy_metadata_props(conv, deconv_node) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name diff --git a/src/qonnx/transformation/subpixel_to_deconv.py b/src/qonnx/transformation/subpixel_to_deconv.py index 3f330c99..73ef3f8f 100644 --- a/src/qonnx/transformation/subpixel_to_deconv.py +++ b/src/qonnx/transformation/subpixel_to_deconv.py @@ -31,7 +31,7 @@ from onnx import helper from qonnx.transformation.base import Transformation -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name def _weight_shuffle(cnv_weights: np.ndarray, block_size: int) -> np.ndarray: @@ -197,6 +197,7 @@ def apply(self, model): group=group, dilations=dilation, ) + copy_metadata_props(n, deconv_node) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 4e300dd1..1696c42b 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -350,3 +350,35 @@ def auto_pad_to_explicit_padding(autopad_str, idim_h, idim_w, k_h, k_w, stride_h return [pad_half_large_h, pad_half_large_w, pad_half_small_h, pad_half_small_w] else: raise Exception("Unsupported auto_pad: " + autopad_str) + + +def copy_metadata_props(source_node, target_node): + """Copy metadata properties from source node(s) to target node. + + Parameters + ---------- + source_node : onnx.NodeProto or list of onnx.NodeProto + Source node(s) from which to copy metadata_props. If a list is provided, + metadata from all nodes will be merged into the target node. + target_node : onnx.NodeProto + Target node to which metadata_props will be copied. + + Returns + ------- + None + Modifies target_node in place by extending its metadata_props. + + Examples + -------- + >>> # Copy from single node + >>> copy_metadata_props(old_node, new_node) + >>> + >>> # Copy from multiple nodes (e.g., when fusing) + >>> copy_metadata_props([quant_node, dequant_node], fused_node) + """ + # Handle both single node and list of nodes + source_nodes = source_node if isinstance(source_node, list) else [source_node] + + for node in source_nodes: + if hasattr(node, "metadata_props"): + target_node.metadata_props.extend(node.metadata_props)