Skip to content

Commit 0b38822

Browse files
authored
Merge pull request #14345 from heavengate/fix_grid_sampler
fix #14344 : win compile error, EigenTenor * float unsupport. test=develop
2 parents 200c410 + 72108d8 commit 0b38822

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

paddle/fluid/operators/grid_sampler_op.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,19 @@ static void CalcGridLocations(const platform::CPUDeviceContext& ctx,
6363
Tensor ones;
6464
ones.mutable_data<T>({n, h, w}, ctx.GetPlace());
6565
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant(1.0);
66+
Tensor half_xmax, half_ymax;
67+
half_xmax.mutable_data<T>({n, h, w}, ctx.GetPlace());
68+
auto half_xmax_t =
69+
EigenTensor<T, 3>::From(half_xmax).setConstant(0.5 * x_max);
70+
half_ymax.mutable_data<T>({n, h, w}, ctx.GetPlace());
71+
auto half_ymax_t =
72+
EigenTensor<T, 3>::From(half_ymax).setConstant(0.5 * y_max);
6673

6774
// scale grid to [0, h-1/w-1]
6875
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
6976
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
70-
grid_x_t.device(place) = 0.5 * ((grid_x_t + ones_t) * x_max);
71-
grid_y_t.device(place) = 0.5 * ((grid_y_t + ones_t) * y_max);
77+
grid_x_t.device(place) = (grid_x_t + ones_t) * half_xmax_t;
78+
grid_y_t.device(place) = (grid_y_t + ones_t) * half_ymax_t;
7279

7380
// calculate coords of 4 corner points
7481
x_w->mutable_data<T>({n, h, w}, ctx.GetPlace());

0 commit comments

Comments
 (0)