Skip to content

Commit aa752cb

Browse files
committed
refact(op): simplify hierarchy, Op fully abstract
1 parent fca66a0 commit aa752cb

File tree

4 files changed

+35
-74
lines changed

4 files changed

+35
-74
lines changed

graphtik/netop.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ def __init__(
6464
:raises ValueError:
6565
see :meth:`narrow()`
6666
"""
67+
self.name = name
68+
self.inputs = inputs
69+
self.provides = outputs
6770
self.net = net.pruned(inputs, outputs)
6871
## Set data asap, for debugging, although `prune()` will reset them.
69-
super().__init__(name, inputs, outputs)
7072
self.set_execution_method(method)
7173
self.set_overwrites_collector(overwrites_collector)
7274

graphtik/op.py

Lines changed: 29 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import abc
66
import logging
77
from collections import abc as cabc
8-
from typing import Callable
8+
from typing import Callable, Collection, Tuple, Union
99

1010
from boltons.setutils import IndexedSet as iset
1111

@@ -43,43 +43,6 @@ def reparse_operation_data(name, needs, provides):
4343
class Operation(abc.ABC):
4444
"""An abstract class representing a data transformation by :meth:`.compute()`."""
4545

46-
def __init__(self, name, needs=None, provides=None):
47-
"""
48-
Create a new layer instance.
49-
Names may be given to this layer and its inputs and outputs. This is
50-
important when connecting layers and data in a Network object, as the
51-
names are used to construct the graph.
52-
53-
:param str name:
54-
The name the operation (e.g. conv1, conv2, etc..)
55-
56-
:param list needs:
57-
Names of input data objects this layer requires.
58-
59-
:param list provides:
60-
Names of output data objects this provides.
61-
62-
"""
63-
64-
# (Optional) names for this layer, and the data it needs and provides
65-
self.name = name
66-
self.needs = needs
67-
self.provides = provides
68-
69-
def __eq__(self, other):
70-
"""
71-
Operation equality is based on name of layer.
72-
(__eq__ and __hash__ must be overridden together)
73-
"""
74-
return bool(self.name is not None and self.name == getattr(other, "name", None))
75-
76-
def __hash__(self):
77-
"""
78-
Operation equality is based on name of layer.
79-
(__eq__ and __hash__ must be overridden together)
80-
"""
81-
return hash(self.name)
82-
8346
@abc.abstractmethod
8447
def compute(self, named_inputs, outputs=None):
8548
"""
@@ -89,23 +52,12 @@ def compute(self, named_inputs, outputs=None):
8952
End-users should simply call the operation with `named_inputs` as kwargs.
9053
9154
:param list named_inputs:
92-
A list of :class:`Data` objects on which to run the layer's
93-
feed-forward computation.
55+
the input values with which to feed the computation.
9456
:returns list:
9557
Should return a list values representing
9658
the results of running the feed-forward computation on
9759
``inputs``.
9860
"""
99-
pass
100-
101-
def __repr__(self):
102-
"""
103-
Display more informative names for the Operation class
104-
"""
105-
clsname = type(self).__name__
106-
needs = aslist(self.needs, "needs")
107-
provides = aslist(self.provides, "provides")
108-
return f"{clsname}({self.name!r}, needs={needs!r}, provides={provides!r})"
10961

11062

11163
class FunctionalOperation(Operation):
@@ -119,19 +71,38 @@ def __init__(
11971
self,
12072
fn: Callable,
12173
name,
122-
needs=None,
123-
provides=None,
74+
needs: Union[Collection, str] = None,
75+
provides: Union[Collection, str] = None,
12476
*,
125-
parents=None,
77+
parents: Tuple = None,
12678
returns_dict=None,
12779
):
80+
"""
81+
Build a new operation out of some function and its requirements.
82+
83+
:param name:
84+
a name for the operation (e.g. `'conv1'`, `'sum'`, etc..);
85+
it will be prefixed by `parents`.
86+
:param needs:
87+
Names of input data objects this operation requires.
88+
:param provides:
89+
Names of output data objects this provides.
90+
:param parent:
91+
a tuple wth the names of the parents, prefixing `name`,
92+
but also kept for equality/hash check.
93+
94+
"""
95+
## Set op-data early, for repr() to work on errors.
12896
self.fn = fn
97+
self.name = name
98+
self.needs = needs
99+
self.provides = provides
129100
self.parents = parents
130101
self.returns_dict = returns_dict
131-
## Set op-data early, for repr() to work on errors.
132-
super().__init__(name, needs, provides)
133102
if not fn or not callable(fn):
134-
raise ValueError(f"Operation was not provided with a callable: {self.fn}")
103+
raise ValueError(
104+
f"Operation was not provided with a callable: {fn}\n {self}"
105+
)
135106
if parents and not isinstance(parents, tuple):
136107
raise ValueError(
137108
f"Operation `parents` must be tuple, was {parents}\n {self}"
@@ -143,21 +114,15 @@ def __init__(
143114
)
144115

145116
def __eq__(self, other):
146-
"""
147-
Operation equality is based on name of layer.
148-
(__eq__ and __hash__ must be overridden together)
149-
"""
117+
"""Operation identity is based on `name` and `parents`."""
150118
return bool(
151119
self.name is not None
152120
and self.name == getattr(other, "name", None)
153121
and self.parents == getattr(other, "parents", None)
154122
)
155123

156124
def __hash__(self):
157-
"""
158-
Operation equality is based on name of layer.
159-
(__eq__ and __hash__ must be overridden together)
160-
"""
125+
"""Operation identity is based on `name` and `parents`."""
161126
return hash(self.name) ^ hash(self.parents)
162127

163128
def __repr__(self):

test/test_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ def test_jetsam_sites_screaming_func(acallable, expected_jetsam):
163163

164164
class DummyOperation(op.Operation):
165165
def __init__(self):
166-
super().__init__("", (), ("a"))
166+
self.name = ("",)
167+
self.needs = ()
168+
self.provides = ("a",)
167169

168170
def compute(self, named_inputs, outputs=None):
169171
pass

test/test_op.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,6 @@ def opprovides(request):
2222
return request.param
2323

2424

25-
class MyOp(Operation):
26-
def compute(self):
27-
pass
28-
29-
3025
def test_repr_smoke(opname, opneeds, opprovides):
3126
# Simply check __repr__() does not crash on partial attributes.
3227
kw = locals().copy()
@@ -35,9 +30,6 @@ def test_repr_smoke(opname, opneeds, opprovides):
3530
op = operation(**kw)
3631
str(op)
3732

38-
op = MyOp(**kw)
39-
str(op)
40-
4133

4234
def test_repr_returns_dict():
4335
assert (

0 commit comments

Comments
 (0)