Skip to content

Commit 6c180af

Browse files
committed
ENH(op): allow OpFuncs returning Dicts
1 parent 891c46f commit 6c180af

File tree

2 files changed

+60
-17
lines changed

2 files changed

+60
-17
lines changed

graphtik/op.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,11 @@ def __repr__(self):
101101
class FunctionalOperation(Operation):
102102
"""Use operation() to build instances of this class instead"""
103103

104-
def __init__(self, fn=None, name=None, needs=None, provides=None):
104+
def __init__(
105+
self, fn=None, name=None, needs=None, provides=None, *, returns_dict=None
106+
):
105107
self.fn = fn
108+
self.returns_dict = returns_dict
106109
## Set op-data early, for repr() to work on errors.
107110
Operation.__init__(self, name=name, needs=needs, provides=provides)
108111
if not fn or not callable(fn):
@@ -119,7 +122,11 @@ def __repr__(self):
119122
needs = aslist(self.needs, "needs")
120123
provides = aslist(self.provides, "provides")
121124
fn_name = self.fn and getattr(self.fn, "__name__", str(self.fn))
122-
return f"FunctionalOperation(name={self.name!r}, needs={needs!r}, provides={provides!r}, fn={fn_name!r})"
125+
returns_dict_marker = self.returns_dict and "{}" or ""
126+
return (
127+
f"FunctionalOperation(name={self.name!r}, needs={needs!r}, "
128+
f"provides={provides!r}, fn{returns_dict_marker}={fn_name!r})"
129+
)
123130

124131
def compute(self, named_inputs, outputs=None):
125132
try:
@@ -143,18 +150,19 @@ def compute(self, named_inputs, outputs=None):
143150
results = self.fn(*args, **optionals)
144151

145152
if not provides:
146-
# All outputs were sideffects.
153+
# All outputs were sideffects?
147154
return {}
148155

149-
if len(provides) == 1:
150-
results = [results]
156+
if not self.returns_dict:
157+
if len(provides) == 1:
158+
results = [results]
151159

152-
results = zip(provides, results)
153-
if outputs:
154-
outputs = set(n for n in outputs if not isinstance(n, sideffect))
155-
results = {key: val for key, val in results if key in outputs}
156-
else:
157-
results = dict(results)
160+
results = zip(provides, results)
161+
if outputs:
162+
outputs = set(n for n in outputs if not isinstance(n, sideffect))
163+
results = {key: val for key, val in results if key in outputs}
164+
else:
165+
results = dict(results)
158166

159167
return results
160168
except Exception as ex:
@@ -190,6 +198,12 @@ class operation:
190198
correspond to the ``args`` of ``fn``.
191199
:param list provides:
192200
Names of output data objects this operation provides.
201+
If more than one given, those must be returned in an iterable,
202+
unless `returns_dict` is true, in which cae a dictionary with as many
203+
elements must be returned
204+
:param bool returns_dict:
205+
if true, it means the `fn` returns a dictionary with all `provides`,
206+
and no further processing is done on them.
193207
194208
:return:
195209
when called, it returns a :class:`FunctionalOperation`
@@ -216,10 +230,16 @@ class operation:
216230

217231
fn = name = needs = provides = None
218232

219-
def __init__(self, fn=None, *, name=None, needs=None, provides=None):
220-
self.withset(fn=fn, name=name, needs=needs, provides=provides)
233+
def __init__(
234+
self, fn=None, *, name=None, needs=None, provides=None, returns_dict=None
235+
):
236+
self.withset(
237+
fn=fn, name=name, needs=needs, provides=provides, returns_dict=returns_dict
238+
)
221239

222-
def withset(self, fn=None, *, name=None, needs=None, provides=None):
240+
def withset(
241+
self, *, fn=None, name=None, needs=None, provides=None, returns_dict=None
242+
):
223243
if fn is not None:
224244
self.fn = fn
225245
if name is not None:
@@ -228,10 +248,14 @@ def withset(self, fn=None, *, name=None, needs=None, provides=None):
228248
self.needs = needs
229249
if provides is not None:
230250
self.provides = provides
251+
if returns_dict is not None:
252+
self.returns_dict = returns_dict
231253

232254
return self
233255

234-
def __call__(self, fn=None, *, name=None, needs=None, provides=None):
256+
def __call__(
257+
self, fn=None, *, name=None, needs=None, provides=None, returns_dict=None
258+
):
235259
"""
236260
This enables ``operation`` to act as a decorator or as a functional
237261
operation, for example::
@@ -254,7 +278,9 @@ def myadd(a, b):
254278
composed into a computation graph.
255279
"""
256280

257-
self.withset(fn=fn, name=name, needs=needs, provides=provides)
281+
self.withset(
282+
fn=fn, name=name, needs=needs, provides=provides, returns_dict=returns_dict
283+
)
258284

259285
return FunctionalOperation(**vars(self))
260286

test/test_op.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def compute(self):
2727
pass
2828

2929

30-
def test_operation_repr(opname, opneeds, opprovides):
30+
def test_operation_repr_smoke(opname, opneeds, opprovides):
3131
# Simply check __repr__() does not crash on partial attributes.
3232
kw = locals().copy()
3333
kw = {name[2:]: arg for name, arg in kw.items()}
@@ -39,6 +39,13 @@ def test_operation_repr(opname, opneeds, opprovides):
3939
str(op)
4040

4141

42+
def test_operation_repr_returns_dict():
43+
assert (
44+
str(operation(lambda: None, returns_dict=True)())
45+
== "FunctionalOperation(name=None, needs=[], provides=[], fn{}='<lambda>')"
46+
)
47+
48+
4249
@pytest.mark.parametrize(
4350
"opargs, exp",
4451
[
@@ -68,3 +75,13 @@ def test_operation_validation(opargs, exp):
6875
reparse_operation_data(*opargs)
6976
else:
7077
assert reparse_operation_data(*opargs) == exp
78+
79+
80+
def test_operation_returns_dict():
81+
result = {"a": 1}
82+
83+
op = operation(lambda: result, provides="a", returns_dict=True)()
84+
assert op.compute({}, ["a"]) == result
85+
86+
op = operation(lambda: 1, provides="a", returns_dict=False)()
87+
assert op.compute({}, ["a"]) == result

0 commit comments

Comments
 (0)