Skip to content

Commit 27cb6f7

Browse files
authored
Merge pull request #939 from mgxd/fix/tagging
FIX: Allow passing arguments through tag decorator
2 parents c55c945 + 614dee1 commit 27cb6f7

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

niworkflows/engine/splicer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging
44
import typing as ty
5+
from functools import wraps
56

67
import nipype.pipeline.engine as pe
78
from nipype.pipeline.engine.base import EngineBase
@@ -14,8 +15,9 @@ def tag(tag: str) -> ty.Callable:
1415
This is used to mark nodes or workflows for replacement in the splicing process.
1516
"""
1617

17-
def _decorator(func, *args, **kwargs) -> ty.Callable:
18-
def _tag() -> EngineBase:
18+
def _decorator(func) -> ty.Callable:
19+
@wraps(func)
20+
def _tag(*args, **kwargs) -> EngineBase:
1921
node = func(*args, **kwargs)
2022
node._tag = tag
2123
return node

niworkflows/engine/tests/test_splicer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,11 @@ def test_splice(wf0):
155155
assert wf.get_node('c_wf') is None
156156

157157

158-
def test_tag():
159-
@tag('foo')
160-
def init_workflow():
161-
return Workflow(name='foo')
162-
163-
assert init_workflow()._tag == 'foo'
158+
@pytest.mark.parametrize('name', ['foo'])
159+
def test_tag(name):
160+
@tag(name)
161+
def init_workflow(name, *, xarg: str):
162+
return Workflow(name=name)
163+
164+
wf = init_workflow(name, xarg='bar')
165+
assert wf._tag == name

0 commit comments

Comments
 (0)