Skip to content

Commit b46ad23

Browse files
author
yarden-sony
committed
remove code duplication from test_tpc
1 parent b1824df commit b46ad23

File tree

1 file changed

+19
-27
lines changed
  • tests_pytest/common_tests/unit_tests/target_platform_capabilities

1 file changed

+19
-27
lines changed

tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,36 +38,16 @@
3838
'''
3939

4040

41-
@pytest.fixture
42-
def tpc():
43-
"""Fixture that returns a TargetPlatformCapabilities instance for testing."""
44-
op1 = schema.OperatorsSet(name="opset1")
45-
op2 = schema.OperatorsSet(name="opset2")
46-
op3 = schema.OperatorsSet(name="opset3")
47-
op12 = schema.OperatorSetGroup(operators_set=[op1, op2])
48-
return schema.TargetPlatformCapabilities(
49-
default_qco=TEST_QCO,
50-
operator_set=(op1, op2, op3),
51-
fusing_patterns=(
52-
schema.Fusing(operator_groups=(op12, op3)),
53-
schema.Fusing(operator_groups=(op1, op2))
54-
),
55-
tpc_minor_version=1,
56-
tpc_patch_version=0,
57-
tpc_platform_type="dump_to_json",
58-
add_metadata=False
59-
)
60-
61-
62-
@pytest.fixture(params=ALL_SCHEMA_VERSIONS)
63-
def tpc_by_schema_version(request):
64-
"""Fixture that returns a TargetPlatformCapabilities instance for testing."""
65-
selected_schema = request.param
41+
def get_tpc(selected_schema):
42+
"""
43+
:param selected_schema: A schema to create tpc from
44+
:return: TargetPlatformCapabilities instance using the given selected_schema
45+
"""
6646
op1 = selected_schema.OperatorsSet(name="opset1")
6747
op2 = selected_schema.OperatorsSet(name="opset2")
6848
op3 = selected_schema.OperatorsSet(name="opset3")
6949
op12 = selected_schema.OperatorSetGroup(operators_set=[op1, op2])
70-
yield selected_schema.TargetPlatformCapabilities(
50+
return selected_schema.TargetPlatformCapabilities(
7151
default_qco=TEST_QCO,
7252
operator_set=(op1, op2, op3),
7353
fusing_patterns=(
@@ -81,6 +61,19 @@ def tpc_by_schema_version(request):
8161
)
8262

8363

64+
@pytest.fixture
65+
def tpc():
66+
"""Fixture that returns a TargetPlatformCapabilities instance of current schema."""
67+
return get_tpc(schema)
68+
69+
70+
@pytest.fixture(params=ALL_SCHEMA_VERSIONS)
71+
def tpc_by_schema_version(request):
72+
"""Fixture that yields a TargetPlatformCapabilities instance for each schema version."""
73+
selected_schema = request.param
74+
yield get_tpc(selected_schema)
75+
76+
8477
@pytest.fixture
8578
def tmp_invalid_json(tmp_path):
8679
"""Fixture that creates an invalid JSON file."""
@@ -120,7 +113,6 @@ def test_new_schema(self):
120113
all_supported_versions = [s.TargetPlatformCapabilities.SCHEMA_VERSION for s in ALL_SCHEMA_VERSIONS]
121114
assert current_version in all_supported_versions, "Current schema need to be added to ALL_SCHEMA_VERSIONS"
122115

123-
124116
def test_schema_compatibility(self, tpc_by_schema_version):
125117
"""Tests that tpc of any schema version is supported and can be converted into current schema version"""
126118
tpc_by_schema = tpc_by_schema_version

0 commit comments

Comments
 (0)