Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions src/qonnx/transformation/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
import warnings

# Protobuf onnx graph node type
from onnx import NodeProto # noqa
from onnx import mapping
from onnx import AttributeProto, NodeProto, mapping # noqa
from toposort import toposort_flatten

import qonnx.util.basic as util
Expand Down Expand Up @@ -335,34 +334,43 @@ def __init__(self, config, node_filter=lambda x: True):
super().__init__()
self.config = config
self.node_filter = node_filter
self.used_configurations = ["Defaults"]
self.missing_configurations = []

def apply(self, model):
if isinstance(self.config, dict):
model_config = self.config
else:
with open(self.config, "r") as f:
model_config = json.load(f)

used_configurations = ["Defaults"]
missing_configurations = []

def configure_network(self, model, model_config, subgraph_hier):
# Configure network
for node_idx, node in enumerate(model.graph.node):
if not self.node_filter(node):
continue

try:
node_config = model_config[node.name]
except KeyError:
missing_configurations += [node.name]
self.missing_configurations += [node.name]
node_config = {}

# check if config matches subhierarchy parameter
try:
node_subgraph_hier = node_config["subgraph_hier"]
except KeyError:
node_subgraph_hier = None
# if the subgraph hierarchy parameter does not match
# the fct parameter skip
# else: remove the parameter from config dict (if not None)
# to prevent applying it to the node as an attribute
if node_subgraph_hier != subgraph_hier:
continue
elif node_subgraph_hier:
del node_config["subgraph_hier"]

self.used_configurations += [node.name]

from qonnx.custom_op.registry import getCustomOp

try:
inst = getCustomOp(node)
except Exception:
continue
used_configurations += [node.name]

# set specified defaults
default_values = []
Expand All @@ -380,11 +388,28 @@ def apply(self, model):
for attr, value in node_config.items():
inst.set_nodeattr(attr, value)

# apply to subgraph
for attr in node.attribute:
if attr.type == AttributeProto.GRAPH:
# this is a subgraph, add it to the list
subgraph = model.make_subgraph_modelwrapper(attr.g)
self.configure_network(subgraph, model_config, subgraph_hier=str(subgraph_hier) + "/" + node.name)

def apply(self, model):
if isinstance(self.config, dict):
model_config = self.config
else:
with open(self.config, "r") as f:
model_config = json.load(f)

# apply configuration on upper level
self.configure_network(model, model_config, subgraph_hier=None)

# Configuration verification
if len(missing_configurations) > 0:
warnings.warn("\nNo HW configuration for nodes: " + ", ".join(missing_configurations))
if len(self.missing_configurations) > 0:
warnings.warn("\nNo HW configuration for nodes: " + ", ".join(self.missing_configurations))

unused_configs = [x for x in model_config if x not in used_configurations]
unused_configs = [x for x in model_config if x not in self.used_configurations]
if len(unused_configs) > 0:
warnings.warn("\nUnused HW configurations: " + ", ".join(unused_configs))

Expand Down
32 changes: 22 additions & 10 deletions src/qonnx/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,38 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import onnx

from qonnx.custom_op.registry import getCustomOp


def extract_model_config_to_json(model, json_filename, attr_names_to_extract):
"""Create a json file with layer name -> attribute mappings extracted from the
model. The created json file can be later applied on a model with
# update this code to handle export configs from subgraphs
# where the subgraph is found in a node's attribute as a graph type
def extract_model_config(model, attr_names_to_extract):
"""Create a dictionary with layer name -> attribute mappings extracted from the
model. The created dictionary can be later applied on a model with
qonnx.transform.general.ApplyConfig."""

cfg = dict()
cfg["Defaults"] = dict()
for n in model.graph.node:
oi = getCustomOp(n)
layer_dict = dict()
for attr in attr_names_to_extract:
try:
layer_dict[attr] = oi.get_nodeattr(attr)
except AttributeError:
pass
for attr in n.attribute:
if attr.type == onnx.AttributeProto.GRAPH: # Graph type
# If the attribute is a graph, we need to extract the attributes from the subgraph
cfg.update(extract_model_config(model.make_subgraph_modelwrapper(attr.g), attr_names_to_extract))
elif attr.name in attr_names_to_extract:
# If the attribute name is in the list, we can add it directly
layer_dict[attr.name] = oi.get_nodeattr(attr.name)
if len(layer_dict) > 0:
cfg[n.name] = layer_dict
return cfg


def extract_model_config_to_json(model, json_filename, attr_names_to_extract):
"""Create a json file with layer name -> attribute mappings extracted from the
model. The created json file can be later applied on a model with
qonnx.transform.general.ApplyConfig."""

with open(json_filename, "w") as f:
json.dump(cfg, f, indent=2)
json.dump(extract_model_config(model, attr_names_to_extract), f, indent=2)
Loading