Skip to content

Commit 9624bd7

Browse files
committed
Refactor test.utils.tag
1 parent 5a8e51c commit 9624bd7

File tree

3 files changed

+101
-75
lines changed

3 files changed

+101
-75
lines changed

test/pt2_to_circle_test/builder.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from test.utils.base_builders import TestDictBuilderBase, TestRunnerBase
3434

35-
from test.utils.tag import is_tagged
35+
from test.utils.tag import TestTag
3636

3737

3838
class NNModuleTest(TestRunnerBase):
@@ -41,9 +41,11 @@ def __init__(self, test_name: str, nnmodule: torch.nn.Module):
4141
self.test_dir = Path(os.path.dirname(os.path.abspath(__file__))) / "artifacts"
4242

4343
# Get tags
44-
self.test_without_pt2: bool = is_tagged(self.nnmodule, "test_without_pt2")
45-
self.test_without_inference: bool = is_tagged(
46-
self.nnmodule, "test_without_inference"
44+
self.test_without_pt2: bool = TestTag.get(
45+
self.nnmodule, "test_without_pt2", False
46+
)
47+
self.test_without_inference: bool = TestTag.get(
48+
self.nnmodule, "test_without_inference", False
4749
)
4850

4951
# Set tolerance

test/utils/base_builders.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
import inspect
1717
import pkgutil
1818
from abc import abstractmethod
19+
from typing import Optional
1920

2021
import torch
21-
22-
from test.utils.tag import is_tagged
22+
from test.utils.tag import get_tag, has_tag, TestTag
2323

2424

2525
class TestRunnerBase:
@@ -31,12 +31,17 @@ def __init__(self, test_name: str, nnmodule: torch.nn.Module):
3131
self.nnmodule = nnmodule
3232
self.example_inputs = nnmodule.get_example_inputs() # type: ignore[operator]
3333

34-
# Get tags
35-
self.skip: bool = is_tagged(self.nnmodule, "skip")
36-
self.skip_reason: str = getattr(self.nnmodule, "__tag_skip_reason", "")
37-
self.test_negative: bool = is_tagged(self.nnmodule, "test_negative")
38-
self.expected_err: str = getattr(self.nnmodule, "__tag_expected_err", "")
39-
self.use_onert: bool = is_tagged(self.nnmodule, "use_onert")
34+
skip: Optional[object] = TestTag.get(type(self.nnmodule), "skip")
35+
self.skip: bool = skip is not None
36+
self.skip_reason: str = skip.get("reason") if skip else ""
37+
38+
test_negative: Optional[object] = TestTag.get(
39+
type(self.nnmodule), "test_negative"
40+
)
41+
self.test_negative: bool = test_negative is not None
42+
self.expected_err: str = test_negative.get("reason") if test_negative else ""
43+
44+
self.use_onert: bool = TestTag.get(type(self.nnmodule), "use_onert", False)
4045

4146
@abstractmethod
4247
def make(self):
@@ -79,16 +84,17 @@ def _get_nnmodules(self, submodule: str):
7984
)
8085
)
8186

82-
# If any of the nnmodule_classes has a tag `__tag_target`, only those nnmodule_classes will be added
87+
# If any of the nnmodule_classes is marked as target, only those will be added
8388
target_only: bool = any(
84-
hasattr(nnmodule_cls, "__tag_target") for nnmodule_cls in nnmodule_classes
89+
TestTag.get(nnmodule_cls, "target", False)
90+
for nnmodule_cls in nnmodule_classes
8591
)
8692

8793
if target_only:
8894
nnmodule_classes = [
8995
nnmodule_cls
9096
for nnmodule_cls in nnmodule_classes
91-
if hasattr(nnmodule_cls, "__tag_target")
97+
if TestTag.get(nnmodule_cls, "target", False)
9298
]
9399

94100
return nnmodule_classes

test/utils/tag.py

Lines changed: 78 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -12,95 +12,113 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Any, Dict, Type
1516

16-
def skip(reason):
17-
def __inner_skip(orig_class):
18-
setattr(orig_class, "__tag_skip", True)
19-
setattr(orig_class, "__tag_skip_reason", reason)
2017

21-
def __init__(self, *args_, **kwargs_):
22-
pass
18+
class TestTag:
19+
"""Central registry for managing test tag"""
2320

24-
# Ignore initialization of skipped modules
25-
orig_class.__init__ = __init__
21+
_registry: Dict[Type, Dict[str, Any]] = {}
2622

27-
return orig_class
23+
@classmethod
24+
def add(cls, test_class: Type, tag_key: str, tag_value: Any = None) -> None:
25+
"""Add test tag object to a class
2826
29-
return __inner_skip
27+
Args:
28+
test_class: The test class to add tag to
29+
tag_key: Name of Tag object to add
30+
tag_value: Tag object to add
31+
"""
32+
if test_class not in cls._registry:
33+
cls._registry[test_class] = {}
3034

35+
cls._registry[test_class][tag_key] = tag_value
3136

32-
def skip_if(predicate, reason):
33-
def __inner_skip(orig_class):
34-
setattr(orig_class, "__tag_skip", True)
35-
setattr(orig_class, "__tag_skip_reason", reason)
37+
@classmethod
38+
def has(cls, test_class: Type, tag_key: str) -> bool:
39+
"""Check if a class has specific tag type
3640
37-
def __init__(self, *args_, **kwargs_):
38-
pass
41+
Args:
42+
test_class: The test class to check
43+
tag_key: Type of tag object to check for
3944
40-
# Ignore initialization of skipped modules
41-
orig_class.__init__ = __init__
45+
Returns:
46+
bool: True if the tag exists, False otherwise
47+
"""
48+
return test_class in cls._registry and tag_key in cls._registry[test_class]
4249

43-
return orig_class
50+
@classmethod
51+
def get(cls, test_class: Type, tag_key: str, default: Any = None) -> Any:
52+
"""Get tag object for a class
4453
45-
if predicate:
46-
return __inner_skip
47-
else:
48-
return lambda x: x
54+
Args:
55+
test_class: The test class to get tag from
56+
tag_key: Type of tag object to retrieve
57+
default: Default value to return if tag not found
4958
59+
Returns:
60+
The tag object or default if not found
61+
"""
62+
return cls._registry.get(test_class, {}).get(tag_key, default)
5063

51-
def test_without_inference(orig_class):
52-
setattr(orig_class, "__tag_test_without_inference", True)
53-
return orig_class
5464

65+
####################################################################
66+
################## Add tag here ##################
67+
####################################################################
5568

56-
def test_without_pt2(orig_class):
57-
setattr(orig_class, "__tag_test_without_pt2", True)
58-
return orig_class
5969

70+
def skip(reason):
71+
"""
72+
Mark a test class to be skipped with a reason
6073
61-
def test_negative(expected_err):
62-
def __inner_test_negative(orig_class):
63-
setattr(orig_class, "__tag_test_negative", True)
64-
setattr(orig_class, "__tag_expected_err", expected_err)
74+
e.g.
75+
@skip(reason="Not implemented yet")
76+
class MyTest(unittest.TestCase): # <-- This test will be skipped
77+
"""
6578

66-
return orig_class
79+
def decorator(cls):
80+
TestTag.add(cls, "skip", {"reason": reason})
81+
return cls
6782

68-
return __inner_test_negative
83+
return decorator
6984

7085

71-
def target(orig_class):
72-
setattr(orig_class, "__tag_target", True)
73-
return orig_class
86+
def skip_if(predicate, reason):
87+
"""Conditionally mark a test class to be skipped with a reason"""
88+
if predicate:
89+
return skip(reason)
90+
return lambda cls: cls
7491

7592

76-
def use_onert(orig_class):
77-
"""
78-
Decorator to mark a test class so that Circle models are executed
79-
with the 'onert' runtime.
93+
def test_negative(expected_err):
94+
"""Mark a test class as negative test case with expected error"""
8095

81-
Useful when the default 'circle-interpreter' cannot run the model
82-
under test.
83-
"""
84-
setattr(orig_class, "__tag_use_onert", True)
85-
return orig_class
96+
def decorator(cls):
97+
TestTag.add(cls, "test_negative", {"expected_err": expected_err})
98+
return cls
99+
100+
return decorator
86101

87102

88-
def init_args(*args, **kwargs):
89-
def __inner_init_args(orig_class):
90-
orig_init = orig_class.__init__
91-
# Make copy of original __init__, so we can call it without recursion
103+
def target(cls):
104+
"""Mark a test class as target test case"""
105+
TestTag.add(cls, "target")
106+
return cls
92107

93-
def __init__(self, *args_, **kwargs_):
94-
args_ = (*args, *args_)
95-
kwargs_ = {**kwargs, **kwargs_}
96108

97-
orig_init(self, *args_, **kwargs_) # Call the original __init__
109+
def use_onert(cls):
110+
"""Mark a test class to use ONERT runtime"""
111+
TestTag.add(cls, "use_onert")
112+
return cls
98113

99-
orig_class.__init__ = __init__
100-
return orig_class
101114

102-
return __inner_init_args
115+
def test_without_pt2(cls):
116+
"""Mark a test class to not convert along pt2 during test execution"""
117+
TestTag.add(cls, "test_without_pt2")
118+
return cls
103119

104120

105-
def is_tagged(cls, tag: str):
106-
return hasattr(cls, f"__tag_{tag}")
121+
def test_without_inference(cls):
122+
"""Mark a test class to not run inference during test execution"""
123+
TestTag.add(cls, "test_without_inference")
124+
return cls

0 commit comments

Comments
 (0)