Skip to content

Commit 8d4879f

Browse files
Gaurav ShuklaGaurav Shukla
authored andcommitted
[TORCH][MLIR] Add and templatize lowering of [aten.zeros|aten.ones|aten.empty] ops
- Templatize `aten.zeros` and `aten.ones` ops lowering. - Add E2E support for `aten.empty` op. - Add Integer type support in `aten.mul.Scalar` op lowering. Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 528354d commit 8d4879f

File tree

3 files changed

+167
-76
lines changed

3 files changed

+167
-76
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,57 @@ def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
637637

638638
# ==============================================================================
639639

640+
class EmptyIntModule(torch.nn.Module):
641+
def __init__(self):
642+
super().__init__()
643+
644+
@export
645+
@annotate_args([
646+
None,
647+
])
648+
def forward(self):
649+
return 0 * torch.empty((3, 4), dtype=torch.int64)
650+
651+
@register_test_case(module_factory=lambda: EmptyIntModule())
652+
def EmptyModule_int(module, tu: TestUtils):
653+
module.forward()
654+
655+
# ==============================================================================
656+
657+
class EmptyFloatModule(torch.nn.Module):
658+
def __init__(self):
659+
super().__init__()
660+
661+
@export
662+
@annotate_args([
663+
None,
664+
])
665+
def forward(self):
666+
return torch.abs(torch.empty((3, 4), dtype=torch.float32)) > -1.0
667+
668+
@register_test_case(module_factory=lambda: EmptyFloatModule())
669+
def EmptyModule_float(module, tu: TestUtils):
670+
module.forward()
671+
672+
673+
class EmptyFalsePinMemoryModule(torch.nn.Module):
674+
def __init__(self):
675+
super().__init__()
676+
677+
@export
678+
@annotate_args([
679+
None,
680+
])
681+
def forward(self):
682+
return torch.abs(torch.empty((3, 4), dtype=torch.float32,
683+
pin_memory=False)) > -1.0
684+
685+
@register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule())
686+
def EmptyModule_falsePinMemory(module, tu: TestUtils):
687+
module.forward()
688+
689+
# ==============================================================================
690+
640691
class ContiguousModule(torch.nn.Module):
641692
def __init__(self):
642693
super().__init__()

e2e_testing/torchscript/elementwise.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,24 @@ def RsubModule_noalpha_basic(module, tu: TestUtils):
443443

444444
# ==============================================================================
445445

446+
class ElementwiseMulScalarIntModule(torch.nn.Module):
447+
def __init__(self):
448+
super().__init__()
446449

447-
class ElementwiseMulScalarModule(torch.nn.Module):
450+
@export
451+
@annotate_args([
452+
None,
453+
([-1, -1], torch.int64, True),
454+
])
455+
def forward(self, x):
456+
return torch.mul(x, 4)
457+
458+
@register_test_case(module_factory=lambda: ElementwiseMulScalarIntModule())
459+
def ElementwiseMulScalarModule_int(module, tu: TestUtils):
460+
module.forward(torch.randint(10, (3, 4)))
461+
462+
463+
class ElementwiseMulScalarFloatModule(torch.nn.Module):
448464
def __init__(self):
449465
super().__init__()
450466

@@ -456,11 +472,27 @@ def __init__(self):
456472
def forward(self, x):
457473
return torch.mul(x, 100.0)
458474

459-
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
460-
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
475+
@register_test_case(module_factory=lambda: ElementwiseMulScalarFloatModule())
476+
def ElementwiseMulScalarModule_float(module, tu: TestUtils):
461477
module.forward(tu.rand(3, 4))
462478

463479

480+
class ElementwiseMulScalarModule(torch.nn.Module):
481+
def __init__(self):
482+
super().__init__()
483+
484+
@export
485+
@annotate_args([
486+
None,
487+
([-1, -1], torch.int64, True),
488+
])
489+
def forward(self, x):
490+
return torch.mul(x, 8.0)
491+
492+
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
493+
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
494+
module.forward(torch.randint(10, (3, 4)))
495+
464496

465497
class ElementwiseMulTensorFloatModule(torch.nn.Module):
466498
def __init__(self):

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 81 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,21 @@ static Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
189189
Value initTensor = b.create<linalg::InitTensorOp>(loc, sizes, elemTy);
190190
return b.create<linalg::FillOp>(loc, initElem, initTensor).getResult(0);
191191
}
192+
// Creates a constant of type `elemType` with value `val`.
193+
static Value getConstant(OpBuilder &b, Location loc, int64_t val,
194+
Type elemType) {
195+
Attribute attr = {};
196+
if (elemType.isa<mlir::FloatType>())
197+
attr = b.getFloatAttr(elemType, val);
198+
if (elemType.isa<mlir::IndexType>())
199+
attr = b.getIndexAttr(val);
200+
if (elemType.isa<mlir::IntegerType>())
201+
attr = b.getIntegerAttr(
202+
elemType, APInt(elemType.cast<IntegerType>().getWidth(), val));
203+
if (!attr)
204+
return nullptr;
205+
return b.create<arith::ConstantOp>(loc, elemType, attr);
206+
}
192207

193208
// Helper function to caculate the output tensor dims for convolution-like ops.
194209
// Along each dim:
@@ -1828,13 +1843,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
18281843
Type dtype = converter->convertType(mulScalar.getType())
18291844
.cast<RankedTensorType>()
18301845
.getElementType();
1831-
if (!dtype.isa<mlir::FloatType>()) {
1832-
mulScalar.emitError("unimplemented: non-floating point dtype");
1833-
return nullptr;
1834-
}
1835-
Value self = payloadArgs[0];
1836-
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
1837-
return b.create<arith::MulFOp>(loc, self, other);
1846+
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
1847+
Value rhs = convertScalarToDtype(b, loc, operands[1], dtype);
1848+
if (dtype.isa<mlir::FloatType>())
1849+
return b.create<arith::MulFOp>(loc, lhs, rhs);
1850+
if (dtype.isa<mlir::IntegerType>())
1851+
return b.create<arith::MulIOp>(loc, lhs, rhs);
1852+
mulScalar.emitError("unimplemented: Only integer/float dtype supported");
1853+
return nullptr;
18381854
}
18391855
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
18401856
Value input = payloadArgs[0];
@@ -3417,83 +3433,68 @@ class ConvertAtenContiguousOp : public OpConversionPattern<AtenContiguousOp> {
34173433
} // namespace
34183434

34193435
namespace {
3420-
// Converts AtenOnesOp and AtenZerosOp.
3421-
struct ConvertAtenOnesZerosOp : ConversionPattern {
3422-
ConvertAtenOnesZerosOp(TypeConverter &typeConverter, MLIRContext *context)
3423-
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
3424-
context) {}
3425-
3436+
// Converts constant tensor allocation like ops.
3437+
template <typename OpTy>
3438+
class ConvertConstantTensorAllocOp : public OpConversionPattern<OpTy> {
3439+
public:
3440+
using OpConversionPattern<OpTy>::OpConversionPattern;
34263441
LogicalResult
3427-
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
3442+
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
34283443
ConversionPatternRewriter &rewriter) const override {
3429-
if (!isa<AtenOnesOp, AtenZerosOp>(op))
3430-
return rewriter.notifyMatchFailure(op,
3431-
"not a supported ones or zeros op");
3432-
34333444
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
34343445
return failure();
3435-
Location loc = op->getLoc();
3436-
3437-
Value size, layout, pin_memory;
3438-
int64_t elementValue;
34393446

3440-
if (AtenOnesOp onesOp = dyn_cast<AtenOnesOp>(op)) {
3441-
size = onesOp.size();
3442-
layout = onesOp.layout();
3443-
pin_memory = onesOp.pin_memory();
3444-
elementValue = 1;
3445-
} else if (AtenZerosOp zerosOp = dyn_cast<AtenZerosOp>(op)) {
3446-
size = zerosOp.size();
3447-
layout = zerosOp.layout();
3448-
pin_memory = zerosOp.pin_memory();
3449-
elementValue = 0;
3447+
// Currently memory pinning and layout features are not supported.
3448+
if (!op.layout().getType().template isa<Torch::NoneType>())
3449+
return rewriter.notifyMatchFailure(
3450+
op, "unimplemented: only default layout is supported");
3451+
bool pinMemory;
3452+
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
3453+
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
3454+
pinMemory)) {
3455+
return rewriter.notifyMatchFailure(
3456+
op, "unimplemented: pin_memory must be either None or false");
34503457
}
34513458

3452-
// We ignore device, but add simple asserts for unimplemented kwargs
3453-
if (!layout.getType().isa<Torch::NoneType>())
3454-
return rewriter.notifyMatchFailure(op,
3455-
"only default layout is supported");
3456-
3457-
bool pinMemory = false;
3458-
if (!pin_memory.getType().isa<Torch::NoneType>() &&
3459-
!matchPattern(pin_memory, m_TorchConstantBool(&pinMemory))) {
3460-
return rewriter.notifyMatchFailure(
3461-
op, "pin_memory must be constant bool or None");
3459+
// Memory formats are not supported in the case of `AtenEmptyMemoryFormat`.
3460+
if constexpr (std::is_same<OpTy, AtenEmptyMemoryFormatOp>::value) {
3461+
if (!op.memory_format().getType().template isa<Torch::NoneType>())
3462+
return rewriter.notifyMatchFailure(
3463+
op, "unimplemented: only default memory format is supported");
34623464
}
3463-
if (pinMemory)
3464-
return rewriter.notifyMatchFailure(op, "memory pinning not supported");
34653465

3466-
SmallVector<Value> sizes, sizeIndex;
3467-
if (!getListConstructElements(size, sizes)) {
3466+
Location loc = op.getLoc();
3467+
TypeConverter *typeConverter = this->getTypeConverter();
3468+
SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
3469+
if (!getListConstructElements(op.size(), resultSizeTorchInt)) {
34683470
return rewriter.notifyMatchFailure(
3469-
op, "size must be created by ListConstruct");
3471+
op, "unimplemented: size must be constructed using ListConstruct");
34703472
}
3471-
sizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), sizes);
3472-
for (size_t i = 0; i < sizes.size(); i++)
3473-
sizeIndex.push_back(castIntToIndex(rewriter, loc, sizes[i]));
3474-
3475-
RankedTensorType newResultType =
3476-
getTypeConverter()
3477-
->convertType(op->getResult(0).getType())
3478-
.cast<RankedTensorType>();
3479-
Type outElementType = newResultType.getElementType();
3480-
3481-
Value constantOp = rewriter.create<arith::ConstantOp>(
3482-
loc, outElementType,
3483-
(outElementType.isa<mlir::FloatType>()
3484-
? rewriter.getFloatAttr(outElementType, elementValue)
3485-
.cast<mlir::Attribute>()
3486-
: rewriter.getIntegerAttr(outElementType, elementValue)
3487-
.cast<mlir::Attribute>()));
3488-
Value outTensor = rewriter
3489-
.create<linalg::InitTensorOp>(
3490-
loc, sizeIndex, newResultType.getElementType())
3491-
.getResult();
3492-
Value fillOp = rewriter.create<linalg::FillOp>(loc, constantOp, outTensor)
3493-
.getResult(0);
3473+
resultSize = getTypeConvertedValues(rewriter, loc, typeConverter,
3474+
resultSizeTorchInt);
3475+
for (auto size : resultSize)
3476+
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
34943477

3495-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, fillOp);
3478+
auto resultType =
3479+
typeConverter->convertType(op.getType()).template cast<RankedTensorType>();
3480+
Type outElemType = resultType.getElementType();
34963481

3482+
// Create an uninitialized tensor of `resultSize` shape. It will be returned
3483+
// without initialization/filling in the case of `AtenEmptyMemoryFormatOp`.
3484+
Value outputTensor = rewriter.create<linalg::InitTensorOp>(
3485+
loc, resultSizeIndex, outElemType);
3486+
3487+
// `AtenZeros` and `AtenOnes` ops will be filled with corresponding values.
3488+
if (std::is_same<OpTy, AtenZerosOp>::value) {
3489+
Value zero = getConstant(rewriter, loc, 0, outElemType);
3490+
outputTensor =
3491+
rewriter.create<linalg::FillOp>(loc, zero, outputTensor).getResult(0);
3492+
} else if (std::is_same<OpTy, AtenOnesOp>::value) {
3493+
Value one = getConstant(rewriter, loc, 1, outElemType);
3494+
outputTensor =
3495+
rewriter.create<linalg::FillOp>(loc, one, outputTensor).getResult(0);
3496+
}
3497+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
34973498
return success();
34983499
}
34993500
};
@@ -3704,8 +3705,15 @@ class ConvertTorchToLinalg
37043705
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
37053706
target.addIllegalOp<AtenEmbeddingOp>();
37063707
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
3707-
target.addIllegalOp<AtenOnesOp, AtenZerosOp>();
3708-
patterns.add<ConvertAtenOnesZerosOp>(typeConverter, context);
3708+
target.addIllegalOp<AtenEmptyMemoryFormatOp>();
3709+
patterns.add<ConvertConstantTensorAllocOp<AtenEmptyMemoryFormatOp>>(
3710+
typeConverter, context);
3711+
target.addIllegalOp<AtenZerosOp>();
3712+
patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp>>(typeConverter,
3713+
context);
3714+
target.addIllegalOp<AtenOnesOp>();
3715+
patterns.add<ConvertConstantTensorAllocOp<AtenOnesOp>>(typeConverter,
3716+
context);
37093717
target.addIllegalOp<AtenContiguousOp>();
37103718
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
37113719
target.addIllegalOp<AtenIntTensorOp>();

0 commit comments

Comments
 (0)