Skip to content

Commit 41a62c9

Browse files
authored
Merge pull request #32 from microsoft/0.1.3.4
Update to version 0.1.3.4
2 parents 27db891 + b8f6760 commit 41a62c9

File tree

5 files changed

+173
-4
lines changed

5 files changed

+173
-4
lines changed

opto/trace/bundle.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,9 +596,25 @@ def is_valid_output(output):
596596
isinstance(output, tuple) and all([isinstance(o, Node) for o in output])
597597
)
598598

599-
def __get__(self, obj, objtype):
599+
600+
# Define __set_name__ and __get__ for FunModule to act as a descriptor.
601+
def __get__(self, obj, db_type):
602+
if obj is None: # class method
603+
return self
600604
# Support instance methods.
601-
return functools.partial(self.__call__, obj)
605+
method_name = f'__TRACE_RESERVED_bundle_{self.name}' # NOTE we assume these are secret names not taken
606+
obj_node_name = f'__TRACE_RESERVED_self_node'
607+
if not hasattr(obj, obj_node_name):
608+
setattr(obj, obj_node_name, node(obj))
609+
if not hasattr(obj, method_name):
610+
funmodule = copy.deepcopy(self) # instance specific version
611+
funmodule.forward = functools.partial(funmodule.forward, getattr(obj, obj_node_name))
612+
setattr(obj, method_name, funmodule)
613+
fun = getattr(obj, method_name)
614+
assert fun is not self # self is defined in the class level
615+
assert isinstance(fun, FunModule), f"Expected {method_name} to be a FunModule, but got {type(fun)}"
616+
# fun = functools.partial(self.__call__, obj)
617+
return fun
602618

603619
def detach(self):
604620
return copy.deepcopy(self)

opto/trace/containers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,14 @@ def parameters_dict(self):
4242
"""
4343
parameters = {}
4444
for name, attr in inspect.getmembers(self):
45+
if name.startswith('__TRACE_RESERVED_'):
46+
# These are reserved for internal use.
47+
continue
4548
if isinstance(attr, functools.partial): # this is a class method
4649
method = attr.func.__self__
4750
if trainable_method(method):
4851
parameters[name] = method.parameter
49-
elif trainable_method(attr): # method attribute
52+
if trainable_method(attr): # method attribute
5053
parameters[name] = attr.parameter
5154
elif isinstance(attr, ParameterNode):
5255
parameters[name] = attr

opto/trace/nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ def __deepcopy__(self, memo):
455455
setattr(result, k, defaultdict(list))
456456
else:
457457
setattr(result, k, copy.deepcopy(v, memo))
458+
GRAPH.register(result)
458459
return result
459460

460461
def lt(self, other):

opto/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.3.3"
1+
__version__ = "0.1.3.4"
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from opto import trace
2+
from copy import deepcopy, copy
3+
4+
@trace.model
5+
class Model:
6+
7+
@trace.bundle(trainable=True)
8+
def forward(self, x):
9+
return x + 1
10+
11+
12+
def test_case_two_models():
13+
m1 = Model()
14+
m2 = Model()
15+
16+
# Make sure the parameters are different
17+
try:
18+
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node
19+
except AttributeError:
20+
# These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed.
21+
pass
22+
23+
# The hidden nodes are defined now
24+
assert len(m1.parameters()) == 1
25+
assert len(m2.parameters()) == 1
26+
27+
# Make sure the parameters are different
28+
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node # they are defined now
29+
assert m1.parameters()[0] is not m2.parameters()[0]
30+
31+
# check that the reserved node is the returned parameter
32+
assert getattr(m1, '__TRACE_RESERVED_bundle_Model.forward').parameter is m1.parameters()[0]
33+
assert getattr(m2, '__TRACE_RESERVED_bundle_Model.forward').parameter is m2.parameters()[0]
34+
35+
# each instance has a version different from the class' version
36+
assert m1.forward != m2.forward
37+
assert m1.forward != Model.forward
38+
assert m2.forward.parameter == Model.forward.parameter == m1.forward.parameter
39+
40+
y1 = m1.forward(1)
41+
y2 = m1.forward(2)
42+
43+
from opto.trace.utils import contain
44+
# self is not duplicated
45+
assert contain(y1.parents, m1.__TRACE_RESERVED_self_node)
46+
assert contain(y2.parents, m1.__TRACE_RESERVED_self_node)
47+
# assert m1.__TRACE_RESERVED_self_node in y1.parents
48+
# assert m1.__TRACE_RESERVED_self_node in y2.parents
49+
assert contain(y1.parents, m1.forward.parameter)
50+
assert contain(y2.parents, m1.forward.parameter)
51+
# assert m1.forward.parameter in y1.parents
52+
# assert m1.forward.parameter in y2.parents
53+
assert len(y1.parents) == 3 # since it's trainable
54+
assert len(y2.parents) == 3
55+
56+
def test_case_model_copy():
57+
m1 = Model()
58+
m2 = deepcopy(m1)
59+
60+
# Make sure the parameters are different
61+
try:
62+
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node
63+
except AttributeError:
64+
# These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed.
65+
pass
66+
67+
# The hidden nodes are defined now
68+
assert len(m1.parameters()) == 1
69+
assert len(m2.parameters()) == 1
70+
71+
# Make sure the parameters are different
72+
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node # they are defined now
73+
assert m1.parameters()[0] is not m2.parameters()[0]
74+
75+
# check that the reserved node is the returned parameter
76+
assert getattr(m1, '__TRACE_RESERVED_bundle_Model.forward').parameter is m1.parameters()[0]
77+
assert getattr(m2, '__TRACE_RESERVED_bundle_Model.forward').parameter is m2.parameters()[0]
78+
79+
# each instance has a version different from the class' version
80+
assert m1.forward is not m2.forward
81+
assert m1.forward is not Model.forward
82+
assert m2.forward.parameter == Model.forward.parameter == m1.forward.parameter
83+
84+
y1 = m1.forward(1)
85+
y2 = m2.forward(2)
86+
87+
from opto.trace.utils import contain
88+
# self is not duplicated
89+
assert contain(y1.parents, m1.__TRACE_RESERVED_self_node)
90+
assert contain(y2.parents, m2.__TRACE_RESERVED_self_node)
91+
# assert m1.__TRACE_RESERVED_self_node in y1.parents
92+
# assert m1.__TRACE_RESERVED_self_node in y2.parents
93+
assert contain(y1.parents, m1.forward.parameter)
94+
assert contain(y2.parents, m2.forward.parameter)
95+
96+
# assert m1.forward.parameter in y1.parents
97+
# assert m1.forward.parameter in y2.parents
98+
assert len(y1.parents) == 3 # since it's trainable
99+
assert len(y2.parents) == 3
100+
101+
def test_case_model_nested_copy():
102+
m1 = Model()
103+
m3 = deepcopy(m1)
104+
m2 = deepcopy(m3)
105+
106+
# Make sure the parameters are different
107+
try:
108+
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node
109+
except AttributeError:
110+
# These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed.
111+
pass
112+
113+
# The hidden nodes are defined now
114+
assert len(m1.parameters()) == 1
115+
assert len(m2.parameters()) == 1
116+
117+
# Make sure the parameters are different
118+
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node # they are defined now
119+
assert m1.parameters()[0] is not m2.parameters()[0]
120+
121+
# check that the reserved node is the returned parameter
122+
assert getattr(m1, '__TRACE_RESERVED_bundle_Model.forward').parameter is m1.parameters()[0]
123+
assert getattr(m2, '__TRACE_RESERVED_bundle_Model.forward').parameter is m2.parameters()[0]
124+
125+
# each instance has a version different from the class' version
126+
assert m1.forward is not m2.forward
127+
assert m1.forward is not Model.forward
128+
assert m2.forward.parameter == Model.forward.parameter == m1.forward.parameter
129+
130+
y1 = m1.forward(1)
131+
y2 = m2.forward(2)
132+
133+
from opto.trace.utils import contain
134+
# self is not duplicated
135+
assert contain(y1.parents, m1.__TRACE_RESERVED_self_node)
136+
assert contain(y2.parents, m2.__TRACE_RESERVED_self_node)
137+
# assert m1.__TRACE_RESERVED_self_node in y1.parents
138+
# assert m1.__TRACE_RESERVED_self_node in y2.parents
139+
assert contain(y1.parents, m1.forward.parameter)
140+
assert contain(y2.parents, m2.forward.parameter)
141+
142+
# assert m1.forward.parameter in y1.parents
143+
# assert m1.forward.parameter in y2.parents
144+
assert len(y1.parents) == 3 # since it's trainable
145+
assert len(y2.parents) == 3
146+
147+
test_case_two_models()
148+
test_case_model_copy()
149+
test_case_model_nested_copy()

0 commit comments

Comments
 (0)