@@ -345,6 +345,51 @@ class ConvertAtenSliceScatterOp
345345 }
346346};
347347
348+ class ConvertAtenArangeStartStepOp
349+ : public OpConversionPattern<AtenArangeStartStepOp> {
350+ using OpConversionPattern::OpConversionPattern;
351+
352+ LogicalResult
353+ matchAndRewrite (AtenArangeStartStepOp op, OpAdaptor adaptor,
354+ ConversionPatternRewriter &rewriter) const override {
355+
356+ // At this point all tensors should have value semantics, and hence the
357+ // `layout` check can be ignored.
358+
359+ // The pin_memory should be either `False` or `none`.
360+ bool pinMemory;
361+ if (!isa<Torch::NoneType>(op.getPinMemory ().getType ()) &&
362+ (!matchPattern (op.getPinMemory (), m_TorchConstantBool (&pinMemory)) ||
363+ pinMemory)) {
364+ return rewriter.notifyMatchFailure (
365+ op, " unimplemented: pin_memory must be either None or false" );
366+ }
367+
368+ torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
369+ getTypeConverter ()};
370+ bool allStatic = true ;
371+ // trt-mlir takes F64Attr, so we need to convert const int to fp attr
372+ if (!helper.tryConvertConstToFloatAttr (" start" , op.getStart ())) {
373+ allStatic = false ;
374+ helper.addOperand (" start" , adaptor.getStart ());
375+ }
376+ if (!helper.tryConvertConstToFloatAttr (" end" , op.getEnd ())) {
377+ allStatic = false ;
378+ helper.addOperand (" end" , adaptor.getEnd ());
379+ }
380+ if (!helper.tryConvertConstToFloatAttr (" step" , op.getStep ())) {
381+ allStatic = false ;
382+ helper.addOperand (" step" , adaptor.getStep ());
383+ }
384+ // static start, end, and step case will be handled through TOSA dialect
385+ if (allStatic)
386+ return rewriter.notifyMatchFailure (op,
387+ " only non-constant values supported" );
388+
389+ return helper.replace ();
390+ }
391+ };
392+
348393} // namespace
349394
350395void torch_to_tcp::populateTcpCustomOpPatternsAndLegality (
@@ -365,8 +410,10 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
365410 INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN (AtenCumsumOp);
366411 INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN (AtenMinDimOp);
367412 INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN (AtenSliceScatterOp);
368- // AtenViewOp can still live after torch-to-tcp conversion
413+ // Following ops can still live after torch-to-tcp conversion
369414 patterns.add <ConvertAtenViewOp>(typeConverter, patterns.getContext ());
415+ patterns.add <ConvertAtenArangeStartStepOp>(typeConverter,
416+ patterns.getContext ());
370417#undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN
371418
372419 // Torch -> TOSA doesn't handle transposed convolutions; map them to
0 commit comments