55//
66// ===----------------------------------------------------------------------===//
77
8+ #include " triton/Dialect/Triton/IR/Types.h"
9+
10+ #include " triton-shared/Analysis/OpFoldResultUtils.h"
811#include " triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h"
9- #include " mlir /Dialect/Arith /IR/Arith .h"
10- # include " mlir/Dialect/SCF/IR/SCF.h "
12+ #include " triton-shared /Dialect/TritonStructured /IR/TritonStructuredDialect .h"
13+
1114#include " mlir/IR/Builders.h"
1215#include " mlir/IR/BuiltinOps.h"
1316#include " mlir/IR/BuiltinTypeInterfaces.h"
1821#include " mlir/IR/Types.h"
1922#include " mlir/Support/LogicalResult.h"
2023#include " mlir/Transforms/DialectConversion.h"
21- #include " triton-shared/Analysis/OpFoldResultUtils.h"
22- #include " triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
23-
24- #include " triton/Dialect/Triton/IR/Dialect.h"
2524
26- #include " mlir/Dialect/Affine /IR/AffineOps .h"
25+ #include " mlir/Dialect/Arith /IR/Arith .h"
2726#include " mlir/Dialect/Bufferization/IR/Bufferization.h"
2827#include " mlir/Dialect/Linalg/IR/Linalg.h"
2928#include " mlir/Dialect/MemRef/IR//MemRef.h"
30- #include " triton/Dialect/Triton/IR/Types.h"
31-
29+ #include " mlir/Dialect/SCF/IR/SCF.h"
3230#include " mlir/Dialect/Utils/StaticValueUtils.h"
31+
3332#include " llvm/ADT/ArrayRef.h"
3433#include " llvm/ADT/STLExtras.h"
3534#include " llvm/ADT/SmallVector.h"
36- #include " llvm/Support/Debug.h"
3735
3836#include < algorithm>
3937#include < cassert>
@@ -61,6 +59,13 @@ static memref::SubViewOp getSubview(int rank, ArrayRef<OpFoldResult> dims,
6159 offsets, dims, strides);
6260}
6361
62+ static Value getPtr (Value v) {
63+ while (auto op = v.getDefiningOp ()) {
64+ v = op->getOperand (0 );
65+ }
66+ return v;
67+ }
68+
6469namespace {
6570
6671struct MakeTensorPtrConverter
@@ -373,28 +378,27 @@ struct MakeTensorPtrConverter
373378 op, staticTargetOffset.value_or (ShapedType::kDynamic ), staticStrides,
374379 resultShape);
375380
376- // The base ptr, which is from one of the args, would have already been
377- // converted to memref<*> at this point, so get the base from adaptor.
378- //
379381 // For block pointers, the base could come from a sequence of `tt.addptr`,
380- // which at this point has already been lowered to a sequence of
381- // `memref.reinterpret_cast` ops. The offset in such cases are dynamic.
382- // (see test/Conversion/StructuredToMemref/block_ptr_complex_offset.mlir)
383- //
384- // For non-block pointer cases, the base is the reinterpret_cast of a
385- // function argument. Assert that the offset is a constant 0 in such cases.
386- auto ptr = adaptor.getBase ();
387- if (auto reinterpretCast = ptr.getDefiningOp <memref::ReinterpretCastOp>()) {
388- auto offset = reinterpretCast.getMixedOffsets ()[0 ];
389- auto intAttr = getIntAttr (offset);
390- assert (isBlockPtr || (intAttr.has_value () && intAttr.value () == 0 ));
391- targetOffset = addOFRs (targetOffset, reinterpretCast.getMixedOffsets ()[0 ],
392- op->getLoc (), rewriter);
382+ // which at this point has already been lowered a single
383+ // tts.make_unstructured_tptr via --fold-unstructured-ptr. The offset in
384+ // such cases is dynamic and comes from the second operand of
385+ // tts.make_unstructured_tptr. See
386+ // test/Conversion/StructuredToMemref/block_ptr_complex_offset.mlir
387+ // Accumulate the target offset with the offset from
388+ // tts.make_unstructured_tptr.
389+ if (auto makePtr =
390+ op.getBase ().getDefiningOp <tts::MakeUnstructuredTensorPtrOp>()) {
391+ auto prevOff =
392+ rewriter
393+ .create <arith::IndexCastOp>(op.getLoc (), rewriter.getIndexType (),
394+ makePtr.getOffset ())
395+ .getResult ();
396+ targetOffset = addOFRs (prevOff, targetOffset, op->getLoc (), rewriter);
393397 }
394398
395399 auto castOp = rewriter.create <memref::ReinterpretCastOp>(
396- op.getLoc (), resultType, ptr, targetOffset, op. getMixedSizes () ,
397- mixedStrides);
400+ op.getLoc (), resultType, adaptor. getBase (), targetOffset ,
401+ op. getMixedSizes (), mixedStrides);
398402
399403 rewriter.replaceOp (op, castOp);
400404
@@ -421,6 +425,10 @@ struct MakeTensorPtrConverter
421425 }
422426
423427public:
428+ MakeTensorPtrConverter (const TypeConverter &typeConverter,
429+ MLIRContext *context)
430+ : OpConversionPattern<tts::MakeTensorPtrOp>(typeConverter, context) {}
431+
424432 LogicalResult
425433 matchAndRewrite (tts::MakeTensorPtrOp op, OpAdaptor adaptor,
426434 ConversionPatternRewriter &rewriter) const override {
@@ -594,8 +602,13 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
594602 // No mask
595603 assert (!other && " other value used in non-masked load" );
596604
597- if (auto unrealizedCast = ptr.getDefiningOp <UnrealizedConversionCastOp>()) {
605+ auto ptrDefiningOp = ptr.getDefiningOp ();
606+ if (ptrDefiningOp->hasAttr (WRAP_SIDE_BY_SIDE) ||
607+ ptrDefiningOp->hasAttr (WRAP_STACKED)) {
608+
609+ auto unrealizedCast = cast<UnrealizedConversionCastOp>(ptrDefiningOp);
598610 auto memrefs = unrealizedCast.getOperands ();
611+ assert (memrefs.size () == 2 );
599612 auto block1 = memrefs[0 ];
600613 auto block2 = memrefs[1 ];
601614
@@ -664,9 +677,14 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
664677 });
665678 }
666679
667- if (auto unrealizedCast = ptr.getDefiningOp <UnrealizedConversionCastOp>()) {
680+ auto ptrDefiningOp = ptr.getDefiningOp ();
681+ if (ptrDefiningOp->hasAttr (WRAP_SIDE_BY_SIDE) ||
682+ ptrDefiningOp->hasAttr (WRAP_STACKED)) {
683+
684+ auto unrealizedCast = cast<UnrealizedConversionCastOp>(ptrDefiningOp);
668685
669686 auto memrefs = unrealizedCast.getOperands ();
687+ assert (memrefs.size () == 2 );
670688 auto block1 = memrefs[0 ];
671689 auto block2 = memrefs[1 ];
672690
@@ -700,6 +718,9 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
700718 }
701719
702720public:
721+ LoadConverter (const TypeConverter &typeConverter, MLIRContext *context)
722+ : OpConversionPattern<tts::LoadOp>(typeConverter, context) {}
723+
703724 LogicalResult
704725 matchAndRewrite (tts::LoadOp op, OpAdaptor adaptor,
705726 ConversionPatternRewriter &rewriter) const override {
@@ -730,6 +751,9 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
730751 }
731752
732753public:
754+ StoreConverter (const TypeConverter &typeConverter, MLIRContext *context)
755+ : OpConversionPattern<tts::StoreOp>(typeConverter, context) {}
756+
733757 LogicalResult
734758 matchAndRewrite (tts::StoreOp op, OpAdaptor adaptor,
735759 ConversionPatternRewriter &rewriter) const override {
@@ -759,101 +783,10 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
759783 }
760784};
761785
762- struct ScalarLoadConverter : public OpConversionPattern <triton::LoadOp> {
763- using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
764-
765- LogicalResult
766- matchAndRewrite (triton::LoadOp op, OpAdaptor adaptor,
767- ConversionPatternRewriter &rewriter) const override {
768- if (!op.getType ().isIntOrIndexOrFloat ()) {
769- return failure ();
770- }
771-
772- auto loc = op->getLoc ();
773- auto memrefPtr = adaptor.getPtr ();
774- auto zeroMap = AffineMap::getConstantMap (0 , rewriter.getContext ());
775- auto loadOp = rewriter.create <affine::AffineLoadOp>(loc, memrefPtr, zeroMap,
776- std::nullopt );
777- rewriter.replaceOp (op, loadOp.getResult ());
778-
779- return success ();
780- }
781- };
782-
783- struct ScalarStoreConverter : public OpConversionPattern <triton::StoreOp> {
784- private:
785- using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
786-
787- public:
788- LogicalResult
789- matchAndRewrite (triton::StoreOp op, OpAdaptor adaptor,
790- ConversionPatternRewriter &rewriter) const override {
791-
792- if (!op.getValue ().getType ().isIntOrIndexOrFloat ()) {
793- return failure ();
794- }
795-
796- auto loc = op->getLoc ();
797- auto memrefPtr = adaptor.getPtr ();
798- auto val = op.getValue ();
799- auto zeroMap = AffineMap::getConstantMap (0 , rewriter.getContext ());
800-
801- rewriter.create <affine::AffineStoreOp>(loc, val, memrefPtr, zeroMap,
802- std::nullopt );
803- rewriter.eraseOp (op);
804-
805- return success ();
806- }
807- };
808-
809- struct UnrealizedCastConverter
810- : public OpConversionPattern<UnrealizedConversionCastOp> {
811- private:
812- using OpConversionPattern<UnrealizedConversionCastOp>::OpConversionPattern;
813-
814- public:
815- UnrealizedCastConverter (TypeConverter &typeConverter, MLIRContext *context)
816- : OpConversionPattern<UnrealizedConversionCastOp>(typeConverter,
817- context) {}
818-
819- LogicalResult
820- matchAndRewrite (UnrealizedConversionCastOp op, OpAdaptor adaptor,
821- ConversionPatternRewriter &rewriter) const override {
822- auto resType = op->getResultTypes ()[0 ];
823- auto input = op.getInputs ()[0 ];
824- auto inputType = input.getType ();
825-
826- if (!isa<triton::PointerType>(resType) ||
827- !isa<MemRefType, UnrankedMemRefType>(inputType)) {
828- return failure ();
829- }
830-
831- if (auto reinterpretCast =
832- input.getDefiningOp <memref::ReinterpretCastOp>()) {
833- rewriter.replaceOp (op, reinterpretCast);
834- } else {
835- auto ptrType = cast<triton::PointerType>(resType);
836- auto memrefType =
837- cast<MemRefType>(getTypeConverter ()->convertType (ptrType));
838-
839- auto cast = rewriter.create <memref::ReinterpretCastOp>(
840- op->getLoc (), memrefType, op.getInputs ()[0 ], 0 /* offset*/ ,
841- SmallVector<int64_t >{1 } /* sizes*/ ,
842- SmallVector<int64_t >{1 } /* strides*/ );
843-
844- rewriter.replaceOp (op, cast);
845- }
846-
847- return success ();
848- }
849- };
850-
851786} // namespace
852787
853788void mlir::triton::populateStructuredToMemrefConversionPatterns (
854789 RewritePatternSet &patterns, TypeConverter &typeConverter) {
855- patterns.add <UnrealizedCastConverter>(typeConverter, patterns.getContext ());
856- patterns.add <MakeTensorPtrConverter, LoadConverter, StoreConverter,
857- ScalarLoadConverter, ScalarStoreConverter>(
858- patterns.getContext ());
790+ patterns.add <MakeTensorPtrConverter>(typeConverter, patterns.getContext ());
791+ patterns.add <LoadConverter, StoreConverter>(patterns.getContext ());
859792}
0 commit comments