@@ -63,12 +63,19 @@ static void CalcGridLocations(const platform::CPUDeviceContext& ctx,
63
63
Tensor ones;
64
64
ones.mutable_data <T>({n, h, w}, ctx.GetPlace ());
65
65
auto ones_t = EigenTensor<T, 3 >::From (ones).setConstant (1.0 );
66
+ Tensor half_xmax, half_ymax;
67
+ half_xmax.mutable_data <T>({n, h, w}, ctx.GetPlace ());
68
+ auto half_xmax_t =
69
+ EigenTensor<T, 3 >::From (half_xmax).setConstant (0.5 * x_max);
70
+ half_ymax.mutable_data <T>({n, h, w}, ctx.GetPlace ());
71
+ auto half_ymax_t =
72
+ EigenTensor<T, 3 >::From (half_ymax).setConstant (0.5 * y_max);
66
73
67
74
// scale grid to [0, h-1/w-1]
68
75
auto grid_x_t = EigenTensor<T, 3 >::From (grid_x);
69
76
auto grid_y_t = EigenTensor<T, 3 >::From (grid_y);
70
- grid_x_t .device (place) = 0.5 * (( grid_x_t + ones_t ) * x_max) ;
71
- grid_y_t .device (place) = 0.5 * (( grid_y_t + ones_t ) * y_max) ;
77
+ grid_x_t .device (place) = ( grid_x_t + ones_t ) * half_xmax_t ;
78
+ grid_y_t .device (place) = ( grid_y_t + ones_t ) * half_ymax_t ;
72
79
73
80
// calculate coords of 4 corner points
74
81
x_w->mutable_data <T>({n, h, w}, ctx.GetPlace ());
0 commit comments