3838'''
3939
4040
41- @pytest .fixture
42- def tpc ():
43- """Fixture that returns a TargetPlatformCapabilities instance for testing."""
44- op1 = schema .OperatorsSet (name = "opset1" )
45- op2 = schema .OperatorsSet (name = "opset2" )
46- op3 = schema .OperatorsSet (name = "opset3" )
47- op12 = schema .OperatorSetGroup (operators_set = [op1 , op2 ])
48- return schema .TargetPlatformCapabilities (
49- default_qco = TEST_QCO ,
50- operator_set = (op1 , op2 , op3 ),
51- fusing_patterns = (
52- schema .Fusing (operator_groups = (op12 , op3 )),
53- schema .Fusing (operator_groups = (op1 , op2 ))
54- ),
55- tpc_minor_version = 1 ,
56- tpc_patch_version = 0 ,
57- tpc_platform_type = "dump_to_json" ,
58- add_metadata = False
59- )
60-
61-
62- @pytest .fixture (params = ALL_SCHEMA_VERSIONS )
63- def tpc_by_schema_version (request ):
64- """Fixture that returns a TargetPlatformCapabilities instance for testing."""
65- selected_schema = request .param
41+ def get_tpc (selected_schema ):
42+ """
43+ :param selected_schema: A schema to create tpc from
44+ :return: TargetPlatformCapabilities instance using the given selected_schema
45+ """
6646 op1 = selected_schema .OperatorsSet (name = "opset1" )
6747 op2 = selected_schema .OperatorsSet (name = "opset2" )
6848 op3 = selected_schema .OperatorsSet (name = "opset3" )
6949 op12 = selected_schema .OperatorSetGroup (operators_set = [op1 , op2 ])
70- yield selected_schema .TargetPlatformCapabilities (
50+ return selected_schema .TargetPlatformCapabilities (
7151 default_qco = TEST_QCO ,
7252 operator_set = (op1 , op2 , op3 ),
7353 fusing_patterns = (
@@ -81,6 +61,19 @@ def tpc_by_schema_version(request):
8161 )
8262
8363
64+ @pytest .fixture
65+ def tpc ():
66+ """Fixture that returns a TargetPlatformCapabilities instance of current schema."""
67+ return get_tpc (schema )
68+
69+
70+ @pytest .fixture (params = ALL_SCHEMA_VERSIONS )
71+ def tpc_by_schema_version (request ):
72+ """Fixture that yields a TargetPlatformCapabilities instance for each schema version."""
73+ selected_schema = request .param
74+ yield get_tpc (selected_schema )
75+
76+
8477@pytest .fixture
8578def tmp_invalid_json (tmp_path ):
8679 """Fixture that creates an invalid JSON file."""
@@ -120,7 +113,6 @@ def test_new_schema(self):
120113 all_supported_versions = [s .TargetPlatformCapabilities .SCHEMA_VERSION for s in ALL_SCHEMA_VERSIONS ]
121114 assert current_version in all_supported_versions , "Current schema need to be added to ALL_SCHEMA_VERSIONS"
122115
123-
124116 def test_schema_compatibility (self , tpc_by_schema_version ):
125117 """Tests that tpc of any schema version is supported and can be converted into current schema version"""
126118 tpc_by_schema = tpc_by_schema_version
0 commit comments