Skip to content

Commit 9c61409

Browse files
committed
Make crop op supporting taking offsets as one of its inputs
1 parent 9ce0885 commit 9c61409

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

paddle/fluid/operators/crop_op.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ class CropOp : public framework::OperatorWithKernel {
4848
ctx->SetOutputDim("Out", y_dim);
4949
}
5050
}
51+
52+
framework::OpKernelType GetExpectedKernelType(
53+
const framework::ExecutionContext& ctx) const override {
54+
return framework::OpKernelType(
55+
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
56+
ctx.device_context());
57+
}
5158
};
5259

5360
class CropOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -163,6 +170,15 @@ class CropOpGrad : public framework::OperatorWithKernel {
163170
ctx->SetOutputDim(x_grad_name, x_dims);
164171
}
165172
}
173+
174+
framework::OpKernelType GetExpectedKernelType(
175+
const framework::ExecutionContext& ctx) const override {
176+
return framework::OpKernelType(
177+
framework::ToDataType(
178+
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
179+
->type()),
180+
ctx.device_context());
181+
}
166182
};
167183

168184
} // namespace operators

paddle/fluid/operators/random_crop_op.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ class RandomCropOp : public framework::OperatorWithKernel {
2020
public:
2121
using framework::OperatorWithKernel::OperatorWithKernel;
2222

23-
protected:
2423
framework::OpKernelType GetExpectedKernelType(
2524
const framework::ExecutionContext& ctx) const override {
2625
return framework::OpKernelType(

python/paddle/fluid/tests/unittests/test_crop_op.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ class TestCropOp(OpTest):
4242
def setUp(self):
4343
self.op_type = "crop"
4444
self.crop_by_input = False
45+
self.offset_by_input = False
4546
self.attrs = {}
4647
self.initTestCase()
47-
self.attrs['offsets'] = self.offsets
4848
if self.crop_by_input:
4949
self.inputs = {
5050
'X': np.random.random(self.x_shape).astype("float32"),
@@ -55,6 +55,10 @@ def setUp(self):
5555
self.inputs = {
5656
'X': np.random.random(self.x_shape).astype("float32"),
5757
}
58+
if self.offset_by_input:
59+
self.inputs['Offsets'] = np.array(self.offsets).astype('int32')
60+
else:
61+
self.attrs['offsets'] = self.offsets
5862
self.outputs = {
5963
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape)
6064
}
@@ -101,5 +105,22 @@ def initTestCase(self):
101105
self.crop_by_input = True
102106

103107

108+
class TestCase5(TestCropOp):
109+
def initTestCase(self):
110+
self.x_shape = (3, 4, 5)
111+
self.crop_shape = [2, 2, 3]
112+
self.offsets = [1, 0, 2]
113+
self.offset_by_input = True
114+
115+
116+
class TestCase6(TestCropOp):
117+
def initTestCase(self):
118+
self.x_shape = (10, 9, 14)
119+
self.crop_shape = [3, 3, 5]
120+
self.offsets = [3, 5, 4]
121+
self.crop_by_input = True
122+
self.offset_by_input = True
123+
124+
104125
if __name__ == '__main__':
105126
unittest.main()

0 commit comments

Comments
 (0)