@@ -189,6 +189,21 @@ static Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
189189 Value initTensor = b.create <linalg::InitTensorOp>(loc, sizes, elemTy);
190190 return b.create <linalg::FillOp>(loc, initElem, initTensor).getResult (0 );
191191}
192+ // Creates a constant of type `elemType` with value `val`.
193+ static Value getConstant (OpBuilder &b, Location loc, int64_t val,
194+ Type elemType) {
195+ Attribute attr = {};
196+ if (elemType.isa <mlir::FloatType>())
197+ attr = b.getFloatAttr (elemType, val);
198+ if (elemType.isa <mlir::IndexType>())
199+ attr = b.getIndexAttr (val);
200+ if (elemType.isa <mlir::IntegerType>())
201+ attr = b.getIntegerAttr (
202+ elemType, APInt (elemType.cast <IntegerType>().getWidth (), val));
203+ if (!attr)
204+ return nullptr ;
205+ return b.create <arith::ConstantOp>(loc, elemType, attr);
206+ }
192207
193208// Helper function to caculate the output tensor dims for convolution-like ops.
194209// Along each dim:
@@ -1828,13 +1843,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
18281843 Type dtype = converter->convertType (mulScalar.getType ())
18291844 .cast <RankedTensorType>()
18301845 .getElementType ();
1831- if (!dtype.isa <mlir::FloatType>()) {
1832- mulScalar.emitError (" unimplemented: non-floating point dtype" );
1833- return nullptr ;
1834- }
1835- Value self = payloadArgs[0 ];
1836- Value other = convertScalarToDtype (b, loc, operands[1 ], dtype);
1837- return b.create <arith::MulFOp>(loc, self, other);
1846+ Value lhs = convertScalarToDtype (b, loc, payloadArgs[0 ], dtype);
1847+ Value rhs = convertScalarToDtype (b, loc, operands[1 ], dtype);
1848+ if (dtype.isa <mlir::FloatType>())
1849+ return b.create <arith::MulFOp>(loc, lhs, rhs);
1850+ if (dtype.isa <mlir::IntegerType>())
1851+ return b.create <arith::MulIOp>(loc, lhs, rhs);
1852+ mulScalar.emitError (" unimplemented: Only integer/float dtype supported" );
1853+ return nullptr ;
18381854 }
18391855 if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
18401856 Value input = payloadArgs[0 ];
@@ -3417,83 +3433,68 @@ class ConvertAtenContiguousOp : public OpConversionPattern<AtenContiguousOp> {
34173433} // namespace
34183434
34193435namespace {
3420- // Converts AtenOnesOp and AtenZerosOp.
3421- struct ConvertAtenOnesZerosOp : ConversionPattern {
3422- ConvertAtenOnesZerosOp (TypeConverter &typeConverter, MLIRContext *context)
3423- : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /* benefit=*/ 1 ,
3424- context) {}
3425-
3436+ // Converts constant tensor allocation like ops.
3437+ template <typename OpTy>
3438+ class ConvertConstantTensorAllocOp : public OpConversionPattern <OpTy> {
3439+ public:
3440+ using OpConversionPattern<OpTy>::OpConversionPattern;
34263441 LogicalResult
3427- matchAndRewrite (Operation * op, ArrayRef<Value> operands ,
3442+ matchAndRewrite (OpTy op, typename OpTy::Adaptor adaptor ,
34283443 ConversionPatternRewriter &rewriter) const override {
3429- if (!isa<AtenOnesOp, AtenZerosOp>(op))
3430- return rewriter.notifyMatchFailure (op,
3431- " not a supported ones or zeros op" );
3432-
34333444 if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
34343445 return failure ();
3435- Location loc = op->getLoc ();
3436-
3437- Value size, layout, pin_memory;
3438- int64_t elementValue;
34393446
3440- if (AtenOnesOp onesOp = dyn_cast<AtenOnesOp>(op)) {
3441- size = onesOp. size ();
3442- layout = onesOp. layout ();
3443- pin_memory = onesOp. pin_memory ( );
3444- elementValue = 1 ;
3445- } else if (AtenZerosOp zerosOp = dyn_cast<AtenZerosOp>(op)) {
3446- size = zerosOp. size ();
3447- layout = zerosOp. layout ();
3448- pin_memory = zerosOp. pin_memory ();
3449- elementValue = 0 ;
3447+ // Currently memory pinning and layout features are not supported.
3448+ if (!op. layout (). getType (). template isa <Torch::NoneType>())
3449+ return rewriter. notifyMatchFailure (
3450+ op, " unimplemented: only default layout is supported " );
3451+ bool pinMemory ;
3452+ if (!op. pin_memory (). getType (). template isa <Torch::NoneType>() &&
3453+ (! matchPattern (op. pin_memory (), m_TorchConstantBool (&pinMemory)) ||
3454+ pinMemory)) {
3455+ return rewriter. notifyMatchFailure (
3456+ op, " unimplemented: pin_memory must be either None or false " ) ;
34503457 }
34513458
3452- // We ignore device, but add simple asserts for unimplemented kwargs
3453- if (!layout.getType ().isa <Torch::NoneType>())
3454- return rewriter.notifyMatchFailure (op,
3455- " only default layout is supported" );
3456-
3457- bool pinMemory = false ;
3458- if (!pin_memory.getType ().isa <Torch::NoneType>() &&
3459- !matchPattern (pin_memory, m_TorchConstantBool (&pinMemory))) {
3460- return rewriter.notifyMatchFailure (
3461- op, " pin_memory must be constant bool or None" );
3459+ // Memory formats are not supported in the case of `AtenEmptyMemoryFormat`.
3460+ if constexpr (std::is_same<OpTy, AtenEmptyMemoryFormatOp>::value) {
3461+ if (!op.memory_format ().getType ().template isa <Torch::NoneType>())
3462+ return rewriter.notifyMatchFailure (
3463+ op, " unimplemented: only default memory format is supported" );
34623464 }
3463- if (pinMemory)
3464- return rewriter.notifyMatchFailure (op, " memory pinning not supported" );
34653465
3466- SmallVector<Value> sizes, sizeIndex;
3467- if (!getListConstructElements (size, sizes)) {
3466+ Location loc = op.getLoc ();
3467+ TypeConverter *typeConverter = this ->getTypeConverter ();
3468+ SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
3469+ if (!getListConstructElements (op.size (), resultSizeTorchInt)) {
34683470 return rewriter.notifyMatchFailure (
3469- op, " size must be created by ListConstruct" );
3471+ op, " unimplemented: size must be constructed using ListConstruct" );
34703472 }
3471- sizes = getTypeConvertedValues (rewriter, loc, getTypeConverter (), sizes);
3472- for (size_t i = 0 ; i < sizes.size (); i++)
3473- sizeIndex.push_back (castIntToIndex (rewriter, loc, sizes[i]));
3474-
3475- RankedTensorType newResultType =
3476- getTypeConverter ()
3477- ->convertType (op->getResult (0 ).getType ())
3478- .cast <RankedTensorType>();
3479- Type outElementType = newResultType.getElementType ();
3480-
3481- Value constantOp = rewriter.create <arith::ConstantOp>(
3482- loc, outElementType,
3483- (outElementType.isa <mlir::FloatType>()
3484- ? rewriter.getFloatAttr (outElementType, elementValue)
3485- .cast <mlir::Attribute>()
3486- : rewriter.getIntegerAttr (outElementType, elementValue)
3487- .cast <mlir::Attribute>()));
3488- Value outTensor = rewriter
3489- .create <linalg::InitTensorOp>(
3490- loc, sizeIndex, newResultType.getElementType ())
3491- .getResult ();
3492- Value fillOp = rewriter.create <linalg::FillOp>(loc, constantOp, outTensor)
3493- .getResult (0 );
3473+ resultSize = getTypeConvertedValues (rewriter, loc, typeConverter,
3474+ resultSizeTorchInt);
3475+ for (auto size : resultSize)
3476+ resultSizeIndex.push_back (castIntToIndex (rewriter, loc, size));
34943477
3495- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, fillOp);
3478+ auto resultType =
3479+ typeConverter->convertType (op.getType ()).template cast <RankedTensorType>();
3480+ Type outElemType = resultType.getElementType ();
34963481
3482+ // Create an uninitialized tensor of `resultSize` shape. It will be returned
3483+ // without initialization/filling in the case of `AtenEmptyMemoryFormatOp`.
3484+ Value outputTensor = rewriter.create <linalg::InitTensorOp>(
3485+ loc, resultSizeIndex, outElemType);
3486+
3487+ // `AtenZeros` and `AtenOnes` ops will be filled with corresponding values.
3488+ if (std::is_same<OpTy, AtenZerosOp>::value) {
3489+ Value zero = getConstant (rewriter, loc, 0 , outElemType);
3490+ outputTensor =
3491+ rewriter.create <linalg::FillOp>(loc, zero, outputTensor).getResult (0 );
3492+ } else if (std::is_same<OpTy, AtenOnesOp>::value) {
3493+ Value one = getConstant (rewriter, loc, 1 , outElemType);
3494+ outputTensor =
3495+ rewriter.create <linalg::FillOp>(loc, one, outputTensor).getResult (0 );
3496+ }
3497+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, outputTensor);
34973498 return success ();
34983499 }
34993500};
@@ -3704,8 +3705,15 @@ class ConvertTorchToLinalg
37043705 patterns.add <ConvertAtenSizeIntOp>(typeConverter, context);
37053706 target.addIllegalOp <AtenEmbeddingOp>();
37063707 patterns.add <ConvertAtenEmbeddingOp>(typeConverter, context);
3707- target.addIllegalOp <AtenOnesOp, AtenZerosOp>();
3708- patterns.add <ConvertAtenOnesZerosOp>(typeConverter, context);
3708+ target.addIllegalOp <AtenEmptyMemoryFormatOp>();
3709+ patterns.add <ConvertConstantTensorAllocOp<AtenEmptyMemoryFormatOp>>(
3710+ typeConverter, context);
3711+ target.addIllegalOp <AtenZerosOp>();
3712+ patterns.add <ConvertConstantTensorAllocOp<AtenZerosOp>>(typeConverter,
3713+ context);
3714+ target.addIllegalOp <AtenOnesOp>();
3715+ patterns.add <ConvertConstantTensorAllocOp<AtenOnesOp>>(typeConverter,
3716+ context);
37093717 target.addIllegalOp <AtenContiguousOp>();
37103718 patterns.add <ConvertAtenContiguousOp>(typeConverter, context);
37113719 target.addIllegalOp <AtenIntTensorOp>();
0 commit comments