Skip to content

Commit d6a6a13

Browse files
Fix build error of affine grid op in mac os. (#14237)
* Fix build error of affine grid op in mac os. test=develop * Make function return reference. test=develop
1 parent d55481c commit d6a6a13

File tree

2 files changed

+56
-74
lines changed

2 files changed

+56
-74
lines changed

paddle/fluid/operators/affine_grid_op.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,13 @@ using Tensor = framework::Tensor;
2626

2727
template <typename T>
2828
struct Linspace<paddle::platform::CPUDeviceContext, T> {
29-
framework::Tensor operator()(T start, T end, int count,
30-
const framework::ExecutionContext& ctx) {
31-
Tensor numbers;
32-
T* number_data = numbers.mutable_data<T>({count}, platform::CPUPlace());
29+
void operator()(T start, T end, int count, framework::Tensor* numbers,
30+
const framework::ExecutionContext& ctx) {
31+
T* number_data = numbers->mutable_data<T>({count}, platform::CPUPlace());
3332
T slice = (end - start) / (T)(count - 1);
3433
for (int i = 0; i < count; ++i) {
3534
number_data[i] = start + (T)i * slice;
3635
}
37-
return numbers;
3836
}
3937
};
4038

paddle/fluid/operators/affine_grid_op.h

Lines changed: 53 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,65 @@ using Array4 = Eigen::DSizes<int64_t, 4>;
3737
*/
3838
template <typename DeviceContext, typename T>
3939
struct Linspace {
40-
framework::Tensor operator()(T start, T end, int count,
41-
const framework::ExecutionContext& ctx);
40+
void operator()(T start, T end, int count, framework::Tensor* numbers,
41+
const framework::ExecutionContext& ctx);
4242
};
4343

44+
template <typename DeviceContext, typename T>
45+
inline void GetIdxMap(int n, int h, int w, Tensor* grid,
46+
const framework::ExecutionContext& ctx) {
47+
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
48+
grid->mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
49+
auto grid_t = EigenTensor<T, 4>::From(*grid);
50+
// Get indexes of height with shape [height, width, 1]
51+
Tensor h_idx;
52+
Linspace<DeviceContext, T> linspace;
53+
linspace((T)-1, (T)1, h, &h_idx, ctx);
54+
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
55+
// Get indexes of width with shape [height, width, 1]
56+
Tensor w_idx;
57+
linspace((T)-1, (T)1, w, &w_idx, ctx);
58+
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
59+
// Get constant ones tensor with shape [height, width, 1]
60+
Tensor ones;
61+
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
62+
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant((T)1);
63+
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
64+
// ones
65+
Tensor w_idx_map;
66+
w_idx_map.mutable_data<T>({h, w, 1}, ctx.GetPlace());
67+
auto w_idx_map_t = EigenTensor<T, 3>::From(w_idx_map);
68+
Tensor h_idx_map;
69+
h_idx_map.mutable_data<T>({h, w, 1}, ctx.GetPlace());
70+
auto h_idx_map_t = EigenTensor<T, 3>::From(h_idx_map);
71+
Tensor w_h_idx_map;
72+
w_h_idx_map.mutable_data<T>({h, w, 2}, ctx.GetPlace());
73+
auto w_h_idx_map_t = EigenTensor<T, 3>::From(w_h_idx_map);
74+
Tensor w_h_one_idx_map;
75+
w_h_one_idx_map.mutable_data<T>({h, w, 3}, ctx.GetPlace());
76+
auto w_h_one_idx_map_t = EigenTensor<T, 3>::From(w_h_one_idx_map);
77+
78+
w_idx_map_t.device(place) = w_idx_t.reshape(Array2(1, w))
79+
.broadcast(Array2(h, 1))
80+
.reshape(Array3(h, w, 1));
81+
82+
h_idx_map_t.device(place) = h_idx_t.reshape(Array2(1, h))
83+
.broadcast(Array2(w, 1))
84+
.shuffle(Array2(1, 0))
85+
.reshape(Array3(h, w, 1));
86+
87+
w_h_idx_map_t.device(place) = w_idx_map_t.concatenate(h_idx_map_t, 2);
88+
w_h_one_idx_map_t.device(place) = w_h_idx_map_t.concatenate(ones_t, 2);
89+
grid_t.device(place) = w_h_one_idx_map_t.reshape(Array4(1, h, w, 3))
90+
.broadcast(Array4(n, 1, 1, 1));
91+
}
92+
4493
template <typename DeviceContext, typename T>
4594
class AffineGridOpKernel : public framework::OpKernel<T> {
4695
public:
4796
void Compute(const framework::ExecutionContext& ctx) const override {
48-
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
4997
auto* theta = ctx.Input<Tensor>("Theta");
5098
int n = theta->dims()[0];
51-
5299
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
53100
int h = 0;
54101
int w = 0;
@@ -63,44 +110,13 @@ class AffineGridOpKernel : public framework::OpKernel<T> {
63110
h = size_attr[2];
64111
w = size_attr[3];
65112
}
66-
67113
auto* output = ctx.Output<Tensor>("Output");
68114
output->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
69-
70115
math::SetConstant<DeviceContext, T>()(
71116
ctx.template device_context<DeviceContext>(), output,
72117
static_cast<T>(0));
73-
74-
Linspace<DeviceContext, T> linspace;
75-
// Get indexes of height with shape [height, width, 1]
76-
auto h_idx = linspace((T)-1, (T)1, h, ctx);
77-
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
78-
// Get indexes of width with shape [height, width, 1]
79-
auto w_idx = linspace((T)-1, (T)1, w, ctx);
80-
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
81-
// Get constant ones tensor with shape [height, width, 1]
82-
Tensor ones;
83-
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
84-
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant((T)1);
85-
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
86-
// ones
87118
Tensor grid;
88-
grid.mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
89-
auto grid_t = EigenTensor<T, 4>::From(grid);
90-
91-
grid_t.device(place) = w_idx_t.reshape(Array2(1, w))
92-
.broadcast(Array2(h, 1))
93-
.reshape(Array3(h, w, 1))
94-
.concatenate(h_idx_t.reshape(Array2(1, h))
95-
.broadcast(Array2(w, 1))
96-
.shuffle(Array2(1, 0))
97-
.reshape(Array3(h, w, 1)),
98-
2)
99-
.eval()
100-
.concatenate(ones_t, 2)
101-
.reshape(Array4(1, h, w, 3))
102-
.broadcast(Array4(n, 1, 1, 1));
103-
119+
GetIdxMap<DeviceContext, T>(n, h, w, &grid, ctx);
104120
// output = grid * theta.T
105121
// TODO(wanghaoshuang): Refine batched matrix multiply
106122
auto blas = math::GetBlas<DeviceContext, T>(ctx);
@@ -118,10 +134,8 @@ template <typename DeviceContext, typename T>
118134
class AffineGridGradOpKernel : public framework::OpKernel<T> {
119135
public:
120136
void Compute(const framework::ExecutionContext& ctx) const override {
121-
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
122137
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
123138
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
124-
125139
int n = output_grad->dims()[0];
126140
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
127141
int h = 0;
@@ -137,42 +151,12 @@ class AffineGridGradOpKernel : public framework::OpKernel<T> {
137151
h = size_attr[2];
138152
w = size_attr[3];
139153
}
140-
141154
theta_grad->mutable_data<T>({n, 2, 3}, ctx.GetPlace());
142-
143155
math::SetConstant<DeviceContext, T>()(
144156
ctx.template device_context<DeviceContext>(), theta_grad,
145157
static_cast<T>(0));
146-
147-
Linspace<DeviceContext, T> linspace;
148-
149-
// Get indexes of height with shape [height, width, 1]
150-
auto h_idx = linspace((T)-1, (T)1, h, ctx);
151-
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
152-
// Get indexes of width with shape [height, width, 1]
153-
auto w_idx = linspace((T)-1, (T)1, w, ctx);
154-
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
155-
// Get constant ones tensor with shape [height, width, 1]
156-
Tensor ones;
157-
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
158-
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant((T)1);
159-
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
160-
// ones
161158
Tensor grid;
162-
grid.mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
163-
auto grid_t = EigenTensor<T, 4>::From(grid);
164-
grid_t.device(place) = w_idx_t.reshape(Array2(1, w))
165-
.broadcast(Array2(h, 1))
166-
.reshape(Array3(h, w, 1))
167-
.concatenate(h_idx_t.reshape(Array2(1, h))
168-
.broadcast(Array2(w, 1))
169-
.shuffle(Array2(1, 0))
170-
.reshape(Array3(h, w, 1)),
171-
2)
172-
.eval()
173-
.concatenate(ones_t, 2)
174-
.reshape(Array4(1, h, w, 3))
175-
.broadcast(Array4(n, 1, 1, 1));
159+
GetIdxMap<DeviceContext, T>(n, h, w, &grid, ctx);
176160
// output = grid * theta.T
177161
// TODO(wanghaoshuang): Refine batched matrix multiply
178162
auto blas = math::GetBlas<DeviceContext, T>(ctx);

0 commit comments

Comments
 (0)