Skip to content
20 changes: 18 additions & 2 deletions opto/trace/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,25 @@ def wrap(self, output: Any, inputs: Union[List[Node], Dict[str, Node]], external
def is_valid_output(output):
return isinstance(output, Node) or (isinstance(output, tuple) and all([isinstance(o, Node) for o in output]))

def __get__(self, obj, objtype):

# Define __set_name__ and __get__ for FunModule to act as a descriptor.
def __get__(self, obj, db_type):
if obj is None: # class method
return self
# Support instance methods.
return functools.partial(self.__call__, obj)
method_name = f'__TRACE_RESERVED_bundle_{self.name}' # NOTE we assume these are secret names not taken
obj_node_name = f'__TRACE_RESERVED_self_node'
if not hasattr(obj, obj_node_name):
setattr(obj, obj_node_name, node(obj))
if not hasattr(obj, method_name):
funmodule = copy.deepcopy(self)
funmodule.forward = functools.partial(self.forward, getattr(obj, obj_node_name))
setattr(obj, method_name, funmodule)
fun = getattr(obj, method_name)
assert fun is not self # self is defined in the class level
assert isinstance(fun, FunModule), f"Expected {method_name} to be a FunModule, but got {type(fun)}"
# fun = functools.partial(self.__call__, obj)
return fun

def detach(self):
return copy.deepcopy(self)
Expand Down
6 changes: 4 additions & 2 deletions opto/trace/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,19 @@ def parameters_dict(self):
"""
parameters = {}
for name, attr in inspect.getmembers(self):
if name.startswith('__TRACE_RESERVED_'):
# These are reserved for internal use.
continue
if isinstance(attr, functools.partial): # this is a class method
method = attr.func.__self__
if trainable_method(method):
parameters[name] = method.parameter
elif trainable_method(attr): # method attribute
if trainable_method(attr): # method attribute
parameters[name] = attr.parameter
elif isinstance(attr, ParameterNode):
parameters[name] = attr
elif isinstance(attr, ParameterContainer):
parameters[name] = attr

assert all(isinstance(v, (ParameterNode, ParameterContainer)) for v in parameters.values())

return parameters # include both trainable and non-trainable parameters
Expand Down
2 changes: 1 addition & 1 deletion opto/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.3.3"
__version__ = "0.1.3.4"
39 changes: 39 additions & 0 deletions tests/unit_tests/test_class_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from opto import trace


@trace.model
class Model:

@trace.bundle(trainable=True)
def forward(self, x):
return x + 1


m1 = Model()
m2 = Model()
try:
assert m1.__TRACE_RESERVED_self_node != m2.__TRACE_RESERVED_self_node
except AttributeError:
# These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed.
pass

assert len(m1.parameters()) == 1
assert len(m2.parameters()) == 1

assert m1.__TRACE_RESERVED_self_node != m2.__TRACE_RESERVED_self_node # they are defined now

# each instance has a version different from the class' version
assert m1.forward != m2.forward
assert m1.forward != Model.forward
assert m2.forward.parameter == Model.forward.parameter == m1.forward.parameter

y1 = m1.forward(1)
y2 = m1.forward(2)

# self is not duplicated
assert m1.__TRACE_RESERVED_self_node in y1.parents
assert m1.__TRACE_RESERVED_self_node in y2.parents
assert m1.forward.parameter in y1.parents
assert m1.forward.parameter in y2.parents
assert len(y1.parents) == 3 # since it's trainable
assert len(y2.parents) == 3
Loading