Skip to content

Commit 0b30697

Browse files
committed
Path attrib plugin to work with generators
Signed-off-by: Krzysztof Lecki <[email protected]>
1 parent 22c0541 commit 0b30697

File tree

5 files changed

+143
-37
lines changed

5 files changed

+143
-37
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Custom nose2 plugin to filter generator test functions by attributes
17+
before they are called (preventing imports of optional dependencies or other code execution).
18+
19+
This plugin monkey-patches the Generators plugin's _testsFromGeneratorFunc
20+
method to check attributes before calling generator functions.
21+
"""
22+
from nose2.events import Plugin
23+
import logging
24+
25+
log = logging.getLogger(__name__)
26+
27+
28+
class AttributeGeneratorFilter(Plugin):
29+
"""Filter generator functions by attributes before calling them."""
30+
31+
configSection = "attrib-generators"
32+
alwaysOn = True
33+
34+
def __init__(self):
35+
super().__init__()
36+
self._patched = False
37+
38+
def _get_attrib_plugin(self):
39+
"""Get the attrib plugin from the session."""
40+
for plugin in self.session.plugins:
41+
if plugin.__class__.__name__ == "AttributeSelector":
42+
return plugin
43+
return None
44+
45+
def _build_attribs_list(self, attrib_plugin):
46+
"""Build the attribs list from the attrib plugin's -A configuration.
47+
48+
This replicates the logic from AttributeSelector.moduleLoadedSuite
49+
for -A filters only (not -E eval filters).
50+
"""
51+
attribs = []
52+
53+
# Handle -A (attribute) filters
54+
for attr in attrib_plugin.attribs:
55+
attr_group = []
56+
for attrib in attr.strip().split(","):
57+
if not attrib:
58+
continue
59+
items = attrib.split("=", 1)
60+
if len(items) > 1:
61+
# "name=value"
62+
key, value = items
63+
else:
64+
key = items[0]
65+
if key[0] == "!":
66+
# "!name"
67+
key = key[1:]
68+
value = False
69+
else:
70+
# "name"
71+
value = True
72+
attr_group.append((key, value))
73+
attribs.append(attr_group)
74+
75+
return attribs
76+
77+
def _matches_attrib_filter(self, test_func, attrib_plugin):
78+
"""Check if test_func matches the attribute filter from attrib plugin."""
79+
if not attrib_plugin:
80+
return True
81+
82+
if not attrib_plugin.attribs:
83+
return True
84+
85+
# Build attribs list using attrib plugin's logic
86+
attribs = self._build_attribs_list(attrib_plugin)
87+
88+
if not attribs:
89+
return True
90+
91+
# Use the plugin's validateAttrib method
92+
return attrib_plugin.validateAttrib(test_func, attribs)
93+
94+
def _patch_generator_plugin(self):
95+
"""Monkey-patch the Generators plugin to check attributes first."""
96+
if self._patched:
97+
return
98+
99+
# Find the Generators plugin
100+
gen_plugin = None
101+
for plugin in self.session.plugins:
102+
if plugin.__class__.__name__ == "Generators":
103+
gen_plugin = plugin
104+
break
105+
106+
if not gen_plugin:
107+
log.warning("Could not find Generators plugin to patch")
108+
return
109+
110+
# Save original method
111+
original_tests_from_gen = gen_plugin._testsFromGeneratorFunc
112+
attrib_filter_self = self
113+
114+
# Create patched method
115+
def patched_tests_from_gen(event, obj):
116+
"""Check attributes before calling generator function."""
117+
attrib_plugin = attrib_filter_self._get_attrib_plugin()
118+
119+
# Check if generator function matches attribute filter
120+
if not attrib_filter_self._matches_attrib_filter(obj, attrib_plugin):
121+
log.debug(f"Skipping generator {obj.__name__} due to attribute filter")
122+
return [] # Return empty list
123+
124+
# Call original method
125+
return original_tests_from_gen(event, obj)
126+
127+
# Monkey-patch it
128+
gen_plugin._testsFromGeneratorFunc = patched_tests_from_gen
129+
self._patched = True
130+
log.debug("Patched Generators plugin to check attributes")
131+
132+
def handleArgs(self, event):
133+
"""Patch right after argument handling, before test discovery."""
134+
self._patch_generator_plugin()

dali/test/python/test_fw_iterators_detection.py

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from nvidia.dali.pipeline import Pipeline
1818

1919
from test_utils import get_dali_extra_path
20-
from nose_utils import assert_raises, attr
20+
from nose_utils import assert_raises, attr, nottest
2121

2222
DALI_EXTRA_PATH = get_dali_extra_path()
2323
EPOCH_SIZE = 32
@@ -54,26 +54,6 @@ def data_paths():
5454
##############
5555

5656

57-
def test_mxnet_pipeline_dynamic_shape():
58-
from nvidia.dali.plugin.mxnet import DALIGenericIterator as MXNetIterator
59-
60-
root, annotations = data_paths()
61-
pipeline = DetectionPipeline(BATCH_SIZE, 0, root, annotations)
62-
train_loader = MXNetIterator(
63-
[pipeline],
64-
[
65-
("data", MXNetIterator.DATA_TAG),
66-
("bboxes", MXNetIterator.LABEL_TAG),
67-
("label", MXNetIterator.LABEL_TAG),
68-
],
69-
EPOCH_SIZE,
70-
auto_reset=False,
71-
dynamic_shape=True,
72-
)
73-
for data in train_loader:
74-
assert data is not None
75-
76-
7757
@attr("pytorch")
7858
def test_pytorch_pipeline_dynamic_shape():
7959
from nvidia.dali.plugin.pytorch import DALIGenericIterator as PyTorchIterator
@@ -127,6 +107,7 @@ def test_api_fw_check1_paddle():
127107
yield from test_api_fw_check1(PaddleIterator, ["data", "bboxes", "label"])
128108

129109

110+
@nottest
130111
def test_api_fw_check1(iter_type, data_definition):
131112
root, annotations = data_paths()
132113
pipe = DetectionPipeline(BATCH_SIZE, 0, root, annotations)
@@ -163,19 +144,6 @@ def test_api_fw_check1(iter_type, data_definition):
163144
yield check, iter_type
164145

165146

166-
def test_api_fw_check2_mxnet():
167-
from nvidia.dali.plugin.mxnet import DALIGenericIterator as MXNetIterator
168-
169-
yield from test_api_fw_check2(
170-
MXNetIterator,
171-
[
172-
("data", MXNetIterator.DATA_TAG),
173-
("bboxes", MXNetIterator.LABEL_TAG),
174-
("label", MXNetIterator.LABEL_TAG),
175-
],
176-
)
177-
178-
179147
@attr("pytorch")
180148
def test_api_fw_check2_pytorch():
181149
from nvidia.dali.plugin.pytorch import DALIGenericIterator as PyTorchIterator
@@ -190,6 +158,7 @@ def test_api_fw_check2_paddle():
190158
yield from test_api_fw_check2(PaddleIterator, ["data", "bboxes", "label"])
191159

192160

161+
@nottest
193162
def test_api_fw_check2(iter_type, data_definition):
194163
root, annotations = data_paths()
195164

dali/test/python/unittest.cfg

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[unittest]
2-
plugins = nose2.plugins.attrib
2+
plugins = nose2_attrib_generators
3+
nose2.plugins.attrib
34
nose2.plugins.collect
45
nose2.plugins.printhooks
56

dali/test/python/unittest_failure.cfg

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[unittest]
2-
plugins = nose2.plugins.attrib
2+
plugins = nose2_attrib_generators
3+
nose2.plugins.attrib
34
nose2.plugins.collect
45
nose2.plugins.printhooks
56

dali/test/python/unittest_slow.cfg

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[unittest]
2-
plugins = nose2.plugins.attrib
2+
plugins = nose2_attrib_generators
3+
nose2.plugins.attrib
34
nose2.plugins.collect
45
nose2.plugins.printhooks
56

0 commit comments

Comments
 (0)