Skip to content

Commit ef6445e

Browse files
authored
Merge pull request #12908 from seiriosPlus/fill_constant_selectedrows
add SelectedRows support in fill_constant_op
2 parents 836e1e0 + 66cc185 commit ef6445e

File tree

4 files changed

+51
-9
lines changed

4 files changed

+51
-9
lines changed

paddle/fluid/operators/fill_constant_op.cc

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/data_type.h"
1616
#include "paddle/fluid/framework/op_registry.h"
1717
#include "paddle/fluid/operators/math/math_function.h"
18-
#include "paddle/fluid/platform/device_context.h"
1918

2019
namespace paddle {
2120
namespace operators {
@@ -41,19 +40,33 @@ class FillConstantOp : public framework::OperatorBase {
4140
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
4241
auto value = Attr<float>("value");
4342
auto force_cpu = Attr<bool>("force_cpu");
44-
auto &out =
45-
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
46-
out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
43+
44+
framework::Tensor *tensor = nullptr;
45+
46+
auto &out_var = *scope.FindVar(Output("Out"));
47+
48+
if (out_var.IsType<framework::LoDTensor>()) {
49+
tensor = out_var.GetMutable<framework::LoDTensor>();
50+
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
51+
} else if (out_var.IsType<framework::SelectedRows>()) {
52+
tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value();
53+
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
54+
} else {
55+
PADDLE_THROW(
56+
"fill constant op's output only"
57+
"supports SelectedRows and LoDTensor");
58+
}
59+
4760
if (force_cpu) {
4861
auto cpu = platform::CPUPlace();
49-
out.mutable_data(cpu, framework::ToTypeIndex(data_type));
62+
tensor->mutable_data(cpu, framework::ToTypeIndex(data_type));
5063
} else {
51-
out.mutable_data(dev_place, framework::ToTypeIndex(data_type));
64+
tensor->mutable_data(dev_place, framework::ToTypeIndex(data_type));
5265
}
5366

5467
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
5568
auto &dev_ctx = *pool.Get(dev_place);
56-
math::set_constant(dev_ctx, &out, value);
69+
math::set_constant(dev_ctx, tensor, value);
5770
}
5871
};
5972

paddle/fluid/operators/uniform_random_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
3737
} else {
3838
PADDLE_THROW(
3939
"uniform_random_op's output only"
40-
"supports SelectedRows and Tensor");
40+
"supports SelectedRows and LoDTensor");
4141
}
4242
T* data = tensor->mutable_data<T>(ctx.GetPlace());
4343
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));

paddle/fluid/operators/uniform_random_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
5454
} else {
5555
PADDLE_THROW(
5656
"uniform_random_op's output only"
57-
"supports SelectedRows and Tensor");
57+
"supports SelectedRows and LoDTensor");
5858
}
5959
T* data = tensor->mutable_data<T>(context.GetPlace());
6060
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));

python/paddle/fluid/tests/unittests/test_fill_constant_op.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
import numpy as np
1919
from op_test import OpTest
2020

21+
import paddle.fluid.core as core
22+
from paddle.fluid.op import Operator
23+
2124

2225
class TestFillConstantOp1(OpTest):
2326
def setUp(self):
@@ -47,5 +50,31 @@ def test_check_output(self):
4750
self.check_output()
4851

4952

53+
class TestFillConstantOpWithSelectedRows(OpTest):
54+
def check_with_place(self, place):
55+
scope = core.Scope()
56+
# create Out Variable
57+
out = scope.var('Out').get_selected_rows()
58+
59+
# create and run fill_constant_op operator
60+
fill_constant_op = Operator(
61+
"fill_constant", shape=[123, 92], value=3.8, Out='Out')
62+
fill_constant_op.run(scope, place)
63+
64+
# get result from Out
65+
result_array = np.array(out.get_tensor())
66+
full_array = np.full((123, 92), 3.8, 'float32')
67+
68+
self.assertTrue(np.array_equal(result_array, full_array))
69+
70+
def test_fill_constant_with_selected_rows(self):
71+
places = [core.CPUPlace()]
72+
if core.is_compiled_with_cuda():
73+
places.append(core.CUDAPlace(0))
74+
75+
for place in places:
76+
self.check_with_place(place)
77+
78+
5079
if __name__ == "__main__":
5180
unittest.main()

0 commit comments

Comments
 (0)