Skip to content
20 changes: 18 additions & 2 deletions opto/trace/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,25 @@ def is_valid_output(output):
isinstance(output, tuple) and all([isinstance(o, Node) for o in output])
)

def __get__(self, obj, objtype):

# Define __set_name__ and __get__ for FunModule to act as a descriptor.
def __get__(self, obj, db_type):
if obj is None: # class method
return self
# Support instance methods.
return functools.partial(self.__call__, obj)
method_name = f'__TRACE_RESERVED_bundle_{self.name}' # NOTE we assume these are secret names not taken
obj_node_name = f'__TRACE_RESERVED_self_node'
if not hasattr(obj, obj_node_name):
setattr(obj, obj_node_name, node(obj))
if not hasattr(obj, method_name):
funmodule = copy.deepcopy(self) # instance specific version
funmodule.forward = functools.partial(funmodule.forward, getattr(obj, obj_node_name))
setattr(obj, method_name, funmodule)
fun = getattr(obj, method_name)
assert fun is not self # self is defined in the class level
assert isinstance(fun, FunModule), f"Expected {method_name} to be a FunModule, but got {type(fun)}"
# fun = functools.partial(self.__call__, obj)
return fun

def detach(self):
return copy.deepcopy(self)
Expand Down
5 changes: 4 additions & 1 deletion opto/trace/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ def parameters_dict(self):
"""
parameters = {}
for name, attr in inspect.getmembers(self):
if name.startswith('__TRACE_RESERVED_'):
# These are reserved for internal use.
continue
if isinstance(attr, functools.partial): # this is a class method
method = attr.func.__self__
if trainable_method(method):
parameters[name] = method.parameter
elif trainable_method(attr): # method attribute
if trainable_method(attr): # method attribute
parameters[name] = attr.parameter
elif isinstance(attr, ParameterNode):
parameters[name] = attr
Expand Down
1 change: 1 addition & 0 deletions opto/trace/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def __deepcopy__(self, memo):
setattr(result, k, defaultdict(list))
else:
setattr(result, k, copy.deepcopy(v, memo))
GRAPH.register(result)
return result

def lt(self, other):
Expand Down
2 changes: 1 addition & 1 deletion opto/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.3.3"
__version__ = "0.1.3.4"
149 changes: 149 additions & 0 deletions tests/unit_tests/test_class_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from opto import trace
from copy import deepcopy, copy

@trace.model
class Model:

@trace.bundle(trainable=True)
def forward(self, x):
return x + 1


def test_case_two_models():
m1 = Model()
m2 = Model()

# Make sure the parameters are different
try:
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node
except AttributeError:
# These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed.
pass

# The hidden nodes are defined now
assert len(m1.parameters()) == 1
assert len(m2.parameters()) == 1

# Make sure the parameters are different
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node # they are defined now
assert m1.parameters()[0] is not m2.parameters()[0]

# check that the reserved node is the returned parameter
assert getattr(m1, '__TRACE_RESERVED_bundle_Model.forward').parameter is m1.parameters()[0]
assert getattr(m2, '__TRACE_RESERVED_bundle_Model.forward').parameter is m2.parameters()[0]

# each instance has a version different from the class' version
assert m1.forward != m2.forward
assert m1.forward != Model.forward
assert m2.forward.parameter == Model.forward.parameter == m1.forward.parameter

y1 = m1.forward(1)
y2 = m1.forward(2)

from opto.trace.utils import contain
# self is not duplicated
assert contain(y1.parents, m1.__TRACE_RESERVED_self_node)
assert contain(y2.parents, m1.__TRACE_RESERVED_self_node)
# assert m1.__TRACE_RESERVED_self_node in y1.parents
# assert m1.__TRACE_RESERVED_self_node in y2.parents
assert contain(y1.parents, m1.forward.parameter)
assert contain(y2.parents, m1.forward.parameter)
# assert m1.forward.parameter in y1.parents
# assert m1.forward.parameter in y2.parents
assert len(y1.parents) == 3 # since it's trainable
assert len(y2.parents) == 3

def test_case_model_copy():
m1 = Model()
m2 = deepcopy(m1)

# Make sure the parameters are different
try:
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node
except AttributeError:
# These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed.
pass

# The hidden nodes are defined now
assert len(m1.parameters()) == 1
assert len(m2.parameters()) == 1

# Make sure the parameters are different
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node # they are defined now
assert m1.parameters()[0] is not m2.parameters()[0]

# check that the reserved node is the returned parameter
assert getattr(m1, '__TRACE_RESERVED_bundle_Model.forward').parameter is m1.parameters()[0]
assert getattr(m2, '__TRACE_RESERVED_bundle_Model.forward').parameter is m2.parameters()[0]

# each instance has a version different from the class' version
assert m1.forward is not m2.forward
assert m1.forward is not Model.forward
assert m2.forward.parameter == Model.forward.parameter == m1.forward.parameter

y1 = m1.forward(1)
y2 = m2.forward(2)

from opto.trace.utils import contain
# self is not duplicated
assert contain(y1.parents, m1.__TRACE_RESERVED_self_node)
assert contain(y2.parents, m2.__TRACE_RESERVED_self_node)
# assert m1.__TRACE_RESERVED_self_node in y1.parents
# assert m1.__TRACE_RESERVED_self_node in y2.parents
assert contain(y1.parents, m1.forward.parameter)
assert contain(y2.parents, m2.forward.parameter)

# assert m1.forward.parameter in y1.parents
# assert m1.forward.parameter in y2.parents
assert len(y1.parents) == 3 # since it's trainable
assert len(y2.parents) == 3

def test_case_model_nested_copy():
m1 = Model()
m3 = deepcopy(m1)
m2 = deepcopy(m3)

# Make sure the parameters are different
try:
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node
except AttributeError:
# These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed.
pass

# The hidden nodes are defined now
assert len(m1.parameters()) == 1
assert len(m2.parameters()) == 1

# Make sure the parameters are different
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node # they are defined now
assert m1.parameters()[0] is not m2.parameters()[0]

# check that the reserved node is the returned parameter
assert getattr(m1, '__TRACE_RESERVED_bundle_Model.forward').parameter is m1.parameters()[0]
assert getattr(m2, '__TRACE_RESERVED_bundle_Model.forward').parameter is m2.parameters()[0]

# each instance has a version different from the class' version
assert m1.forward is not m2.forward
assert m1.forward is not Model.forward
assert m2.forward.parameter == Model.forward.parameter == m1.forward.parameter

y1 = m1.forward(1)
y2 = m2.forward(2)

from opto.trace.utils import contain
# self is not duplicated
assert contain(y1.parents, m1.__TRACE_RESERVED_self_node)
assert contain(y2.parents, m2.__TRACE_RESERVED_self_node)
# assert m1.__TRACE_RESERVED_self_node in y1.parents
# assert m1.__TRACE_RESERVED_self_node in y2.parents
assert contain(y1.parents, m1.forward.parameter)
assert contain(y2.parents, m2.forward.parameter)

# assert m1.forward.parameter in y1.parents
# assert m1.forward.parameter in y2.parents
assert len(y1.parents) == 3 # since it's trainable
assert len(y2.parents) == 3

test_case_two_models()
test_case_model_copy()
test_case_model_nested_copy()