Skip to content

Commit c3a87e3

Browse files
authored
support slice double grad, test=develop (#22166) (#22836)
* support slice double grad, test=develop * merge two doublegradopmaker to one doublegradopmaker,test=develop * change the shape of slice_OP's unittest, test=develop
1 parent 9ba5216 commit c3a87e3

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

paddle/fluid/operators/slice_op.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,34 @@ class SliceOpGradMaker : public framework::SingleGradOpMaker<T> {
291291
}
292292
};
293293

294+
template <typename T>
295+
class SliceDoubleOpGradMaker : public framework::SingleGradOpMaker<T> {
296+
public:
297+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
298+
299+
protected:
300+
std::unique_ptr<T> Apply() const override {
301+
auto *bind = new T();
302+
if (this->HasInput("StartsTensor")) {
303+
bind->SetInput("StartsTensor", this->Input("StartsTensor"));
304+
}
305+
if (this->HasInput("EndsTensor")) {
306+
bind->SetInput("EndsTensor", this->Input("EndsTensor"));
307+
}
308+
if (this->HasInput("StartsTensorList")) {
309+
bind->SetInput("StartsTensorList", this->Input("StartsTensorList"));
310+
}
311+
if (this->HasInput("EndsTensorList")) {
312+
bind->SetInput("EndsTensorList", this->Input("EndsTensorList"));
313+
}
314+
bind->SetInput("Input", this->OutputGrad(framework::GradVarName("Input")));
315+
bind->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
316+
bind->SetAttrMap(this->Attrs());
317+
bind->SetType("slice");
318+
return std::unique_ptr<T>(bind);
319+
}
320+
};
321+
294322
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SliceOpGradNoNeedBufferVarsInference,
295323
"Input");
296324

@@ -302,6 +330,8 @@ REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker,
302330
ops::SliceOpGradMaker<paddle::framework::OpDesc>,
303331
ops::SliceOpGradMaker<paddle::imperative::OpBase>);
304332
REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad,
333+
ops::SliceDoubleOpGradMaker<paddle::framework::OpDesc>,
334+
ops::SliceDoubleOpGradMaker<paddle::imperative::OpBase>,
305335
ops::SliceOpGradNoNeedBufferVarsInference);
306336

307337
REGISTER_OP_CPU_KERNEL(

python/paddle/fluid/tests/unittests/test_nn_grad.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,41 @@ def test_grad(self):
4343
self.func(p)
4444

4545

46+
class TestSliceOpDoubleGradCheck(unittest.TestCase):
47+
def func(self, place):
48+
self.config()
49+
50+
out = fluid.layers.slice(
51+
self.inputs, axes=self.axes, starts=self.starts, ends=self.ends)
52+
gradient_checker.double_grad_check(
53+
[self.inputs], out, x_init=self.x_arr, place=place)
54+
55+
def config(self):
56+
self.starts = [1, 0, -1]
57+
self.ends = [3, 3, 6]
58+
self.axes = [0, 1, 2]
59+
self.x_arr = np.random.random([3, 4, 5, 2]).astype("float64")
60+
self.inputs = layers.create_parameter(
61+
dtype="float64", shape=[3, 4, 5, 2], name='x')
62+
63+
def test_grad(self):
64+
places = [fluid.CPUPlace()]
65+
if core.is_compiled_with_cuda():
66+
places.append(fluid.CUDAPlace(0))
67+
for place in places:
68+
self.func(place)
69+
70+
71+
class TestSliceOpDoubleGradCheckCase3(TestSliceOpDoubleGradCheck):
72+
def config(self):
73+
self.starts = [1, -1, 1]
74+
self.ends = [3, 3, 3]
75+
self.axes = [0, 1, 2]
76+
self.x_arr = np.random.random([3, 3, 3]).astype("float64")
77+
self.inputs = layers.create_parameter(
78+
dtype="float64", shape=[3, 3, 3], name='x3')
79+
80+
4681
class TestReduceMeanWithDimDoubleGradCheck(unittest.TestCase):
4782
@prog_scope()
4883
def func(self, place):

0 commit comments

Comments
 (0)