Skip to content

Commit 5516a9e

Browse files
Merge pull request #23 from amd/alex_global_args
Added global args support in config
2 parents 72fb0b1 + ecffe28 commit 5516a9e

File tree

5 files changed

+114
-18
lines changed

5 files changed

+114
-18
lines changed

nodescraper/interfaces/dataplugin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def analyze(
241241
Returns:
242242
TaskResult: result of data analysis
243243
"""
244+
244245
if self.ANALYZER is None:
245246
self.analysis_result = TaskResult(
246247
status=ExecutionStatus.NOT_RAN,

nodescraper/interfaces/task.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def max_event_priority_level(self) -> EventPriority:
8080
def max_event_priority_level(self, input_value: str | EventPriority):
8181
if isinstance(input_value, str):
8282
value: EventPriority = getattr(EventPriority, input_value)
83+
elif isinstance(input_value, int):
84+
value = EventPriority(input_value)
8385
elif isinstance(input_value, EventPriority):
8486
value: EventPriority = input_value
8587
else:

nodescraper/pluginexecutor.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from pydantic import BaseModel
3434

3535
from nodescraper.constants import DEFAULT_LOGGER
36-
from nodescraper.interfaces import ConnectionManager, DataPlugin
36+
from nodescraper.interfaces import ConnectionManager, DataPlugin, PluginInterface
3737
from nodescraper.models import PluginConfig, SystemInfo
3838
from nodescraper.models.pluginresult import PluginResult
3939
from nodescraper.pluginregistry import PluginRegistry
@@ -165,16 +165,25 @@ def run_queue(self) -> list[PluginResult]:
165165
plugin_inst = plugin_class(**init_payload)
166166

167167
run_payload = copy.deepcopy(plugin_args)
168-
169168
run_args = TypeUtils.get_func_arg_types(plugin_class.run, plugin_class)
169+
170170
for arg in run_args.keys():
171171
if arg == "preserve_connection" and issubclass(plugin_class, DataPlugin):
172172
run_payload[arg] = True
173-
elif arg in self.plugin_config.global_args:
174-
run_payload[arg] = self.plugin_config.global_args[arg]
175173

176-
# TODO
177-
# enable global substitution in collection and analysis args
174+
try:
175+
global_run_args = self.apply_global_args_to_plugin(
176+
plugin_inst, plugin_class, self.plugin_config.global_args
177+
)
178+
run_payload.update(global_run_args)
179+
except ValueError as ve:
180+
self.logger.error(
181+
"Invalid global_args for plugin %s: %s. Skipping plugin.",
182+
plugin_name,
183+
str(ve),
184+
)
185+
continue
186+
178187
self.logger.info("-" * 50)
179188
plugin_results.append(plugin_inst.run(**run_payload))
180189
except Exception as e:
@@ -210,3 +219,42 @@ def run_queue(self) -> list[PluginResult]:
210219
)
211220

212221
return plugin_results
222+
223+
def apply_global_args_to_plugin(
224+
self,
225+
plugin_inst: PluginInterface,
226+
plugin_class: type,
227+
global_args: dict,
228+
) -> dict:
229+
"""
230+
Applies global arguments to the plugin instance, including standard attributes
231+
and merging Pydantic model arguments (collection_args, analysis_args).
232+
233+
Args:
234+
plugin_inst: The plugin instance to update.
235+
plugin_class: The plugin class (needed for model instantiation).
236+
global_args: Dict of global argument overrides.
237+
"""
238+
239+
run_args = {}
240+
for key in global_args:
241+
if key in ["collection_args", "analysis_args"] and isinstance(plugin_inst, DataPlugin):
242+
continue
243+
else:
244+
run_args[key] = global_args[key]
245+
246+
if "collection_args" in global_args and hasattr(plugin_class, "COLLECTOR_ARGS"):
247+
plugin_fields = set(plugin_class.COLLECTOR_ARGS.__fields__.keys())
248+
filtered = {
249+
k: v for k, v in global_args["collection_args"].items() if k in plugin_fields
250+
}
251+
if filtered:
252+
run_args["collection_args"] = filtered
253+
254+
if "analysis_args" in global_args and hasattr(plugin_class, "ANALYZER_ARGS"):
255+
plugin_fields = set(plugin_class.ANALYZER_ARGS.__fields__.keys())
256+
filtered = {k: v for k, v in global_args["analysis_args"].items() if k in plugin_fields}
257+
if filtered:
258+
run_args["analysis_args"] = filtered
259+
260+
return run_args

test/unit/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class DummyDataModel(DataModel):
5555

5656
class DummyArg(BaseModel):
5757
value: int
58+
regex_match: bool = True
5859

5960

6061
class DummyResult:

test/unit/framework/test_plugin_executor.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,41 @@
2424
#
2525
###############################################################################
2626
import pytest
27-
from common.shared_utils import MockConnectionManager
27+
from common.shared_utils import DummyDataModel, MockConnectionManager
28+
from pydantic import BaseModel
2829

2930
from nodescraper.enums import ExecutionStatus
31+
from nodescraper.enums.eventpriority import EventPriority
32+
from nodescraper.enums.systeminteraction import SystemInteractionLevel
3033
from nodescraper.interfaces import PluginInterface
3134
from nodescraper.models import PluginConfig, PluginResult
3235
from nodescraper.pluginexecutor import PluginExecutor
3336
from nodescraper.pluginregistry import PluginRegistry
3437

3538

36-
class TestPluginA(PluginInterface[MockConnectionManager, None]):
39+
class DummyArgs(BaseModel):
40+
foo: str = "bar"
41+
regex_match: bool = True
42+
3743

44+
class TestPluginA(PluginInterface[MockConnectionManager, None]):
3845
CONNECTION_TYPE = MockConnectionManager
46+
COLLECTOR_ARGS = DummyArgs(foo="initial")
47+
ANALYZER_ARGS = DummyArgs(foo="initial")
48+
collection = False
49+
analysis = False
50+
preserve_connection = False
51+
data = DummyDataModel(some_version="1")
52+
max_event_priority_level = EventPriority.INFO
53+
system_interaction_level = SystemInteractionLevel.PASSIVE
54+
collection_args = None
3955

4056
def run(self):
4157
self._update_queue(("TestPluginB", {}))
42-
return PluginResult(
43-
source="testA",
44-
status=ExecutionStatus.ERROR,
45-
)
58+
return PluginResult(source="testA", status=ExecutionStatus.ERROR)
4659

4760

4861
class TestPluginB(PluginInterface[MockConnectionManager, None]):
49-
5062
CONNECTION_TYPE = MockConnectionManager
5163

5264
def run(self, test_arg=None):
@@ -67,10 +79,7 @@ def plugin_registry():
6779
"input_configs, output_config",
6880
[
6981
(
70-
[
71-
PluginConfig(plugins={"Plugin1": {}}),
72-
PluginConfig(plugins={"Plugin2": {}}),
73-
],
82+
[PluginConfig(plugins={"Plugin1": {}}), PluginConfig(plugins={"Plugin2": {}})],
7483
PluginConfig(plugins={"Plugin1": {}, "Plugin2": {}}),
7584
),
7685
(
@@ -109,7 +118,8 @@ def test_plugin_queue(plugin_registry):
109118

110119
def test_queue_callback(plugin_registry):
111120
executor = PluginExecutor(
112-
plugin_configs=[PluginConfig(plugins={"TestPluginA": {}})], plugin_registry=plugin_registry
121+
plugin_configs=[PluginConfig(plugins={"TestPluginA": {}})],
122+
plugin_registry=plugin_registry,
113123
)
114124

115125
results = executor.run_queue()
@@ -119,3 +129,37 @@ def test_queue_callback(plugin_registry):
119129
assert results[0].status == ExecutionStatus.ERROR
120130
assert results[1].source == "testB"
121131
assert results[1].status == ExecutionStatus.OK
132+
133+
134+
def test_apply_global_args_to_plugin():
135+
plugin = TestPluginA()
136+
global_args = {
137+
"collection": True,
138+
"analysis": True,
139+
"preserve_connection": True,
140+
"data": {"some_version": "1"},
141+
"max_event_priority_level": 4,
142+
"system_interaction_level": "INTERACTIVE",
143+
"collection_args": {"foo": "collected", "regex_match": False, "not_in_model": "skip_this"},
144+
"analysis_args": {"foo": "analyzed", "regex_match": False, "ignore_this": True},
145+
}
146+
147+
executor = PluginExecutor(plugin_configs=[])
148+
run_payload = executor.apply_global_args_to_plugin(plugin, TestPluginA, global_args)
149+
150+
assert run_payload["collection"] is True
151+
assert run_payload["analysis"] is True
152+
assert run_payload["preserve_connection"] is True
153+
assert run_payload["data"]["some_version"] == "1"
154+
assert run_payload["max_event_priority_level"] == 4
155+
assert run_payload["system_interaction_level"] == "INTERACTIVE"
156+
157+
# Safely check filtered args
158+
assert run_payload.get("collection_args") == {
159+
"foo": "collected",
160+
"regex_match": False,
161+
}
162+
assert run_payload.get("analysis_args") == {
163+
"foo": "analyzed",
164+
"regex_match": False,
165+
}

0 commit comments

Comments
 (0)