Skip to content

Commit aac2570

Browse files
authored
[3.0][Dy2St] Support element_size method for value and breakgraph when access unknown tensor attr (#71605) (#71618)
* [Dy2St] Support `element_size` method for value and breakgraph when access unknown tensor attr (#71605) * fix hard coded branch name * set $BRANCH to `release/3.0` * use `github.base_ref` * Revert "use `github.base_ref`" This reverts commit 036a01d. * Revert "set $BRANCH to `release/3.0`" This reverts commit 8addc14. * Revert "fix hard coded branch name" This reverts commit 827d8c9.
1 parent 3eb020a commit aac2570

File tree

6 files changed

+15
-6
lines changed

6 files changed

+15
-6
lines changed

paddle/fluid/pybind/pir.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,6 +1556,8 @@ void BindValue(py::module *m) {
15561556
.def("apply", &apply)
15571557
.def("is_same", &Value::operator==)
15581558
.def("hash", [](Value self) { return std::hash<pir::Value>{}(self); })
1559+
.def("element_size",
1560+
[](Value self) { return phi::SizeOf(pir::GetValueDtype(self)); })
15591561
.def("_rename", &name_analysis::RenameValue)
15601562
.def("_has_only_one_name",
15611563
[](Value self) -> bool {

python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,7 @@ def load_method(self, method_name):
899899
if CALL_METHOD_LAYOUT_NULL_AFTER_VALUE:
900900
self.stack.push(NullVariable())
901901

902+
@call_break_graph_decorator(push_n=2)
902903
def LOAD_METHOD(self, instr: Instruction):
903904
method_name = self._code.co_names[instr.arg]
904905
self.load_method(method_name)

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@
4848
printable,
4949
)
5050
from ....utils.envs import ENV_SOT_BREAK_GRAPH_ON_GET_SYMBOLIC_VALUE
51-
from ....utils.exceptions import HasNoAttributeError, InnerError
51+
from ....utils.exceptions import (
52+
InnerError,
53+
UnsupportedPaddleAPIBreak,
54+
)
5255
from ..dispatch_functions import tensor_numel
5356
from ..guard import (
5457
FasterStringifiedExpression,
@@ -704,7 +707,9 @@ def getattr(self, name: str, default=None):
704707
)
705708
return fn_var.bind(self, name)
706709
else:
707-
raise HasNoAttributeError(f"Unknown Tensor attribute: {name}")
710+
raise BreakGraphError(
711+
UnsupportedPaddleAPIBreak(fn_name=f"Tensor.{name}")
712+
)
708713

709714
def setattr(self, key, val):
710715
# support tensor variable store attr, like:

test/dygraph_to_static/test_tensor_attr_consistency.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
'data',
3838
'data_ptr',
3939
'detach_',
40-
'element_size',
4140
'fill_',
4241
'fill_diagonal_',
4342
'fill_diagonal_tensor',

test/sot/test_18_tensor_method.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def tensor_method_property_without_breakgraph(
5757
+ a.ndim
5858
+ a.dim()
5959
+ a.rank(),
60+
a.element_size(),
6061
)
6162

6263

@@ -67,6 +68,7 @@ def tensor_method_property_with_breakgraph(a: paddle.Tensor, b: paddle.Tensor):
6768
a.tolist(),
6869
str(a.place),
6970
a.clear_gradient(),
71+
a.is_dense(),
7072
)
7173

7274

test/sot/test_sot_exception.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
def case1(x):
25-
return n # noqa: F821
25+
return undefined_var # noqa: F821
2626

2727

2828
def case2(x):
@@ -31,15 +31,15 @@ def case2(x):
3131

3232

3333
def case3(x):
34-
y = x.undefined_attr
34+
y = undefined_var # noqa: F821
3535
return y
3636

3737

3838
def case4_inner(x):
3939
y = x * 2
4040
print()
4141
y = y + 1
42-
return y.undefined_attr
42+
return undefined_var # noqa: F821
4343

4444

4545
def case4(x):

0 commit comments

Comments
 (0)