Skip to content

Commit 8aa5757

Browse files
authored
Fix slice op shape=-1 bug (#18107) (#18227)
* fix slice op bug; test=develop * fix variabel test bug; test=develop * remove slice while true; test=develop
1 parent cac315f commit 8aa5757

File tree

5 files changed

+370
-79
lines changed

5 files changed

+370
-79
lines changed

paddle/fluid/operators/slice_op.cc

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,49 @@ class SliceOp : public framework::OperatorWithKernel {
3939
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
4040
auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
4141
auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
42+
auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis");
4243

4344
PADDLE_ENFORCE_EQ(starts.size(), ends.size());
4445
PADDLE_ENFORCE_EQ(starts.size(), axes.size());
4546
int dim_value, start, end;
4647
for (size_t i = 0; i < axes.size(); ++i) {
4748
dim_value = out_dims[axes[i]];
48-
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
49-
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
50-
start = std::max(start, 0);
51-
end = std::max(end, 0);
52-
start = std::min(start, dim_value);
53-
end = std::min(end, dim_value);
54-
start = std::min(start, end);
55-
out_dims[axes[i]] = end - start;
49+
if (dim_value > 0) {
50+
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
51+
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
52+
start = std::max(start, 0);
53+
end = std::max(end, 0);
54+
// start = std::min(start, dim_value);
55+
end = std::min(end, dim_value);
56+
// start = std::min(start, end);
57+
PADDLE_ENFORCE_GT(end, start, "end should greater than start");
58+
out_dims[axes[i]] = end - start;
59+
}
5660
}
61+
62+
// generate new shape
63+
if (decrease_axis.size() > 0) {
64+
std::vector<int> new_out_shape;
65+
for (size_t i = 0; i < decrease_axis.size(); ++i) {
66+
if (ctx->IsRuntime()) {
67+
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
68+
"decrease dim should be 1");
69+
}
70+
out_dims[decrease_axis[i]] = 0;
71+
}
72+
73+
for (int i = 0; i < out_dims.size(); ++i) {
74+
if (out_dims[i] != 0) {
75+
new_out_shape.push_back(out_dims[i]);
76+
}
77+
}
78+
if (new_out_shape.size() == 0) {
79+
new_out_shape.push_back(1);
80+
}
81+
82+
out_dims = framework::make_ddim(new_out_shape);
83+
}
84+
5785
ctx->SetOutputDim("Out", out_dims);
5886
if (axes[0] != 0) {
5987
ctx->ShareLoD("Input", /*->*/ "Out");
@@ -84,7 +112,8 @@ class SliceOpMaker : public framework::OpProtoAndCheckerMaker {
84112
AddAttr<std::vector<int>>(
85113
"ends",
86114
"(list<int>) Starting indices of corresponding axis in `axes`.");
87-
115+
AddAttr<std::vector<int>>("decrease_axis", "(list<int>) decrease_axis")
116+
.SetDefault({});
88117
AddComment(R"DOC(
89118
Slice Operator.
90119

paddle/fluid/operators/slice_op.h

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,45 @@ class SliceKernel : public framework::OpKernel<T> {
5555
*context.template device_context<DeviceContext>().eigen_device();
5656
auto in = context.Input<framework::Tensor>("Input");
5757
auto out = context.Output<framework::Tensor>("Out");
58-
out->mutable_data<T>(context.GetPlace());
5958
auto out_dims = out->dims();
6059
auto in_dims = in->dims();
60+
61+
// resize out_dims
62+
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
63+
if (decrease_axis.size() > 0) {
64+
if (decrease_axis.size() == (size_t)in_dims.size()) {
65+
std::vector<int> vec_origin_out_shape(decrease_axis.size(), 1);
66+
out->Resize(framework::make_ddim(vec_origin_out_shape));
67+
} else {
68+
std::vector<int> vec_origin_out_shape(
69+
out_dims.size() + decrease_axis.size(), -1);
70+
71+
for (size_t i = 0; i < decrease_axis.size(); ++i) {
72+
vec_origin_out_shape[decrease_axis[i]] = 1;
73+
}
74+
75+
int index = 0;
76+
for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) {
77+
if (vec_origin_out_shape[i] == -1) {
78+
vec_origin_out_shape[i] = out_dims[index];
79+
++index;
80+
}
81+
}
82+
83+
out->Resize(framework::make_ddim(vec_origin_out_shape));
84+
}
85+
}
86+
87+
out->mutable_data<T>(context.GetPlace());
6188
auto axes = context.Attr<std::vector<int>>("axes");
6289
auto starts = context.Attr<std::vector<int>>("starts");
6390

91+
auto new_out_dims = out->dims();
6492
auto offsets = Eigen::array<int, D>();
6593
auto extents = Eigen::array<int, D>();
6694
for (size_t i = 0; i < D; ++i) {
6795
offsets[i] = 0;
68-
extents[i] = out_dims[i];
96+
extents[i] = new_out_dims[i];
6997
}
7098
int start;
7199
for (size_t i = 0; i < axes.size(); ++i) {
@@ -81,18 +109,18 @@ class SliceKernel : public framework::OpKernel<T> {
81109
*in);
82110
auto out_t =
83111
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
84-
*out);
112+
*out, new_out_dims);
85113
out_t.device(place) = in_t.slice(offsets, extents);
114+
115+
out->Resize(out_dims);
86116
}
87117
};
88118

89119
template <typename DeviceContext, typename T>
90120
class SliceGradKernel : public framework::OpKernel<T> {
91121
public:
92122
void Compute(const framework::ExecutionContext& ctx) const override {
93-
size_t rank = ctx.Input<framework::Tensor>(framework::GradVarName("Out"))
94-
->dims()
95-
.size();
123+
size_t rank = ctx.Input<framework::Tensor>("Input")->dims().size();
96124
switch (rank) {
97125
case 1:
98126
SliceCompute<1>(ctx);
@@ -130,6 +158,32 @@ class SliceGradKernel : public framework::OpKernel<T> {
130158
auto axes = context.Attr<std::vector<int>>("axes");
131159
auto starts = context.Attr<std::vector<int>>("starts");
132160

161+
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
162+
if (decrease_axis.size() > 0) {
163+
if (decrease_axis.size() == (size_t)in_dims.size()) {
164+
// all dims decrease
165+
std::vector<int> vec_origin_out_shape(decrease_axis.size(), 1);
166+
out_dims = framework::make_ddim(vec_origin_out_shape);
167+
} else {
168+
std::vector<int> vec_origin_out_shape(
169+
out_dims.size() + decrease_axis.size(), -1);
170+
171+
for (size_t i = 0; i < decrease_axis.size(); ++i) {
172+
vec_origin_out_shape[decrease_axis[i]] = 1;
173+
}
174+
175+
int index = 0;
176+
for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) {
177+
if (vec_origin_out_shape[i] == -1) {
178+
vec_origin_out_shape[i] = out_dims[index];
179+
++index;
180+
}
181+
}
182+
183+
out_dims = framework::make_ddim(vec_origin_out_shape);
184+
}
185+
}
186+
133187
auto offsets = Eigen::array<int, D>();
134188
auto extents = Eigen::array<int, D>();
135189
for (size_t i = 0; i < D; ++i) {
@@ -155,7 +209,7 @@ class SliceGradKernel : public framework::OpKernel<T> {
155209
*d_input);
156210
auto d_out_t =
157211
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
158-
*d_out);
212+
*d_out, out_dims);
159213
d_in_t.device(place) = d_out_t.pad(paddings, 0);
160214
}
161215
};

python/paddle/fluid/framework.py

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -822,35 +822,84 @@ def __getitem__(self, item):
822822
Returns:
823823
Sliced variable
824824
"""
825-
new_var = None
826-
if isinstance(item, tuple):
827-
if len(item) > len(self.shape):
828-
raise IndexError("Too many indexes")
829-
fixedSize = True
830-
for i in range(len(self.shape)):
831-
if self.shape[i] == -1:
832-
fixedSize = False
833-
break
834825

835-
newitem = self._reconstructSliceinfo(item) or item
836-
if fixedSize:
837-
check, info = self._detectContinuesSlice(newitem)
838-
if check:
839-
starts = info[0]
840-
ends = info[1]
841-
axes = [i for i in range(len(starts))]
842-
return self._sliceVar(axes, starts, ends)
843-
else:
844-
new_var = self
845-
for index, o in enumerate(newitem):
846-
new_var = new_var._sliceAndConcatVar(o, index)
826+
if not isinstance(item, tuple):
827+
item = [item]
828+
829+
decrease_axis = []
830+
slice_axis = []
831+
slice_start = []
832+
slice_end = []
833+
reverse_axis = []
834+
835+
for dim, slice_item in enumerate(item):
836+
if isinstance(slice_item, slice):
837+
start = slice_item.start
838+
end = slice_item.stop
839+
step = slice_item.step if slice_item.step else 1
840+
841+
assert (step == 1 or step == -1)
842+
843+
if step == -1:
844+
reverse_axis.append(dim)
845+
assert (start is None and end is None)
846+
847+
if start is None and end is None:
848+
continue
849+
850+
if start is None:
851+
start = 0
852+
853+
if end is None:
854+
end = 10000000
855+
856+
slice_axis.append(dim)
857+
slice_start.append(start)
858+
slice_end.append(end)
847859
else:
848-
new_var = self
849-
for index, o in enumerate(newitem):
850-
new_var = new_var._sliceAndConcatVar(o, index)
851-
else:
852-
new_var = self._sliceAndConcatVar(item, 0)
853-
return new_var
860+
# int
861+
decrease_axis.append(dim)
862+
slice_axis.append(dim)
863+
slice_start.append(slice_item)
864+
slice_end.append(slice_item + 1
865+
if slice_item != -1 else 10000000)
866+
867+
out = self
868+
if len(slice_axis) > 0:
869+
# append slice_op here
870+
871+
slice_out_var = self.block.create_var(
872+
name=unique_name.generate_with_ignorable_key(self.name +
873+
"_slice"),
874+
dtype=self.dtype)
875+
876+
self.block.append_op(
877+
type="slice",
878+
inputs={'Input': [out]},
879+
outputs={'Out': [slice_out_var]},
880+
attrs={
881+
'axes': slice_axis,
882+
'starts': slice_start,
883+
'ends': slice_end,
884+
'decrease_axis': decrease_axis
885+
})
886+
887+
out = slice_out_var
888+
889+
if len(reverse_axis) > 0:
890+
reverse_out_var = self.block.create_var(
891+
name=unique_name.generate_with_ignorable_key(self.name +
892+
"_slice_reverse"),
893+
dtype=self.dtype)
894+
self.block.append_op(
895+
type="reverse",
896+
inputs={'X': out},
897+
outputs={'Out': [reverse_out_var]},
898+
attrs={'axis': reverse_axis})
899+
900+
out = reverse_out_var
901+
902+
return out
854903

855904

856905
def get_all_op_protos():

0 commit comments

Comments
 (0)