Skip to content

Commit 85a4528

Browse files
committed
revert adding deepcopy override. Add test for nested deepcopy
1 parent 9e2db27 commit 85a4528

File tree

2 files changed

+49
-15
lines changed

2 files changed

+49
-15
lines changed

opto/trace/modules.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,4 @@ def _set(self, new_parameters):
6161
parameters_dict[k]._set(v)
6262
else: # if the parameter does not exist
6363
assert k not in self.__dict__
64-
setattr(self, k, v)
65-
66-
def __deepcopy__(self, memo):
67-
""" Custom deepcopy behavior for Module. """
68-
cls = self.__class__
69-
result = cls.__new__(cls)
70-
memo[id(self)] = result
71-
for k, v in self.__dict__.items():
72-
if '__TRACE_RESERVED_' not in k:
73-
setattr(result, k, copy.deepcopy(v, memo))
74-
return result
64+
setattr(self, k, v)

tests/unit_tests/test_class_method.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from opto import trace
2-
from copy import deepcopy
2+
from copy import deepcopy, copy
33

44
@trace.model
55
class Model:
@@ -98,8 +98,52 @@ def test_case_model_copy():
9898
assert len(y1.parents) == 3 # since it's trainable
9999
assert len(y2.parents) == 3
100100

101-
def printout_deecopy_modules():
102-
pass
101+
def test_case_model_nested_copy():
102+
m1 = Model()
103+
m3 = deepcopy(m1)
104+
m2 = deepcopy(m3)
105+
106+
# Make sure the parameters are different
107+
try:
108+
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node
109+
except AttributeError:
110+
# These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed.
111+
pass
112+
113+
# The hidden nodes are defined now
114+
assert len(m1.parameters()) == 1
115+
assert len(m2.parameters()) == 1
116+
117+
# Make sure the parameters are different
118+
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node # they are defined now
119+
assert m1.parameters()[0] is not m2.parameters()[0]
120+
121+
# check that the reserved node is the returned parameter
122+
assert getattr(m1, '__TRACE_RESERVED_bundle_Model.forward').parameter is m1.parameters()[0]
123+
assert getattr(m2, '__TRACE_RESERVED_bundle_Model.forward').parameter is m2.parameters()[0]
124+
125+
# each instance has a version different from the class' version
126+
assert m1.forward is not m2.forward
127+
assert m1.forward is not Model.forward
128+
assert m2.forward.parameter == Model.forward.parameter == m1.forward.parameter
129+
130+
y1 = m1.forward(1)
131+
y2 = m2.forward(2)
132+
133+
from opto.trace.utils import contain
134+
# self is not duplicated
135+
assert contain(y1.parents, m1.__TRACE_RESERVED_self_node)
136+
assert contain(y2.parents, m2.__TRACE_RESERVED_self_node)
137+
# assert m1.__TRACE_RESERVED_self_node in y1.parents
138+
# assert m1.__TRACE_RESERVED_self_node in y2.parents
139+
assert contain(y1.parents, m1.forward.parameter)
140+
assert contain(y2.parents, m2.forward.parameter)
141+
142+
# assert m1.forward.parameter in y1.parents
143+
# assert m1.forward.parameter in y2.parents
144+
assert len(y1.parents) == 3 # since it's trainable
145+
assert len(y2.parents) == 3
103146

104147
test_case_two_models()
105-
test_case_model_copy()
148+
test_case_model_copy()
149+
test_case_model_nested_copy()

0 commit comments

Comments
 (0)