11from opto import trace
2- from copy import deepcopy
2+ from copy import deepcopy , copy
33
44@trace .model
55class Model :
@@ -98,8 +98,52 @@ def test_case_model_copy():
9898 assert len (y1 .parents ) == 3 # since it's trainable
9999 assert len (y2 .parents ) == 3
100100
101- def printout_deecopy_modules ():
102- pass
101+ def test_case_model_nested_copy ():
102+ m1 = Model ()
103+ m3 = deepcopy (m1 )
104+ m2 = deepcopy (m3 )
105+
106+ # Make sure the parameters are different
107+ try :
108+ assert m1 .__TRACE_RESERVED_self_node is not m2 .__TRACE_RESERVED_self_node
109+ except AttributeError :
110+ # These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed.
111+ pass
112+
113+ # The hidden nodes are defined now
114+ assert len (m1 .parameters ()) == 1
115+ assert len (m2 .parameters ()) == 1
116+
117+ # Make sure the parameters are different
118+ assert m1 .__TRACE_RESERVED_self_node is not m2 .__TRACE_RESERVED_self_node # they are defined now
119+ assert m1 .parameters ()[0 ] is not m2 .parameters ()[0 ]
120+
121+ # check that the reserved node is the returned parameter
122+ assert getattr (m1 , '__TRACE_RESERVED_bundle_Model.forward' ).parameter is m1 .parameters ()[0 ]
123+ assert getattr (m2 , '__TRACE_RESERVED_bundle_Model.forward' ).parameter is m2 .parameters ()[0 ]
124+
125+ # each instance has a version different from the class' version
126+ assert m1 .forward is not m2 .forward
127+ assert m1 .forward is not Model .forward
128+ assert m2 .forward .parameter == Model .forward .parameter == m1 .forward .parameter
129+
130+ y1 = m1 .forward (1 )
131+ y2 = m2 .forward (2 )
132+
133+ from opto .trace .utils import contain
134+ # self is not duplicated
135+ assert contain (y1 .parents , m1 .__TRACE_RESERVED_self_node )
136+ assert contain (y2 .parents , m2 .__TRACE_RESERVED_self_node )
137+ # assert m1.__TRACE_RESERVED_self_node in y1.parents
138+ # assert m1.__TRACE_RESERVED_self_node in y2.parents
139+ assert contain (y1 .parents , m1 .forward .parameter )
140+ assert contain (y2 .parents , m2 .forward .parameter )
141+
142+ # assert m1.forward.parameter in y1.parents
143+ # assert m1.forward.parameter in y2.parents
144+ assert len (y1 .parents ) == 3 # since it's trainable
145+ assert len (y2 .parents ) == 3
103146
104147test_case_two_models ()
105- test_case_model_copy ()
148+ test_case_model_copy ()
149+ test_case_model_nested_copy ()
0 commit comments