@@ -37,18 +37,65 @@ using Array4 = Eigen::DSizes<int64_t, 4>;
37
37
*/
38
38
template <typename DeviceContext, typename T>
39
39
struct Linspace {
40
- framework::Tensor operator ()(T start, T end, int count,
41
- const framework::ExecutionContext& ctx);
40
+ void operator ()(T start, T end, int count, framework::Tensor* numbers ,
41
+ const framework::ExecutionContext& ctx);
42
42
};
43
43
44
+ template <typename DeviceContext, typename T>
45
+ inline void GetIdxMap (int n, int h, int w, Tensor* grid,
46
+ const framework::ExecutionContext& ctx) {
47
+ auto & place = *ctx.template device_context <DeviceContext>().eigen_device ();
48
+ grid->mutable_data <T>({n, h, w, 3 }, ctx.GetPlace ());
49
+ auto grid_t = EigenTensor<T, 4 >::From (*grid);
50
+ // Get indexes of height with shape [height, width, 1]
51
+ Tensor h_idx;
52
+ Linspace<DeviceContext, T> linspace;
53
+ linspace ((T)-1 , (T)1 , h, &h_idx, ctx);
54
+ auto h_idx_t = EigenTensor<T, 1 >::From (h_idx);
55
+ // Get indexes of width with shape [height, width, 1]
56
+ Tensor w_idx;
57
+ linspace ((T)-1 , (T)1 , w, &w_idx, ctx);
58
+ auto w_idx_t = EigenTensor<T, 1 >::From (w_idx);
59
+ // Get constant ones tensor with shape [height, width, 1]
60
+ Tensor ones;
61
+ ones.mutable_data <T>({h, w, 1 }, ctx.GetPlace ());
62
+ auto ones_t = EigenTensor<T, 3 >::From (ones).setConstant ((T)1 );
63
+ // Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
64
+ // ones
65
+ Tensor w_idx_map;
66
+ w_idx_map.mutable_data <T>({h, w, 1 }, ctx.GetPlace ());
67
+ auto w_idx_map_t = EigenTensor<T, 3 >::From (w_idx_map);
68
+ Tensor h_idx_map;
69
+ h_idx_map.mutable_data <T>({h, w, 1 }, ctx.GetPlace ());
70
+ auto h_idx_map_t = EigenTensor<T, 3 >::From (h_idx_map);
71
+ Tensor w_h_idx_map;
72
+ w_h_idx_map.mutable_data <T>({h, w, 2 }, ctx.GetPlace ());
73
+ auto w_h_idx_map_t = EigenTensor<T, 3 >::From (w_h_idx_map);
74
+ Tensor w_h_one_idx_map;
75
+ w_h_one_idx_map.mutable_data <T>({h, w, 3 }, ctx.GetPlace ());
76
+ auto w_h_one_idx_map_t = EigenTensor<T, 3 >::From (w_h_one_idx_map);
77
+
78
+ w_idx_map_t .device (place) = w_idx_t .reshape (Array2 (1 , w))
79
+ .broadcast (Array2 (h, 1 ))
80
+ .reshape (Array3 (h, w, 1 ));
81
+
82
+ h_idx_map_t .device (place) = h_idx_t .reshape (Array2 (1 , h))
83
+ .broadcast (Array2 (w, 1 ))
84
+ .shuffle (Array2 (1 , 0 ))
85
+ .reshape (Array3 (h, w, 1 ));
86
+
87
+ w_h_idx_map_t .device (place) = w_idx_map_t .concatenate (h_idx_map_t , 2 );
88
+ w_h_one_idx_map_t .device (place) = w_h_idx_map_t .concatenate (ones_t , 2 );
89
+ grid_t .device (place) = w_h_one_idx_map_t .reshape (Array4 (1 , h, w, 3 ))
90
+ .broadcast (Array4 (n, 1 , 1 , 1 ));
91
+ }
92
+
44
93
template <typename DeviceContext, typename T>
45
94
class AffineGridOpKernel : public framework ::OpKernel<T> {
46
95
public:
47
96
void Compute (const framework::ExecutionContext& ctx) const override {
48
- auto & place = *ctx.template device_context <DeviceContext>().eigen_device ();
49
97
auto * theta = ctx.Input <Tensor>(" Theta" );
50
98
int n = theta->dims ()[0 ];
51
-
52
99
auto size_attr = ctx.Attr <std::vector<int >>(" output_shape" );
53
100
int h = 0 ;
54
101
int w = 0 ;
@@ -63,44 +110,13 @@ class AffineGridOpKernel : public framework::OpKernel<T> {
63
110
h = size_attr[2 ];
64
111
w = size_attr[3 ];
65
112
}
66
-
67
113
auto * output = ctx.Output <Tensor>(" Output" );
68
114
output->mutable_data <T>({n, h, w, 2 }, ctx.GetPlace ());
69
-
70
115
math::SetConstant<DeviceContext, T>()(
71
116
ctx.template device_context <DeviceContext>(), output,
72
117
static_cast <T>(0 ));
73
-
74
- Linspace<DeviceContext, T> linspace;
75
- // Get indexes of height with shape [height, width, 1]
76
- auto h_idx = linspace ((T)-1 , (T)1 , h, ctx);
77
- auto h_idx_t = EigenTensor<T, 1 >::From (h_idx);
78
- // Get indexes of width with shape [height, width, 1]
79
- auto w_idx = linspace ((T)-1 , (T)1 , w, ctx);
80
- auto w_idx_t = EigenTensor<T, 1 >::From (w_idx);
81
- // Get constant ones tensor with shape [height, width, 1]
82
- Tensor ones;
83
- ones.mutable_data <T>({h, w, 1 }, ctx.GetPlace ());
84
- auto ones_t = EigenTensor<T, 3 >::From (ones).setConstant ((T)1 );
85
- // Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
86
- // ones
87
118
Tensor grid;
88
- grid.mutable_data <T>({n, h, w, 3 }, ctx.GetPlace ());
89
- auto grid_t = EigenTensor<T, 4 >::From (grid);
90
-
91
- grid_t .device (place) = w_idx_t .reshape (Array2 (1 , w))
92
- .broadcast (Array2 (h, 1 ))
93
- .reshape (Array3 (h, w, 1 ))
94
- .concatenate (h_idx_t .reshape (Array2 (1 , h))
95
- .broadcast (Array2 (w, 1 ))
96
- .shuffle (Array2 (1 , 0 ))
97
- .reshape (Array3 (h, w, 1 )),
98
- 2 )
99
- .eval ()
100
- .concatenate (ones_t , 2 )
101
- .reshape (Array4 (1 , h, w, 3 ))
102
- .broadcast (Array4 (n, 1 , 1 , 1 ));
103
-
119
+ GetIdxMap<DeviceContext, T>(n, h, w, &grid, ctx);
104
120
// output = grid * theta.T
105
121
// TODO(wanghaoshuang): Refine batched matrix multiply
106
122
auto blas = math::GetBlas<DeviceContext, T>(ctx);
@@ -118,10 +134,8 @@ template <typename DeviceContext, typename T>
118
134
class AffineGridGradOpKernel : public framework ::OpKernel<T> {
119
135
public:
120
136
void Compute (const framework::ExecutionContext& ctx) const override {
121
- auto & place = *ctx.template device_context <DeviceContext>().eigen_device ();
122
137
auto output_grad = ctx.Input <Tensor>(framework::GradVarName (" Output" ));
123
138
auto theta_grad = ctx.Output <Tensor>(framework::GradVarName (" Theta" ));
124
-
125
139
int n = output_grad->dims ()[0 ];
126
140
auto size_attr = ctx.Attr <std::vector<int >>(" output_shape" );
127
141
int h = 0 ;
@@ -137,42 +151,12 @@ class AffineGridGradOpKernel : public framework::OpKernel<T> {
137
151
h = size_attr[2 ];
138
152
w = size_attr[3 ];
139
153
}
140
-
141
154
theta_grad->mutable_data <T>({n, 2 , 3 }, ctx.GetPlace ());
142
-
143
155
math::SetConstant<DeviceContext, T>()(
144
156
ctx.template device_context <DeviceContext>(), theta_grad,
145
157
static_cast <T>(0 ));
146
-
147
- Linspace<DeviceContext, T> linspace;
148
-
149
- // Get indexes of height with shape [height, width, 1]
150
- auto h_idx = linspace ((T)-1 , (T)1 , h, ctx);
151
- auto h_idx_t = EigenTensor<T, 1 >::From (h_idx);
152
- // Get indexes of width with shape [height, width, 1]
153
- auto w_idx = linspace ((T)-1 , (T)1 , w, ctx);
154
- auto w_idx_t = EigenTensor<T, 1 >::From (w_idx);
155
- // Get constant ones tensor with shape [height, width, 1]
156
- Tensor ones;
157
- ones.mutable_data <T>({h, w, 1 }, ctx.GetPlace ());
158
- auto ones_t = EigenTensor<T, 3 >::From (ones).setConstant ((T)1 );
159
- // Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
160
- // ones
161
158
Tensor grid;
162
- grid.mutable_data <T>({n, h, w, 3 }, ctx.GetPlace ());
163
- auto grid_t = EigenTensor<T, 4 >::From (grid);
164
- grid_t .device (place) = w_idx_t .reshape (Array2 (1 , w))
165
- .broadcast (Array2 (h, 1 ))
166
- .reshape (Array3 (h, w, 1 ))
167
- .concatenate (h_idx_t .reshape (Array2 (1 , h))
168
- .broadcast (Array2 (w, 1 ))
169
- .shuffle (Array2 (1 , 0 ))
170
- .reshape (Array3 (h, w, 1 )),
171
- 2 )
172
- .eval ()
173
- .concatenate (ones_t , 2 )
174
- .reshape (Array4 (1 , h, w, 3 ))
175
- .broadcast (Array4 (n, 1 , 1 , 1 ));
159
+ GetIdxMap<DeviceContext, T>(n, h, w, &grid, ctx);
176
160
// output = grid * theta.T
177
161
// TODO(wanghaoshuang): Refine batched matrix multiply
178
162
auto blas = math::GetBlas<DeviceContext, T>(ctx);
0 commit comments