Skip to content

Commit 06de824

Browse files
committed
fix shape in floats
1 parent 318ba99 commit 06de824

File tree

4 files changed

+21
-9
lines changed

4 files changed

+21
-9
lines changed

paddle/fluid/operators/split_selected_rows_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
2222
void Make() override {
2323
AddInput("X", "The input SelectedRows.");
2424
AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable();
25-
AddAttr<std::vector<int>>("height_sections",
26-
"Height for each output SelectedRows.")
27-
.SetDefault(std::vector<int>({}));
25+
AddAttr<std::vector<int64_t>>("height_sections",
26+
"Height for each output SelectedRows.")
27+
.SetDefault(std::vector<int64_t>({}));
2828

2929
AddComment(R"DOC(
3030
Split a SelectedRows with a specified rows section.

paddle/fluid/operators/split_selected_rows_op.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ limitations under the License. */
2121
namespace paddle {
2222
namespace operators {
2323

24-
static int FindOutIdx(int row, const std::vector<int>& abs_sections) {
24+
static int FindOutIdx(int row, const std::vector<int64_t>& abs_sections) {
2525
for (size_t i = 1; i < abs_sections.size(); ++i) {
2626
if (row < abs_sections[i]) {
2727
return i - 1;
@@ -30,9 +30,9 @@ static int FindOutIdx(int row, const std::vector<int>& abs_sections) {
3030
return abs_sections.size() - 1;
3131
}
3232

33-
static std::vector<int> ToAbsoluteSection(
34-
const std::vector<int>& height_sections) {
35-
std::vector<int> abs_sections;
33+
static std::vector<int64_t> ToAbsoluteSection(
34+
const std::vector<int64_t>& height_sections) {
35+
std::vector<int64_t> abs_sections;
3636
abs_sections.resize(height_sections.size());
3737
abs_sections[0] = 0;
3838
for (size_t i = 1; i < height_sections.size(); ++i) {
@@ -47,7 +47,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
4747
void Compute(const framework::ExecutionContext& ctx) const override {
4848
auto* x = ctx.Input<framework::SelectedRows>("X");
4949
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
50-
auto height_sections = ctx.Attr<std::vector<int>>("height_sections");
50+
auto height_sections = ctx.Attr<std::vector<int64_t>>("height_sections");
5151

5252
auto abs_sections = ToAbsoluteSection(height_sections);
5353

paddle/fluid/operators/uniform_random_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
4848
if (out_var->IsType<framework::LoDTensor>()) {
4949
tensor = out_var->GetMutable<framework::LoDTensor>();
5050
} else if (out_var->IsType<framework::SelectedRows>()) {
51-
auto shape = context.Attr<std::vector<int>>("shape");
51+
auto shape = context.Attr<std::vector<int64_t>>("shape");
5252
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
5353
tensor->Resize(framework::make_ddim(shape));
5454
} else {

paddle/fluid/pybind/protobuf.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,18 @@ struct variant_caster<V<Ts...>> {
5757
auto caster = make_caster<T>();
5858
if (!load_success_ && caster.load(src, convert)) {
5959
load_success_ = true;
60+
61+
if (std::is_same<T, std::vector<float>>::value) {
62+
auto caster_ints = make_caster<std::vector<int64_t>>();
63+
if (caster_ints.load(src, convert)) {
64+
VLOG(4) << "This value are floats and int64_ts satisfy "
65+
"simultaneously, will set it's type to "
66+
"std::vector<int64_t>";
67+
value = cast_op<std::vector<int64_t>>(caster_ints);
68+
return true;
69+
}
70+
}
71+
6072
value = cast_op<T>(caster);
6173
return true;
6274
}

0 commit comments

Comments
 (0)