1111
1212#include " intel/include/Dialect/TritonIntelGPU/IR/Attributes.h"
1313#include " intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
14- #include " triton/Dialect/TritonGPU/IR/Attributes.h"
1514
1615using namespace mlir ;
1716using namespace mlir ::triton;
@@ -194,7 +193,7 @@ struct LoadStoreConversionBase {
194193 ArrayRef<int32_t > boundaryCheck = {},
195194 std::optional<PaddingOption> padding = std::nullopt ) const {
196195
197- auto rank = tensorType.getRank ();
196+ size_t rank = tensorType.getRank ();
198197 // The block pointer struct is expected to have the following layout:
199198 // Struct {
200199 // Value offset[rank];
@@ -818,36 +817,31 @@ struct LoadOpConversion
818817 LogicalResult
819818 matchAndRewrite (triton::LoadOp op, OpAdaptor adaptor,
820819 ConversionPatternRewriter &rewriter) const override {
821- auto loc = op->getLoc ();
820+ if (isTensorPointerType (op.getPtr ().getType ()))
821+ if (rewriteTensorPointerLoad (op, adaptor, rewriter).succeeded ())
822+ return success ();
823+
824+ Location loc = op->getLoc ();
822825 auto typeConverter = getTypeConverter ();
823- auto *ctx = rewriter.getContext ();
826+ MLIRContext *ctx = rewriter.getContext ();
824827
825828 // Determine the vectorization size
826829 Type valueElemTy =
827830 typeConverter->convertType (getElementTypeOrSelf (op.getType ()));
828831 unsigned numElems = getTotalElemsPerThread (op.getType ());
829832 unsigned vec = 1 ;
830833
831- SmallVector<Value> ptrElems;
832- SmallVector<Value> maskElems;
833-
834+ SmallVector<Value> ptrElems, maskElems, otherElems;
834835 bool otherIsSplatConstInt = false ;
835836 int64_t splatVal = 0 ;
836- SmallVector<Value> otherElems;
837837
838838 if (isTensorPointerType (op.getPtr ().getType ())) {
839- if (rewriteTensorPointerLoad (op, adaptor, rewriter).succeeded ()) {
840- return success ();
841- } else {
842- // TODO: (johnlu) set the vector size > 1; Need to prove the memory is
843- // contiguous on the fast changing dim when fallback to gather load.
844- Type resultType = op.getType ();
845- auto tensorType = cast<RankedTensorType>(resultType);
846- std::tie (ptrElems, maskElems, otherElems) =
847- convertBlockPtrToTensorOfPtr (
848- loc, adaptor.getPtr (), tensorType, valueElemTy, rewriter,
849- op.getBoundaryCheck (), op.getPadding ());
850- }
839+ // TODO: (johnlu) set the vector size > 1; Need to prove the memory is
840+ // contiguous on the fast changing dim when fallback to gather load.
841+ auto tensorType = cast<RankedTensorType>(op.getType ());
842+ std::tie (ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr (
843+ loc, adaptor.getPtr (), tensorType, valueElemTy, rewriter,
844+ op.getBoundaryCheck (), op.getPadding ());
851845 } else {
852846 // original values
853847 Value ptr = op.getPtr ();
@@ -922,7 +916,7 @@ struct LoadOpConversion
922916 for (size_t s = 0 ; s < size; ++s) {
923917 Value falseVal = otherElems[vecStart + ii * size + s];
924918 Value sVal = createIndexAttrConstant (
925- rewriter, loc, this -> getTypeConverter () ->getIndexType (), s);
919+ rewriter, loc, typeConverter ->getIndexType (), s);
926920 v = insert_element (vecTy, v, falseVal, sVal );
927921 }
928922 v = bitcast (v, IntegerType::get (ctx, width));
@@ -934,7 +928,7 @@ struct LoadOpConversion
934928 }
935929
936930 Value iiVal = createIndexAttrConstant (
937- rewriter, loc, this -> getTypeConverter () ->getIndexType (), ii);
931+ rewriter, loc, typeConverter ->getIndexType (), ii);
938932 if (nWords > 1 ) {
939933 other_ = insert_element (retTy, other_, v, iiVal);
940934 } else {
@@ -1129,31 +1123,27 @@ struct StoreOpConversion
11291123 LogicalResult
11301124 matchAndRewrite (triton::StoreOp op, OpAdaptor adaptor,
11311125 ConversionPatternRewriter &rewriter) const override {
1132- auto loc = op->getLoc ();
1133- MLIRContext *ctx = rewriter.getContext ();
1126+ if (isTensorPointerType (op.getPtr ().getType ()))
1127+ if (rewriteTensorPointerStore (op, adaptor, rewriter).succeeded ())
1128+ return success ();
11341129
1130+ Location loc = op->getLoc ();
1131+ MLIRContext *ctx = rewriter.getContext ();
11351132 Value ptr = op.getPtr ();
11361133 Value value = op.getValue ();
1137-
1138- auto valueTy = value.getType ();
1134+ Type valueTy = value.getType ();
11391135 Type valueElemTy =
11401136 typeConverter->convertType (getElementTypeOrSelf (valueTy));
1141- SmallVector<Value> ptrElems;
1142- SmallVector<Value> maskElems;
1137+ SmallVector<Value> ptrElems, maskElems;
11431138 unsigned vec = 1 ;
11441139
11451140 if (isTensorPointerType (ptr.getType ())) {
1146- if (rewriteTensorPointerStore (op, adaptor, rewriter).succeeded ()) {
1147- return success ();
1148- } else {
1149- // fallback to scatter store.
1150- auto tensorType = cast<RankedTensorType>(valueTy);
1151- SmallVector<Value> dummyOther;
1152- std::tie (ptrElems, maskElems, dummyOther) =
1153- convertBlockPtrToTensorOfPtr (loc, adaptor.getPtr (), tensorType,
1154- valueElemTy, rewriter,
1155- op.getBoundaryCheck ());
1156- }
1141+ // fallback to scatter store.
1142+ auto tensorType = cast<RankedTensorType>(valueTy);
1143+ SmallVector<Value> dummyOther;
1144+ std::tie (ptrElems, maskElems, dummyOther) = convertBlockPtrToTensorOfPtr (
1145+ loc, adaptor.getPtr (), tensorType, valueElemTy, rewriter,
1146+ op.getBoundaryCheck ());
11571147 } else {
11581148 Value llPtr = adaptor.getPtr ();
11591149 Value llMask = adaptor.getMask ();
@@ -1561,8 +1551,8 @@ struct AtomicRMWOpConversion
15611551
15621552 rmwVal = bitcast (rmwVal, valueElemTy);
15631553
1564- // Align pointer by 4 bytes by zeroing lower address bits. Atomically read
1565- // a vector of two fp16 values as a single i32. The second lowest bit is
1554+ // Align pointer by 4 bytes by zeroing lower address bits. Atomically read a
1555+ // vector of two fp16 values as a single i32. The second lowest bit is
15661556 // extracted to later be used as an index to extract the required vector
15671557 // element.
15681558 assert (isa<LLVM::LLVMPointerType>(rmwPtr.getType ()));
0 commit comments