Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/qonnx/transformation/change_datalayout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from qonnx.transformation.base import Transformation
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.basic import get_by_name
from qonnx.util.basic import copy_metadata_props, get_by_name


class ChangeDataLayoutQuantAvgPool2d(Transformation):
Expand Down Expand Up @@ -78,6 +78,7 @@ def apply(self, model):
graph.value_info.append(quantavg_out)
quantavg_out = quantavg_out.name
inp_trans_node = helper.make_node("Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1])
copy_metadata_props(n, inp_trans_node)
quantavg_node = helper.make_node(
"QuantAvgPool2d",
[inp_trans_out],
Expand All @@ -90,8 +91,10 @@ def apply(self, model):
signed=signed,
data_layout="NHWC",
)
copy_metadata_props(n, quantavg_node)
# NHWC -> NCHW
out_trans_node = helper.make_node("Transpose", [quantavg_out], [node_output], perm=[0, 3, 1, 2])
copy_metadata_props(n, out_trans_node)
# insert nodes
graph.node.insert(node_ind, inp_trans_node)
graph.node.insert(node_ind + 1, quantavg_node)
Expand Down
5 changes: 4 additions & 1 deletion src/qonnx/transformation/channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast
from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit
from qonnx.util.basic import get_by_name
from qonnx.util.basic import copy_metadata_props, get_by_name
from qonnx.util.onnx import is_eltwise_optype

# Standard ONNX nodes which require a ChannelsLast data format to function properly
Expand Down Expand Up @@ -96,6 +96,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe
new_t_inp = model.make_new_valueinfo_name()
inv_perm = np.argsort(perm)
new_transpose_node = helper.make_node("Transpose", [eltwise_inp], [new_t_inp], perm=inv_perm)
copy_metadata_props(transpose_node, new_transpose_node)
t_shape = np.transpose(np.empty(inp_shape), axes=inv_perm).shape
model.set_tensor_shape(new_t_inp, t_shape)
eltwise_node.input[ind] = new_t_inp
Expand All @@ -107,13 +108,15 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe
model.set_initializer(unsqueeze_param_name, np.asarray(list(range(ndim_inp - ndim)), dtype=np.int64))
unsqueeze_out_name = model.make_new_valueinfo_name()
new_unsqueeze_node = helper.make_node("Unsqueeze", [eltwise_inp, unsqueeze_param_name], [unsqueeze_out_name])
copy_metadata_props(eltwise_inp, new_unsqueeze_node)
unsqueeze_out_shape = np.expand_dims(np.empty(inp_shape), axis=tuple(range(ndim_inp - ndim))).shape
model.set_tensor_shape(unsqueeze_out_name, unsqueeze_out_shape)
model.graph.node.append(new_unsqueeze_node)
# now add inverse transpose
new_t_inp = model.make_new_valueinfo_name()
inv_perm = np.argsort(perm)
new_transpose_node = helper.make_node("Transpose", [unsqueeze_out_name], [new_t_inp], perm=inv_perm)
copy_metadata_props(transpose_node, new_transpose_node)
t_shape = np.transpose(np.empty(unsqueeze_out_shape), axes=inv_perm).shape
model.set_tensor_shape(new_t_inp, t_shape)
eltwise_node.input[ind] = new_t_inp
Expand Down
2 changes: 2 additions & 0 deletions src/qonnx/transformation/extract_conv_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from onnx import helper

from qonnx.transformation.base import Transformation
from qonnx.util.basic import copy_metadata_props


class ExtractBiasFromConv(Transformation):
Expand Down Expand Up @@ -75,6 +76,7 @@ def apply(self, model):
[act_add_tensor.name, n.input[2]],
[n.output[0]],
)
copy_metadata_props(n, add_node)
graph.node.insert(node_ind, add_node)

# Repoint Conv output and remove bias tensor
Expand Down
5 changes: 5 additions & 0 deletions src/qonnx/transformation/extract_quant_scale_zeropt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from qonnx.transformation.base import Transformation
from qonnx.transformation.general import GiveUniqueParameterTensors, SortGraph
from qonnx.transformation.remove import RemoveIdentityOps
from qonnx.util.basic import copy_metadata_props


class ExtractQuantScaleZeroPt(Transformation):
Expand Down Expand Up @@ -69,6 +70,7 @@ def apply(self, model: ModelWrapper):
)
graph.value_info.append(inp_scaled)
inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm])
copy_metadata_props(node, inp_scale_node)
graph.node.append(inp_scale_node)
# create new Mul node
# remove scale from Quant node
Expand All @@ -87,6 +89,7 @@ def apply(self, model: ModelWrapper):
)
graph.value_info.append(inp_zeropt)
inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm])
copy_metadata_props(node, inp_zeropt_node)
graph.node.append(inp_zeropt_node)
# remove zeropt from Quant node
new_zeropt_nm = model.make_new_valueinfo_name()
Expand All @@ -108,6 +111,7 @@ def apply(self, model: ModelWrapper):
)
graph.value_info.append(out_zeropt)
out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output])
copy_metadata_props(node, out_zeropt_node)
last_node.output[0] = out_zeropt_nm
graph.node.append(out_zeropt_node)
# important: when tracking a pointer to newly added nodes,
Expand All @@ -127,6 +131,7 @@ def apply(self, model: ModelWrapper):
last_node.output[0] = out_scale_nm
graph.value_info.append(out_scale)
out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output])
copy_metadata_props(node, out_scale_node)
graph.node.append(out_scale_node)

if extract_scale or extract_zeropt:
Expand Down
9 changes: 7 additions & 2 deletions src/qonnx/transformation/gemm_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from qonnx.core.datatype import DataType
from qonnx.transformation.base import Transformation
from qonnx.transformation.remove import RemoveIdentityOps
from qonnx.util.basic import get_by_name
from qonnx.util.basic import copy_metadata_props, get_by_name


class GemmToMatMul(Transformation):
Expand Down Expand Up @@ -76,6 +76,7 @@ def apply(self, model):
)
graph.value_info.append(inp_trans_out)
inp_trans_node = helper.make_node("Transpose", [n.input[0]], [inp_trans_out.name])
copy_metadata_props(n, inp_trans_node)
graph.node.insert(running_node_index, inp_trans_node)
running_node_index += 1
dt = model.get_tensor_datatype(n.input[0])
Expand All @@ -98,6 +99,7 @@ def apply(self, model):
)
graph.value_info.append(inp_trans_out)
inp_trans_node = helper.make_node("Transpose", [n.input[1]], [inp_trans_out.name])
copy_metadata_props(n, inp_trans_node)
graph.node.insert(running_node_index, inp_trans_node)
running_node_index += 1
# Copy over the datatype
Expand All @@ -109,6 +111,7 @@ def apply(self, model):

# Insert MatMul: A * B
matMul_node = helper.make_node("MatMul", [n.input[0], n.input[1]], [n.output[0]])
copy_metadata_props(n, matMul_node)
graph.node.insert(running_node_index, matMul_node)
matMul_node = graph.node[running_node_index]
running_node_index += 1
Expand Down Expand Up @@ -144,6 +147,7 @@ def apply(self, model):
[act_mul_tensor.name, mul_tensor.name],
[n.output[0]],
)
copy_metadata_props(n, mul_node)
graph.node.insert(running_node_index, mul_node)
mul_node_main_branch = graph.node[running_node_index]
running_node_index += 1
Expand Down Expand Up @@ -175,6 +179,7 @@ def apply(self, model):
[n.input[2], mul_tensor.name],
[act_mul_tensor.name],
)
copy_metadata_props(n, mul_node)
graph.node.insert(running_node_index, mul_node)
running_node_index += 1
dt = model.get_tensor_datatype(n.input[2])
Expand All @@ -196,7 +201,7 @@ def apply(self, model):
[act_add_tensor.name, n.input[2]],
[n.output[0]],
)

copy_metadata_props(n, add_node)
graph.node.insert(running_node_index, add_node)
running_node_index += 1

Expand Down
4 changes: 3 additions & 1 deletion src/qonnx/transformation/lower_convs_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from qonnx.transformation.base import Transformation
from qonnx.transformation.extract_conv_bias import ExtractBiasFromConv
from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name
from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name


class LowerConvsToMatMul(Transformation):
Expand Down Expand Up @@ -178,8 +178,10 @@ def apply(self, model):
matmul_input = im2col_out if need_im2col else inp_trans_out
# do matmul
matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out])
copy_metadata_props(node, matmul_node)
# NHWC -> NCHW
out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2])
copy_metadata_props(node, out_trans_node)

nodes_to_insert.extend([matmul_node, out_trans_node])

Expand Down
4 changes: 3 additions & 1 deletion src/qonnx/transformation/qcdq_to_qonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.base import Transformation
from qonnx.util.basic import get_by_name
from qonnx.util.basic import copy_metadata_props, get_by_name


def extract_elem_type(elem_type: int, clip_range=None) -> Tuple[int, int, bool]:
Expand Down Expand Up @@ -203,6 +203,8 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]:
rounding_mode="ROUND", # round-to-even
signed=signed,
)
# Preserve metadata from all nodes being fused
copy_metadata_props(node, fused_node)
model.graph.node.insert(dequant_node_index, fused_node)
for node_to_remove in nodes_to_remove:
model.graph.node.remove(node_to_remove)
Expand Down
2 changes: 2 additions & 0 deletions src/qonnx/transformation/rebalance_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.base import Transformation
from qonnx.util.basic import copy_metadata_props


class RebalanceIm2Col(Transformation):
Expand Down Expand Up @@ -103,6 +104,7 @@ def apply(self, model):
inp_reshape_node = helper.make_node(
"Reshape", [node.input[0], inp_shapedata.name], [inp_reshape_out.name]
)
copy_metadata_props(node, inp_reshape_node)
graph.node.insert(running_node_index, inp_reshape_node)
# rewire Im2Col input
node.input[0] = inp_reshape_out.name
Expand Down
3 changes: 2 additions & 1 deletion src/qonnx/transformation/resize_conv_to_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from qonnx.core.datatype import DataType
from qonnx.custom_op.general.quant import quant, resolve_rounding_mode
from qonnx.transformation.base import Transformation
from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name
from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name


def _weight_convolution(cnv_weights: np.ndarray, scale: int) -> np.ndarray:
Expand Down Expand Up @@ -242,6 +242,7 @@ def apply(self, model):
group=group,
dilations=dilation,
)
copy_metadata_props(conv, deconv_node)
W_deconv_init = weight_name
if weight_prod is not None:
W_deconv_init = q_w_name
Expand Down
3 changes: 2 additions & 1 deletion src/qonnx/transformation/subpixel_to_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from onnx import helper

from qonnx.transformation.base import Transformation
from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name
from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name


def _weight_shuffle(cnv_weights: np.ndarray, block_size: int) -> np.ndarray:
Expand Down Expand Up @@ -197,6 +197,7 @@ def apply(self, model):
group=group,
dilations=dilation,
)
copy_metadata_props(n, deconv_node)
W_deconv_init = weight_name
if weight_prod is not None:
W_deconv_init = q_w_name
Expand Down
32 changes: 32 additions & 0 deletions src/qonnx/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,35 @@ def auto_pad_to_explicit_padding(autopad_str, idim_h, idim_w, k_h, k_w, stride_h
return [pad_half_large_h, pad_half_large_w, pad_half_small_h, pad_half_small_w]
else:
raise Exception("Unsupported auto_pad: " + autopad_str)


def copy_metadata_props(source_node, target_node):
"""Copy metadata properties from source node(s) to target node.

Parameters
----------
source_node : onnx.NodeProto or list of onnx.NodeProto
Source node(s) from which to copy metadata_props. If a list is provided,
metadata from all nodes will be merged into the target node.
target_node : onnx.NodeProto
Target node to which metadata_props will be copied.

Returns
-------
None
Modifies target_node in place by extending its metadata_props.

Examples
--------
>>> # Copy from single node
>>> copy_metadata_props(old_node, new_node)
>>>
>>> # Copy from multiple nodes (e.g., when fusing)
>>> copy_metadata_props([quant_node, dequant_node], fused_node)
"""
# Handle both single node and list of nodes
source_nodes = source_node if isinstance(source_node, list) else [source_node]

for node in source_nodes:
if hasattr(node, "metadata_props"):
target_node.metadata_props.extend(node.metadata_props)
Comment on lines +355 to +384
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide some unit test(s) for this new utility function, ideally also covering edge cases (e.g. source node has no attributes, source and target node have attribute with the same name).

Loading