diff --git a/opto/trace/bundle.py b/opto/trace/bundle.py index 048aed45..57b9901a 100644 --- a/opto/trace/bundle.py +++ b/opto/trace/bundle.py @@ -596,9 +596,25 @@ def is_valid_output(output): isinstance(output, tuple) and all([isinstance(o, Node) for o in output]) ) - def __get__(self, obj, objtype): + + # Define __set_name__ and __get__ for FunModule to act as a descriptor. + def __get__(self, obj, db_type): + if obj is None: # class method + return self # Support instance methods. - return functools.partial(self.__call__, obj) + method_name = f'__TRACE_RESERVED_bundle_{self.name}' # NOTE we assume these are secret names not taken + obj_node_name = f'__TRACE_RESERVED_self_node' + if not hasattr(obj, obj_node_name): + setattr(obj, obj_node_name, node(obj)) + if not hasattr(obj, method_name): + funmodule = copy.deepcopy(self) # instance specific version + funmodule.forward = functools.partial(funmodule.forward, getattr(obj, obj_node_name)) + setattr(obj, method_name, funmodule) + fun = getattr(obj, method_name) + assert fun is not self # self is defined in the class level + assert isinstance(fun, FunModule), f"Expected {method_name} to be a FunModule, but got {type(fun)}" + # fun = functools.partial(self.__call__, obj) + return fun def detach(self): return copy.deepcopy(self) diff --git a/opto/trace/containers.py b/opto/trace/containers.py index 85b2f0e9..a216118d 100644 --- a/opto/trace/containers.py +++ b/opto/trace/containers.py @@ -42,11 +42,14 @@ def parameters_dict(self): """ parameters = {} for name, attr in inspect.getmembers(self): + if name.startswith('__TRACE_RESERVED_'): + # These are reserved for internal use. + continue if isinstance(attr, functools.partial): # this is a class method method = attr.func.__self__ if trainable_method(method): parameters[name] = method.parameter - elif trainable_method(attr): # method attribute + if trainable_method(attr): # method attribute parameters[name] = attr.parameter elif isinstance(attr, ParameterNode): parameters[name] = attr diff --git a/opto/trace/nodes.py b/opto/trace/nodes.py index 6719aeb3..fbadeeb6 100644 --- a/opto/trace/nodes.py +++ b/opto/trace/nodes.py @@ -455,6 +455,7 @@ def __deepcopy__(self, memo): setattr(result, k, defaultdict(list)) else: setattr(result, k, copy.deepcopy(v, memo)) + GRAPH.register(result) return result def lt(self, other): diff --git a/opto/version.py b/opto/version.py index 223cabec..df98922f 100644 --- a/opto/version.py +++ b/opto/version.py @@ -1 +1 @@ -__version__ = "0.1.3.3" \ No newline at end of file +__version__ = "0.1.3.4" \ No newline at end of file diff --git a/tests/unit_tests/test_class_method.py b/tests/unit_tests/test_class_method.py new file mode 100644 index 00000000..2bf62600 --- /dev/null +++ b/tests/unit_tests/test_class_method.py @@ -0,0 +1,149 @@ +from opto import trace +from copy import deepcopy, copy + +@trace.model +class Model: + + @trace.bundle(trainable=True) + def forward(self, x): + return x + 1 + + +def test_case_two_models(): + m1 = Model() + m2 = Model() + + # Make sure the parameters are different + try: + assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node + except AttributeError: + # These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed. + pass + + # The hidden nodes are defined now + assert len(m1.parameters()) == 1 + assert len(m2.parameters()) == 1 + + # Make sure the parameters are different + assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node # they are defined now + assert m1.parameters()[0] is not m2.parameters()[0] + + # check that the reserved node is the returned parameter + assert getattr(m1, '__TRACE_RESERVED_bundle_Model.forward').parameter is m1.parameters()[0] + assert getattr(m2, '__TRACE_RESERVED_bundle_Model.forward').parameter is m2.parameters()[0] + + # each instance has a version different from the class' version + assert m1.forward != m2.forward + assert m1.forward != Model.forward + assert m2.forward.parameter == Model.forward.parameter == m1.forward.parameter + + y1 = m1.forward(1) + y2 = m1.forward(2) + + from opto.trace.utils import contain + # self is not duplicated + assert contain(y1.parents, m1.__TRACE_RESERVED_self_node) + assert contain(y2.parents, m1.__TRACE_RESERVED_self_node) + # assert m1.__TRACE_RESERVED_self_node in y1.parents + # assert m1.__TRACE_RESERVED_self_node in y2.parents + assert contain(y1.parents, m1.forward.parameter) + assert contain(y2.parents, m1.forward.parameter) + # assert m1.forward.parameter in y1.parents + # assert m1.forward.parameter in y2.parents + assert len(y1.parents) == 3 # since it's trainable + assert len(y2.parents) == 3 + +def test_case_model_copy(): + m1 = Model() + m2 = deepcopy(m1) + + # Make sure the parameters are different + try: + assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node + except AttributeError: + # These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed. + pass + + # The hidden nodes are defined now + assert len(m1.parameters()) == 1 + assert len(m2.parameters()) == 1 + + # Make sure the parameters are different + assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node # they are defined now + assert m1.parameters()[0] is not m2.parameters()[0] + + # check that the reserved node is the returned parameter + assert getattr(m1, '__TRACE_RESERVED_bundle_Model.forward').parameter is m1.parameters()[0] + assert getattr(m2, '__TRACE_RESERVED_bundle_Model.forward').parameter is m2.parameters()[0] + + # each instance has a version different from the class' version + assert m1.forward is not m2.forward + assert m1.forward is not Model.forward + assert m2.forward.parameter == Model.forward.parameter == m1.forward.parameter + + y1 = m1.forward(1) + y2 = m2.forward(2) + + from opto.trace.utils import contain + # self is not duplicated + assert contain(y1.parents, m1.__TRACE_RESERVED_self_node) + assert contain(y2.parents, m2.__TRACE_RESERVED_self_node) + # assert m1.__TRACE_RESERVED_self_node in y1.parents + # assert m1.__TRACE_RESERVED_self_node in y2.parents + assert contain(y1.parents, m1.forward.parameter) + assert contain(y2.parents, m2.forward.parameter) + + # assert m1.forward.parameter in y1.parents + # assert m1.forward.parameter in y2.parents + assert len(y1.parents) == 3 # since it's trainable + assert len(y2.parents) == 3 + +def test_case_model_nested_copy(): + m1 = Model() + m3 = deepcopy(m1) + m2 = deepcopy(m3) + + # Make sure the parameters are different + try: + assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node + except AttributeError: + # These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed. + pass + + # The hidden nodes are defined now + assert len(m1.parameters()) == 1 + assert len(m2.parameters()) == 1 + + # Make sure the parameters are different + assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node # they are defined now + assert m1.parameters()[0] is not m2.parameters()[0] + + # check that the reserved node is the returned parameter + assert getattr(m1, '__TRACE_RESERVED_bundle_Model.forward').parameter is m1.parameters()[0] + assert getattr(m2, '__TRACE_RESERVED_bundle_Model.forward').parameter is m2.parameters()[0] + + # each instance has a version different from the class' version + assert m1.forward is not m2.forward + assert m1.forward is not Model.forward + assert m2.forward.parameter == Model.forward.parameter == m1.forward.parameter + + y1 = m1.forward(1) + y2 = m2.forward(2) + + from opto.trace.utils import contain + # self is not duplicated + assert contain(y1.parents, m1.__TRACE_RESERVED_self_node) + assert contain(y2.parents, m2.__TRACE_RESERVED_self_node) + # assert m1.__TRACE_RESERVED_self_node in y1.parents + # assert m1.__TRACE_RESERVED_self_node in y2.parents + assert contain(y1.parents, m1.forward.parameter) + assert contain(y2.parents, m2.forward.parameter) + + # assert m1.forward.parameter in y1.parents + # assert m1.forward.parameter in y2.parents + assert len(y1.parents) == 3 # since it's trainable + assert len(y2.parents) == 3 + +test_case_two_models() +test_case_model_copy() +test_case_model_nested_copy() \ No newline at end of file