Skip to content

Commit 1490551

Browse files
authored
Merge pull request #10970 from JiayiFeng/dev_add_random_crop_op
Add random crop op
2 parents 654f5d3 + d2c1fac commit 1490551

File tree

7 files changed

+381
-12
lines changed

7 files changed

+381
-12
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
469469
protected:
470470
DDim GetDim(const std::string& name) const override {
471471
Variable* var = scope_.FindVar(name);
472+
PADDLE_ENFORCE_NOT_NULL(var);
472473
if (var->IsType<LoDTensor>()) {
473474
return var->Get<LoDTensor>().dims();
474475
} else if (var->IsType<SelectedRows>()) {
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include "paddle/fluid/operators/random_crop_op.h"
15+
16+
namespace paddle {
17+
namespace operators {
18+
19+
class RandomCropOp : public framework::OperatorWithKernel {
20+
public:
21+
using framework::OperatorWithKernel::OperatorWithKernel;
22+
23+
protected:
24+
framework::OpKernelType GetExpectedKernelType(
25+
const framework::ExecutionContext& ctx) const override {
26+
return framework::OpKernelType(
27+
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
28+
ctx.device_context());
29+
}
30+
};
31+
32+
class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
33+
public:
34+
void Make() override {
35+
AddInput("X", "A batch of instances to random crop.");
36+
AddInput("Seed", "The random seed.");
37+
AddOutput("Out", "The cropped instance batch.");
38+
AddOutput("SeedOut", "The random seed after random cropping.")
39+
.AsDispensable();
40+
AddAttr<std::vector<int>>("shape", "The shape of a cropped instance.");
41+
AddComment(R"DOC(
42+
This operator takes a batch of instance, and do random cropping on each instance.
43+
It means that cropping positions differs on each instance, which is determined
44+
by an uniform random generator. All cropped instances have the same shape, which
45+
is determined by the operator's attribute 'shape'.
46+
)DOC");
47+
}
48+
};
49+
50+
class RandomCropOpInferShape : public framework::InferShapeBase {
51+
public:
52+
void operator()(framework::InferShapeContext* ctx) const override {
53+
auto seed_dim = ctx->GetInputDim("Seed");
54+
PADDLE_ENFORCE(seed_dim.size() == 1 && seed_dim[0] == 1);
55+
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
56+
auto x_dim = ctx->GetInputDim("X");
57+
PADDLE_ENFORCE_GT(x_dim.size(), static_cast<int64_t>(shape.size()));
58+
auto out_dim = framework::vectorize2int(x_dim);
59+
for (size_t i = 1; i <= shape.size(); ++i) {
60+
size_t x_i = x_dim.size() - i;
61+
size_t shape_i = shape.size() - i;
62+
PADDLE_ENFORCE_GE(x_dim[x_i], shape[shape_i]);
63+
out_dim[x_i] = shape[shape_i];
64+
}
65+
ctx->SetOutputDim("Out", framework::make_ddim(out_dim));
66+
ctx->SetOutputDim("SeedOut", framework::make_ddim({1}));
67+
}
68+
};
69+
70+
} // namespace operators
71+
} // namespace paddle
72+
73+
namespace ops = paddle::operators;
74+
namespace f = paddle::framework;
75+
REGISTER_OPERATOR(random_crop, ops::RandomCropOp, ops::RandomCropOpMaker,
76+
ops::RandomCropOpInferShape, f::EmptyGradOpMaker);
77+
78+
template <typename T>
79+
using Kernel = ops::RandomCropKernel<paddle::platform::CPUDeviceContext, T>;
80+
REGISTER_OP_CPU_KERNEL(random_crop, Kernel<float>, Kernel<int>, Kernel<double>,
81+
Kernel<uint8_t>, Kernel<int16_t>);
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/random_crop_op.h"
16+
17+
namespace ops = paddle::operators;
18+
template <typename T>
19+
using Kernel = ops::RandomCropKernel<paddle::platform::CUDADeviceContext, T>;
20+
REGISTER_OP_CUDA_KERNEL(random_crop, Kernel<float>, Kernel<int>, Kernel<double>,
21+
Kernel<uint8_t>, Kernel<int16_t>);
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <vector>
18+
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/detail/safe_ref.h"
20+
#include "paddle/fluid/platform/device_context.h"
21+
#include "paddle/fluid/platform/for_range.h"
22+
#ifdef PADDLE_WITH_CUDA
23+
#include <thrust/random.h>
24+
#endif
25+
26+
namespace paddle {
27+
namespace operators {
28+
29+
template <typename DeviceContext>
30+
struct Random;
31+
32+
template <>
33+
struct Random<platform::CPUDeviceContext> {
34+
using Engine = std::minstd_rand;
35+
36+
template <typename T>
37+
using UniformIntDist = std::uniform_int_distribution<T>;
38+
};
39+
40+
#ifdef PADDLE_WITH_CUDA
41+
template <>
42+
struct Random<platform::CUDADeviceContext> {
43+
using Engine = thrust::minstd_rand;
44+
45+
template <typename T>
46+
using UniformIntDist = thrust::uniform_int_distribution<T>;
47+
};
48+
#endif
49+
50+
template <typename T>
51+
HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out,
52+
const size_t* out_dims, int i, int rank,
53+
size_t prod_x_remain,
54+
size_t prod_out_remain,
55+
const size_t* offsets) {
56+
size_t x_dim_i = x_dims[i];
57+
size_t out_dim_i = out_dims[i];
58+
size_t x_stride = prod_x_remain / x_dim_i;
59+
size_t out_stride = prod_out_remain / out_dim_i;
60+
size_t offset_i = offsets[i];
61+
62+
if (i == rank - 1) {
63+
PADDLE_ASSERT(x_stride == 1 && out_stride == 1);
64+
x += offset_i;
65+
for (size_t j = 0; j < out_dim_i; ++j) {
66+
*out++ = *x++;
67+
}
68+
} else {
69+
x += offset_i * x_stride;
70+
for (size_t j = 0; j < out_dim_i; ++j) {
71+
StridedMemcpy<T>(x, x_dims, out, out_dims, i + 1, rank, x_stride,
72+
out_stride, offsets);
73+
x += x_stride;
74+
out += out_stride;
75+
}
76+
}
77+
}
78+
79+
template <typename DeviceContext, typename T>
80+
struct RandomCropFunctor {
81+
const T* x_;
82+
T* out_;
83+
size_t x_dims_[9];
84+
size_t out_dims_[9];
85+
int num_batchsize_dims_;
86+
int rank_;
87+
int64_t seed_;
88+
89+
size_t prod_batchsize_dims_;
90+
size_t prod_x_ins_dims_;
91+
size_t prod_out_ins_dims_;
92+
93+
RandomCropFunctor(const T* x, T* out, const framework::DDim& x_dims,
94+
const framework::DDim& out_dims, int num_batchsize_dims,
95+
int64_t seed)
96+
: x_(x),
97+
out_(out),
98+
num_batchsize_dims_(num_batchsize_dims),
99+
rank_(x_dims.size()),
100+
seed_(seed) {
101+
PADDLE_ENFORCE_EQ(x_dims.size(), out_dims.size());
102+
PADDLE_ENFORCE_GT(rank_, num_batchsize_dims_);
103+
prod_batchsize_dims_ = 1;
104+
prod_x_ins_dims_ = 1;
105+
prod_out_ins_dims_ = 1;
106+
for (size_t i = 0; i < static_cast<size_t>(rank_); ++i) {
107+
size_t x_dim_i = x_dims[i];
108+
size_t out_dim_i = out_dims[i];
109+
x_dims_[i] = x_dim_i;
110+
out_dims_[i] = out_dim_i;
111+
if (i < static_cast<size_t>(num_batchsize_dims_)) {
112+
PADDLE_ENFORCE_EQ(x_dim_i, out_dim_i);
113+
prod_batchsize_dims_ *= x_dim_i;
114+
} else {
115+
prod_x_ins_dims_ *= x_dim_i;
116+
prod_out_ins_dims_ *= out_dim_i;
117+
}
118+
}
119+
}
120+
121+
HOSTDEVICE void operator()(size_t ins_idx) {
122+
typename Random<DeviceContext>::Engine engine(seed_);
123+
engine.discard(ins_idx * (rank_ - num_batchsize_dims_));
124+
size_t offsets[9];
125+
for (int i = num_batchsize_dims_; i < rank_; ++i) {
126+
typename Random<DeviceContext>::template UniformIntDist<size_t> dist(
127+
0, x_dims_[i] - out_dims_[i]);
128+
offsets[i - num_batchsize_dims_] = dist(engine);
129+
}
130+
131+
const T* x = x_ + ins_idx * prod_x_ins_dims_;
132+
T* out = out_ + ins_idx * prod_out_ins_dims_;
133+
134+
StridedMemcpy<T>(x, x_dims_ + num_batchsize_dims_, out,
135+
out_dims_ + num_batchsize_dims_, 0,
136+
rank_ - num_batchsize_dims_, prod_x_ins_dims_,
137+
prod_out_ins_dims_, offsets);
138+
}
139+
};
140+
141+
template <typename DeviceContext, typename T>
142+
class RandomCropKernel : public framework::OpKernel<T> {
143+
public:
144+
virtual void Compute(const framework::ExecutionContext& ctx) const {
145+
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
146+
int64_t seed = 0;
147+
if (platform::is_cpu_place(seed_tensor.place())) {
148+
seed = *seed_tensor.data<int64_t>();
149+
} else {
150+
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
151+
"your program";
152+
framework::LoDTensor cpu_seed;
153+
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
154+
seed = *cpu_seed.data<int64_t>();
155+
}
156+
auto shape = ctx.Attr<std::vector<int>>("shape");
157+
auto& x = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
158+
auto& out = detail::Ref(ctx.Output<framework::LoDTensor>("Out"));
159+
160+
int num_batchsize_dims = x.dims().size() - shape.size();
161+
RandomCropFunctor<DeviceContext, T> functor(
162+
x.data<T>(), out.mutable_data<T>(ctx.GetPlace()), x.dims(), out.dims(),
163+
num_batchsize_dims, seed);
164+
platform::ForRange<DeviceContext> for_range(
165+
ctx.template device_context<DeviceContext>(),
166+
functor.prod_batchsize_dims_);
167+
168+
for_range(functor);
169+
170+
Random<platform::CPUDeviceContext>::Engine engine(seed);
171+
engine.discard(functor.prod_batchsize_dims_ *
172+
(functor.rank_ - functor.num_batchsize_dims_));
173+
*ctx.Output<framework::LoDTensor>("SeedOut")->mutable_data<int64_t>(
174+
platform::CPUPlace()) = engine();
175+
}
176+
};
177+
178+
// TODO(fengjiayi): Backward of random crop op
179+
180+
} // namespace operators
181+
} // namespace paddle

0 commit comments

Comments
 (0)