Skip to content

Commit 5f8e3eb

Browse files
authored
[SOT] Add fine gate side effects changed check for dict like proxy data (#73549)
1 parent 53919d3 commit 5f8e3eb

File tree

4 files changed

+51
-4
lines changed

4 files changed

+51
-4
lines changed

python/paddle/jit/sot/opcode_translator/executor/mutable_data.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ def version(self):
148148
def has_changed(self):
149149
return self.version != 0
150150

151+
def check_changed(self, key: Any) -> bool:
152+
raise NotImplementedError
153+
151154
def rollback(self, version: int):
152155
assert version <= self.version
153156
self.records[:] = self.records[:version]
@@ -182,6 +185,17 @@ def __init__(self, data: Any, getter: DataGetter):
182185
def clear_read_cache(self):
183186
self.read_cache.clear()
184187

188+
def check_changed(self, key: Any) -> bool:
189+
if not self.has_changed:
190+
return False
191+
for mutation in self.records:
192+
if (
193+
isinstance(mutation, (MutationNew, MutationDel, MutationSet))
194+
and mutation.key == key
195+
):
196+
return True
197+
return False
198+
185199
def get(self, key: Any):
186200
# TODO(SigureMo): Optimize performance of this.
187201
write_cache = self.reproduce(self.version)
@@ -240,6 +254,9 @@ def __init__(self, data: Any, getter: DataGetter):
240254
def clear_read_cache(self):
241255
self.read_cache[:] = []
242256

257+
def check_changed(self, key: Any) -> bool:
258+
return self.has_changed
259+
243260
@property
244261
def length(self):
245262
return len(self.reproduce())

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2475,7 +2475,7 @@ def proxy_getter(self, proxy: MutableDictLikeData, key: Any):
24752475
return VariableFactory.from_value(
24762476
proxy.original_data[key],
24772477
self.graph,
2478-
tracker=GetAttrTracker(self, key, changed=proxy.has_changed),
2478+
tracker=GetAttrTracker(self, key, changed=proxy.check_changed(key)),
24792479
)
24802480

24812481
def get_py_type(self):

python/paddle/jit/sot/opcode_translator/executor/variables/container.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def proxy_getter(self, proxy: MutableListLikeData, key: Any):
199199
return VariableFactory.from_value(
200200
proxy.original_data[key],
201201
self.graph,
202-
tracker=GetItemTracker(self, key, changed=proxy.has_changed),
202+
tracker=GetItemTracker(self, key, changed=proxy.check_changed(key)),
203203
)
204204

205205
def get_py_value(self, allow_tensor=False):
@@ -249,7 +249,7 @@ def getitem(self, key):
249249
items[key],
250250
self.graph,
251251
tracker=GetItemTracker(
252-
self, key, changed=self.proxy.has_changed
252+
self, key, changed=self.proxy.check_changed(key)
253253
),
254254
)
255255
else:
@@ -870,7 +870,7 @@ def proxy_getter(self, proxy: MutableDictLikeData, key: Any):
870870
return VariableFactory.from_value(
871871
proxy.original_data[key],
872872
self.graph,
873-
tracker=GetItemTracker(self, key, changed=proxy.has_changed),
873+
tracker=GetItemTracker(self, key, changed=proxy.check_changed(key)),
874874
)
875875

876876
def get_py_value(self, allow_tensor=False):

test/sot/test_side_effects.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,24 @@ def slice_list_after_change(l):
181181
return sum
182182

183183

184+
class ReadBufferAfterChanged(paddle.nn.Layer):
185+
def __init__(self):
186+
super().__init__()
187+
self.buffer1 = paddle.to_tensor(1)
188+
self.buffer2 = paddle.to_tensor(2)
189+
190+
def forward(self, x):
191+
self.buffer1 += 1
192+
return x + self.buffer1 + self.buffer2
193+
194+
def __eq__(self, other):
195+
if not isinstance(other, ReadBufferAfterChanged):
196+
return False
197+
return paddle.equal(self.buffer1, other.buffer1) and paddle.equal(
198+
self.buffer2, other.buffer2
199+
)
200+
201+
184202
class TestDictSideEffect(TestCaseBase):
185203
def test_dict_setitem(self):
186204
self.assert_results_with_side_effects(
@@ -246,6 +264,7 @@ def test_list_insert(self):
246264
def test_list_remove(self):
247265
self.assert_results_with_side_effects(list_remove, [1, 1, 1])
248266
self.assert_results_with_side_effects(list_remove, [0, 1, 2])
267+
# TODO(DrRyanHuang): change this to ValueError
249268
with self.assertRaises(InnerError):
250269
symbolic_translate(list_remove)([0, 2, 4])
251270

@@ -329,5 +348,16 @@ def test_attr_set_breakgraph(self):
329348
self.attr_check(object_attr_breakgraph, ["x"], CustomObject, 1000)
330349

331350

351+
class TestReadBufferAfterChanged(TestCaseBase):
352+
def test_read_buffer_after_change(self):
353+
layer = ReadBufferAfterChanged()
354+
x = paddle.randn([1, 2, 3])
355+
self.assert_results_with_side_effects(
356+
layer.__class__.forward,
357+
layer,
358+
x,
359+
)
360+
361+
332362
if __name__ == "__main__":
333363
unittest.main()

0 commit comments

Comments
 (0)