Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions src/qonnx/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,38 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import onnx

from qonnx.custom_op.registry import getCustomOp


def extract_model_config_to_json(model, json_filename, attr_names_to_extract):
"""Create a json file with layer name -> attribute mappings extracted from the
model. The created json file can be later applied on a model with
# update this code to handle export configs from subgraphs
# where the subgraph is found in a node's attribute as a graph type
def extract_model_config(model, attr_names_to_extract):
"""Create a dictionary with layer name -> attribute mappings extracted from the
model. The created dictionary can be later applied on a model with
qonnx.transform.general.ApplyConfig."""

cfg = dict()
cfg["Defaults"] = dict()
for n in model.graph.node:
oi = getCustomOp(n)
layer_dict = dict()
for attr in attr_names_to_extract:
try:
layer_dict[attr] = oi.get_nodeattr(attr)
except AttributeError:
pass
for attr in n.attribute:
if attr.type == onnx.AttributeProto.GRAPH: # Graph type
# If the attribute is a graph, we need to extract the attributes from the subgraph
cfg.update(extract_model_config(model.make_subgraph_modelwrapper(attr.g), attr_names_to_extract))
elif attr.name in attr_names_to_extract:
# If the attribute name is in the list, we can add it directly
layer_dict[attr.name] = oi.get_nodeattr(attr.name)
if len(layer_dict) > 0:
cfg[n.name] = layer_dict
return cfg


def extract_model_config_to_json(model, json_filename, attr_names_to_extract):
"""Create a json file with layer name -> attribute mappings extracted from the
model. The created json file can be later applied on a model with
qonnx.transform.general.ApplyConfig."""

with open(json_filename, "w") as f:
json.dump(cfg, f, indent=2)
json.dump(extract_model_config(model, attr_names_to_extract), f, indent=2)
Loading