Skip to content

Commit a1b8656

Browse files
[fix] fix strided_slice_op_cc cherry-pick bug,test=develop (#5792)
1 parent 35cabd3 commit a1b8656

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

lite/operators/strided_slice_op.cc

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ bool StridedSliceOp::AttachImpl(const cpp::OpDesc &op_desc,
192192
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
193193
}
194194

195-
if (op_desc.HasAttr("infer_flags")) {
195+
if (op_desc.HasAttr("axes")) {
196196
param_.axes = op_desc.GetAttr<std::vector<int>>("axes");
197197
}
198198

@@ -234,34 +234,40 @@ bool StridedSliceOp::AttachImpl(const cpp::OpDesc &op_desc,
234234
}
235235
}
236236
auto tensor_input = false;
237-
if (op_desc.HasInput("EndsTensor") || op_desc.HasInput("StartsTensor") ||
238-
op_desc.HasInput("StridesTensor")) {
237+
if ((op_desc.HasInput("EndsTensor") &&
238+
!op_desc.Input("EndsTensor").empty()) ||
239+
(op_desc.HasInput("StartsTensor") &&
240+
!op_desc.Input("StartsTensor").empty()) ||
241+
(op_desc.HasInput("StridesTensor") &&
242+
!op_desc.Input("StridesTensor").empty())) {
239243
tensor_input = true;
240244
}
241245
param_.tensor_input = tensor_input;
242-
if (!op_desc.HasInput("EndsTensor")) {
246+
if (op_desc.HasInput("EndsTensor") && !op_desc.Input("EndsTensor").empty()) {
247+
auto inputs = op_desc.Input("EndsTensor").front();
248+
param_.EndsTensor = scope->FindVar(inputs)->GetMutable<Tensor>();
249+
} else {
243250
CHECK_EQ(param_.axes.size(), ends_size)
244251
<< "axes.size(): " << param_.axes.size()
245252
<< " is not equal to ends_size: " << ends_size;
246-
} else {
247-
auto inputs = op_desc.Input("EndsTensor").front();
248-
param_.EndsTensor = scope->FindVar(inputs)->GetMutable<Tensor>();
249253
}
250-
if (!op_desc.HasInput("StartsTensor")) {
254+
if (op_desc.HasInput("StartsTensor") &&
255+
!op_desc.Input("StartsTensor").empty()) {
256+
auto inputs = op_desc.Input("StartsTensor").front();
257+
param_.StartsTensor = scope->FindVar(inputs)->GetMutable<Tensor>();
258+
} else {
251259
CHECK_EQ(param_.axes.size(), starts_size)
252260
<< "axes.size(): " << param_.axes.size()
253261
<< " is not equal to starts_size: " << starts_size;
254-
} else {
255-
auto inputs = op_desc.Input("StartsTensor").front();
256-
param_.StartsTensor = scope->FindVar(inputs)->GetMutable<Tensor>();
257262
}
258-
if (!op_desc.HasInput("StridesTensor")) {
263+
if (op_desc.HasInput("StridesTensor") &&
264+
!op_desc.Input("StridesTensor").empty()) {
265+
auto inputs = op_desc.Input("StridesTensor").front();
266+
param_.StridesTensor = scope->FindVar(inputs)->GetMutable<Tensor>();
267+
} else {
259268
CHECK_EQ(param_.axes.size(), strides_size)
260269
<< "axes.size(): " << param_.axes.size()
261270
<< " is not equal to ends_size: " << strides_size;
262-
} else {
263-
auto inputs = op_desc.Input("StridesTensor").front();
264-
param_.StridesTensor = scope->FindVar(inputs)->GetMutable<Tensor>();
265271
}
266272
return true;
267273
}

0 commit comments

Comments
 (0)