Skip to content

Commit 862e81e

Browse files
authored
[cherry-pick][Dy2Stat]supplet several interface of static Variable to consistent with dygraph Tensor (#33330) #34401
As the title [cherry-pick][Dy2Stat]supplet several interface of static Variable to consistent with dygraph Tensor (#33330)
1 parent 9b48cfd commit 862e81e

File tree

9 files changed

+364
-26
lines changed

9 files changed

+364
-26
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/share_data_op.h"
16+
#include "paddle/fluid/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class ShareDataOp : public framework::OperatorWithKernel {
22+
public:
23+
using framework::OperatorWithKernel::OperatorWithKernel;
24+
25+
void InferShape(framework::InferShapeContext *ctx) const override {
26+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShareData");
27+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShareData");
28+
auto in_type = ctx->GetInputsVarType("X")[0];
29+
auto out_type = ctx->GetOutputsVarType("Out")[0];
30+
31+
PADDLE_ENFORCE_EQ(
32+
in_type == framework::proto::VarType::LOD_TENSOR ||
33+
in_type == framework::proto::VarType::SELECTED_ROWS,
34+
true, platform::errors::InvalidArgument(
35+
"Type of Variable[X] must be LoDTensor or SelectedRows!"));
36+
PADDLE_ENFORCE_EQ(
37+
in_type, out_type,
38+
platform::errors::InvalidArgument(
39+
"The type of input (X) and output (Out) are inconsistent."));
40+
41+
ctx->ShareDim("X", "Out");
42+
}
43+
};
44+
45+
class ShareDataOpMaker : public framework::OpProtoAndCheckerMaker {
46+
public:
47+
void Make() override {
48+
AddInput("X", "(Tensor), The input tensor of share_data op");
49+
AddOutput("Out", "(Tensor), The output tensor of share_data op");
50+
AddComment(R"DOC(
51+
ShareData Operator.
52+
53+
Return a tensor $Out$ that shares data with the input tensor $X$ and without tensor copy.
54+
)DOC");
55+
}
56+
};
57+
58+
} // namespace operators
59+
} // namespace paddle
60+
61+
namespace ops = paddle::operators;
62+
REGISTER_OPERATOR(
63+
share_data, ops::ShareDataOp, ops::ShareDataOpMaker,
64+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
65+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
66+
REGISTER_OP_CPU_KERNEL(share_data, ops::ShareDataKernel<bool>,
67+
ops::ShareDataKernel<int>, ops::ShareDataKernel<int8_t>,
68+
ops::ShareDataKernel<uint8_t>,
69+
ops::ShareDataKernel<paddle::platform::float16>,
70+
ops::ShareDataKernel<int64_t>,
71+
ops::ShareDataKernel<float>,
72+
ops::ShareDataKernel<double>)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/share_data_op.h"
16+
17+
REGISTER_OP_CUDA_KERNEL(
18+
share_data, paddle::operators::ShareDataKernel<bool>,
19+
paddle::operators::ShareDataKernel<int>,
20+
paddle::operators::ShareDataKernel<int8_t>,
21+
paddle::operators::ShareDataKernel<uint8_t>,
22+
paddle::operators::ShareDataKernel<paddle::platform::float16>,
23+
paddle::operators::ShareDataKernel<int64_t>,
24+
paddle::operators::ShareDataKernel<float>,
25+
paddle::operators::ShareDataKernel<double>);
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
template <typename T>
22+
class ShareDataKernel : public framework::OpKernel<T> {
23+
public:
24+
void Compute(const framework::ExecutionContext &ctx) const override {
25+
auto *in_var = ctx.InputVar("X");
26+
auto *out_var = ctx.OutputVar("Out");
27+
if (in_var->IsType<framework::LoDTensor>()) {
28+
const auto &origin_tensor = in_var->Get<framework::LoDTensor>();
29+
auto *detach_tensor = out_var->GetMutable<framework::LoDTensor>();
30+
detach_tensor->ShareDataWith(origin_tensor);
31+
} else {
32+
const auto &origin_selected_rows = in_var->Get<framework::SelectedRows>();
33+
auto *detach_selected_rows =
34+
out_var->GetMutable<framework::SelectedRows>();
35+
detach_selected_rows->mutable_value()->ShareDataWith(
36+
origin_selected_rows.value());
37+
}
38+
}
39+
};
40+
} // namespace operators
41+
} // namespace paddle

python/paddle/fluid/framework.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -942,35 +942,43 @@ def __init__(self,
942942
self._stop_gradient = stop_gradient
943943
self.is_data = is_data
944944

945-
@fake_interface_only
946945
def detach(self):
947946
"""
948-
**Notes**:
949-
**This API is ONLY available in Dygraph mode**
950-
951947
Returns a new Variable, detached from the current graph.
948+
It will share data with origin Variable and without tensor copy.
949+
In addition, the detached Variable doesn't provide gradient propagation.
952950
953951
Returns:
954952
( :ref:`api_guide_Variable_en` | dtype is same as current Variable): The detached Variable.
955953
956-
957954
Examples:
958955
.. code-block:: python
959956
960-
import paddle.fluid as fluid
961-
from paddle.fluid.dygraph.base import to_variable
962-
from paddle.fluid.dygraph import Linear
963-
import numpy as np
957+
import paddle
964958
965-
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
966-
with fluid.dygraph.guard():
967-
linear = Linear(32, 64)
968-
data = to_variable(data)
969-
x = linear(data)
970-
y = x.detach()
959+
paddle.enable_static()
960+
961+
# create a static Variable
962+
x = paddle.static.data(name='x', shape=[3, 2, 1])
971963
964+
# create a detached Variable
965+
y = x.detach()
972966
"""
973-
pass
967+
968+
assert self.type == core.VarDesc.VarType.SELECTED_ROWS or \
969+
self.type == core.VarDesc.VarType.LOD_TENSOR, \
970+
"only support a variable with SELECTED_ROWS or LOD_TENSOR to be detached"
971+
972+
output = self.block.create_var(
973+
name=unique_name.generate_with_ignorable_key("detach_" + self.name),
974+
dtype=self.dtype,
975+
type=self.type,
976+
persistable=self.persistable,
977+
stop_gradient=True)
978+
979+
self.block.append_op(
980+
type='share_data', inputs={'X': [self]}, outputs={'Out': [output]})
981+
return output
974982

975983
@fake_interface_only
976984
def numpy(self):
@@ -1805,6 +1813,35 @@ def set_value(self, value, scope=None):
18051813

18061814
t.set(value, place)
18071815

1816+
def size(self):
1817+
"""
1818+
Returns the number of elements for current Variable, which is a int64 Variable with shape [1]
1819+
1820+
Returns:
1821+
Variable: the number of elements for current Variable
1822+
1823+
Examples:
1824+
.. code-block:: python
1825+
1826+
import paddle
1827+
1828+
paddle.enable_static()
1829+
1830+
# create a static Variable
1831+
x = paddle.static.data(name='x', shape=[3, 2, 1])
1832+
1833+
# get the number of elements of the Variable
1834+
y = x.size()
1835+
"""
1836+
1837+
output = self.block.create_var(
1838+
name=unique_name.generate_with_ignorable_key(self.name + "_size"),
1839+
dtype=core.VarDesc.VarType.INT64)
1840+
1841+
self.block.append_op(
1842+
type='size', inputs={'Input': [self]}, outputs={'Out': [output]})
1843+
return output
1844+
18081845

18091846
def get_all_op_protos():
18101847
"""

python/paddle/fluid/layers/math_op_patch.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"__rpow__": "A **= B",
4848
"__floordiv__": "A //B",
4949
"__mod__": "A % B",
50+
"__matmul__": "A @ B",
5051
"__eq__": "A == B",
5152
"__ne__": "A != B",
5253
"__lt__": "A < B",
@@ -197,6 +198,28 @@ def _scalar_op_(var, scale, bias):
197198
def _neg_(var):
198199
return _scalar_op_(var, -1.0, 0.0)
199200

201+
@property
202+
def _ndim_(self):
203+
"""
204+
Returns the dimension of current Variable
205+
206+
Returns:
207+
the dimension
208+
209+
Examples:
210+
.. code-block:: python
211+
212+
import paddle
213+
214+
paddle.enable_static()
215+
216+
# create a static Variable
217+
x = paddle.static.data(name='x', shape=[3, 2, 1])
218+
# print the dimension of the Variable
219+
print(x.ndim)
220+
"""
221+
return len(self.shape)
222+
200223
def _scalar_add_(var, value):
201224
return _scalar_op_(var, 1.0, value)
202225

@@ -233,9 +256,9 @@ def __impl__(self, other_var):
233256
other_var = float(other_var)
234257
# division is a special case
235258
# NOTE(chenweihang): because we cast tensor to float32 instead float64,
236-
# the division result can only guarantee the numerical accuracy of 6 digits
237-
# after the decimal point. The result of numpy calculation is of float64 type,
238-
# so the calculation result here and the calculation result of numpy are
259+
# the division result can only guarantee the numerical accuracy of 6 digits
260+
# after the decimal point. The result of numpy calculation is of float64 type,
261+
# so the calculation result here and the calculation result of numpy are
239262
# different after 6 decimal point. If necessary, we can also use float64 here.
240263
# torch's behavior here is consistent with ours
241264
if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_:
@@ -323,6 +346,9 @@ def __impl__(self, other_var):
323346
# b=-a
324347
('__neg__', _neg_),
325348
('astype', astype),
349+
('dim', lambda x: len(x.shape)),
350+
('ndimension', lambda x: len(x.shape)),
351+
('ndim', _ndim_),
326352
('__add__', _binary_creator_('__add__', 'elementwise_add', False,
327353
_scalar_add_)),
328354
# a+b == b+a. Do not need to reverse explicitly
@@ -353,6 +379,8 @@ def __impl__(self, other_var):
353379
'elementwise_floordiv', False, None)),
354380
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False,
355381
None)),
382+
('__matmul__', _binary_creator_('__matmul__', "matmul_v2", False,
383+
None)),
356384
# for logical compare
357385
('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
358386
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),

python/paddle/fluid/tests/unittests/test_detach.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,6 @@ def test_NoDetachSingle_DetachMulti(self):
149149
array_detach_multi = self.detach_multi()
150150
assert np.array_equal(array_no_detach_single, array_detach_multi)
151151

152-
def test_detach_exception(self):
153-
x = fluid.layers.data(name="a", shape=[3, 4], dtype='float32')
154-
y = fluid.layers.fc(input=x, size=10, bias_attr=True)
155-
with self.assertRaises(AssertionError):
156-
y_detach = y.detach()
157-
158152

159153
class TestInplace(unittest.TestCase):
160154
def test_forward_version(self):

0 commit comments

Comments
 (0)