@@ -192,7 +192,7 @@ bool StridedSliceOp::AttachImpl(const cpp::OpDesc &op_desc,
192
192
param_.strides = op_desc.GetAttr <std::vector<int >>(" strides" );
193
193
}
194
194
195
- if (op_desc.HasAttr (" infer_flags " )) {
195
+ if (op_desc.HasAttr (" axes " )) {
196
196
param_.axes = op_desc.GetAttr <std::vector<int >>(" axes" );
197
197
}
198
198
@@ -234,34 +234,40 @@ bool StridedSliceOp::AttachImpl(const cpp::OpDesc &op_desc,
234
234
}
235
235
}
236
236
auto tensor_input = false ;
237
- if (op_desc.HasInput (" EndsTensor" ) || op_desc.HasInput (" StartsTensor" ) ||
238
- op_desc.HasInput (" StridesTensor" )) {
237
+ if ((op_desc.HasInput (" EndsTensor" ) &&
238
+ !op_desc.Input (" EndsTensor" ).empty ()) ||
239
+ (op_desc.HasInput (" StartsTensor" ) &&
240
+ !op_desc.Input (" StartsTensor" ).empty ()) ||
241
+ (op_desc.HasInput (" StridesTensor" ) &&
242
+ !op_desc.Input (" StridesTensor" ).empty ())) {
239
243
tensor_input = true ;
240
244
}
241
245
param_.tensor_input = tensor_input;
242
- if (!op_desc.HasInput (" EndsTensor" )) {
246
+ if (op_desc.HasInput (" EndsTensor" ) && !op_desc.Input (" EndsTensor" ).empty ()) {
247
+ auto inputs = op_desc.Input (" EndsTensor" ).front ();
248
+ param_.EndsTensor = scope->FindVar (inputs)->GetMutable <Tensor>();
249
+ } else {
243
250
CHECK_EQ (param_.axes .size (), ends_size)
244
251
<< " axes.size(): " << param_.axes .size ()
245
252
<< " is not equal to ends_size: " << ends_size;
246
- } else {
247
- auto inputs = op_desc.Input (" EndsTensor" ).front ();
248
- param_.EndsTensor = scope->FindVar (inputs)->GetMutable <Tensor>();
249
253
}
250
- if (!op_desc.HasInput (" StartsTensor" )) {
254
+ if (op_desc.HasInput (" StartsTensor" ) &&
255
+ !op_desc.Input (" StartsTensor" ).empty ()) {
256
+ auto inputs = op_desc.Input (" StartsTensor" ).front ();
257
+ param_.StartsTensor = scope->FindVar (inputs)->GetMutable <Tensor>();
258
+ } else {
251
259
CHECK_EQ (param_.axes .size (), starts_size)
252
260
<< " axes.size(): " << param_.axes .size ()
253
261
<< " is not equal to starts_size: " << starts_size;
254
- } else {
255
- auto inputs = op_desc.Input (" StartsTensor" ).front ();
256
- param_.StartsTensor = scope->FindVar (inputs)->GetMutable <Tensor>();
257
262
}
258
- if (!op_desc.HasInput (" StridesTensor" )) {
263
+ if (op_desc.HasInput (" StridesTensor" ) &&
264
+ !op_desc.Input (" StridesTensor" ).empty ()) {
265
+ auto inputs = op_desc.Input (" StridesTensor" ).front ();
266
+ param_.StridesTensor = scope->FindVar (inputs)->GetMutable <Tensor>();
267
+ } else {
259
268
CHECK_EQ (param_.axes .size (), strides_size)
260
269
<< " axes.size(): " << param_.axes .size ()
261
270
<< " is not equal to ends_size: " << strides_size;
262
- } else {
263
- auto inputs = op_desc.Input (" StridesTensor" ).front ();
264
- param_.StridesTensor = scope->FindVar (inputs)->GetMutable <Tensor>();
265
271
}
266
272
return true ;
267
273
}
0 commit comments