11# Set pytest parameters
22import pytest
3+
34# Numpy for handling simulation of tensor operations
45import numpy as np
6+
57# Helper for creating ONNX nodes
68from onnx import TensorProto
79from onnx import helper as oh
10+
811# QONNX wrapper of ONNX model graphs
912from qonnx .core .modelwrapper import ModelWrapper
10- # QONNX utility for creating models from ONNX graphs
11- from qonnx .util .basic import qonnx_make_model
13+
1214# Execute QONNX model graphs
1315from qonnx .core .onnx_exec import execute_onnx
16+
1417# Graph transformation to be tested: Sorts the input list of commutative
1518# operations to have all dynamic inputs first followed by all initializer inputs
1619from qonnx .transformation .general import SortCommutativeInputsInitializerLast
1720
21+ # QONNX utility for creating models from ONNX graphs
22+ from qonnx .util .basic import qonnx_make_model
23+
1824
1925# Specify how many inputs the test should cover
2026@pytest .mark .parametrize ("num_inputs" , [4 , 5 , 6 ])
2127# Specify which inputs should be turned into initializers
2228@pytest .mark .parametrize (
29+ # fmt: off
2330 "initializers" , [[], [0 ], [1 ], [0 , 1 ], [0 , 3 ], [0 , 1 , 2 , 3 ]]
31+ # fmt: on
2432)
2533# Tests the SortCommutativeInputsInitializerLast transformation
2634def test_sort_commutative_inputs_initializer_last (num_inputs , initializers ):
@@ -29,11 +37,15 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
2937 # We will use the Sum ONNX operation to test this behavior, as it allows for
3038 # arbitrary many inputs
3139 node = oh .make_node (
40+ # fmt: off
3241 op_type = "Sum" , inputs = inputs , outputs = ["out" ], name = "Sum"
42+ # fmt: on
3343 )
3444 # Create value infos for all input and the output tensor
3545 inputs = [
46+ # fmt: off
3647 oh .make_tensor_value_info (i , TensorProto .FLOAT , (16 ,)) for i in inputs
48+ # fmt: on
3749 ]
3850 out = oh .make_tensor_value_info ("out" , TensorProto .FLOAT , (16 ,))
3951 # Make a graph comprising the Sum node and value infos for all inputs and
@@ -42,9 +54,7 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
4254 # Wrap the graph in an QONNX model wrapper
4355 model = ModelWrapper (qonnx_make_model (graph , producer_name = "qonnx-tests" ))
4456 # Prepare the execution context
45- context = {
46- f"in{ i } " : np .random .rand (16 ) for i in range (num_inputs )
47- }
57+ context = {f"in{ i } " : np .random .rand (16 ) for i in range (num_inputs )}
4858 # Make sure all inputs are of type float32
4959 context = {key : value .astype (np .float32 ) for key , value in context .items ()}
5060 # Turn selected inputs into initializers
@@ -57,7 +67,9 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
5767 # Note: No cleanup, as the tested transformation is part of the cleanup, and
5868 # we want to test this in isolation
5969 model = model .transform (
70+ # fmt: off
6071 SortCommutativeInputsInitializerLast (), cleanup = False
72+ # fmt: on
6173 )
6274 # Execute the ONNX model after transforming
6375 out_produced = execute_onnx (model , context )["out" ]
@@ -71,8 +83,9 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
7183 seen_initializer = True
7284 # If there has already been an initializer, this input must be an
7385 # initializer as well
74- assert not seen_initializer or model .get_initializer (i ) is not None , \
75- "Non-initializer input following initializer after sorting"
86+ assert (
87+ not seen_initializer or model .get_initializer (i ) is not None
88+ ), "Non-initializer input following initializer after sorting"
7689
7790 # Outputs before and after must match
7891 assert np .allclose (out_produced , out_expected )
0 commit comments