@@ -161,29 +161,33 @@ getWarpsPerCTA(const ArrayRef<int64_t> tensorShape,
161161
162162// Contains some helper functions for both Load and Store conversions.
163163struct LoadStoreConversionBase {
164- explicit LoadStoreConversionBase (const triton::intel::TargetInfo &targetInfo,
165- ModuleAxisInfoAnalysis &axisAnalysisPass)
164+ explicit LoadStoreConversionBase (
165+ const triton::intel::TargetInfo &targetInfo,
166+ const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass)
166167 : targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) {}
167168
168169 unsigned getContiguity (Value ptr) const {
169- auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType ());
170- if (!tensorTy)
171- return 1 ;
172- return axisAnalysisPass.getPtrContiguity (ptr);
170+ return const_cast <triton::intel::ModuleAxisInfoAnalysis &>(axisAnalysisPass)
171+ .getPtrContiguity (ptr);
173172 }
174173
175174 unsigned getVectorSize (Value ptr) const {
176- auto tensorTy = dyn_cast<RankedTensorType> (ptr.getType ());
175+ auto tensorTy = getRankedTensorType (ptr.getType ());
177176 if (!tensorTy)
178177 return 1 ;
179- auto contiguity = getContiguity (ptr);
180- auto pointeeBitWidth = triton::getPointeeBitWidth (tensorTy);
178+
179+ unsigned contiguity = getContiguity (ptr);
180+ unsigned pointeeBitWidth =
181+ isTensorPointerType (ptr.getType ())
182+ ? tensorTy.getElementType ().getIntOrFloatBitWidth ()
183+ : triton::getPointeeBitWidth (tensorTy);
181184 // The maximum vector size is 128 bits.
182185 return std::min<unsigned >(128 / pointeeBitWidth, contiguity);
183186 }
184187
185188 unsigned getMaskAlignment (Value mask) const {
186- return axisAnalysisPass.getMaskAlignment (mask);
189+ return const_cast <triton::intel::ModuleAxisInfoAnalysis &>(axisAnalysisPass)
190+ .getMaskAlignment (mask);
187191 }
188192
189193 std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
@@ -289,7 +293,7 @@ struct LoadStoreConversionBase {
289293 }
290294
291295protected:
292- ModuleAxisInfoAnalysis &axisAnalysisPass;
296+ const triton::intel:: ModuleAxisInfoAnalysis &axisAnalysisPass;
293297 const triton::intel::TargetInfo &targetInfo;
294298};
295299
@@ -299,10 +303,11 @@ struct PrefetchOpConversion
299303 using ConvertTritonGPUOpToLLVMPattern<
300304 triton::gpu::intel::PrefetchOp>::ConvertTritonGPUOpToLLVMPattern;
301305
302- PrefetchOpConversion (TritonGPUToLLVMTypeConverter &converter,
303- const triton::intel::TargetInfo &targetInfo,
304- ModuleAxisInfoAnalysis &axisAnalysisPass,
305- PatternBenefit benefit)
306+ PrefetchOpConversion (
307+ TritonGPUToLLVMTypeConverter &converter,
308+ const triton::intel::TargetInfo &targetInfo,
309+ const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
310+ PatternBenefit benefit)
306311 : ConvertTritonGPUOpToLLVMPattern<triton::gpu::intel::PrefetchOp>(
307312 converter, benefit),
308313 LoadStoreConversionBase (targetInfo, axisAnalysisPass) {}
@@ -475,10 +480,11 @@ struct LoadOpConversion
475480
476481 using ValueTable = std::map<std::pair<int , int >, Value>;
477482
478- LoadOpConversion (TritonIntelGPUToLLVMTypeConverter &converter,
479- const triton::intel::TargetInfo &targetInfo,
480- ModuleAxisInfoAnalysis &axisAnalysisPass,
481- PatternBenefit benefit)
483+ LoadOpConversion (
484+ TritonIntelGPUToLLVMTypeConverter &converter,
485+ const triton::intel::TargetInfo &targetInfo,
486+ const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
487+ PatternBenefit benefit)
482488 : ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
483489 LoadStoreConversionBase (targetInfo, axisAnalysisPass) {}
484490
@@ -824,37 +830,32 @@ struct LoadOpConversion
824830 Location loc = op->getLoc ();
825831 auto typeConverter = getTypeConverter ();
826832 MLIRContext *ctx = rewriter.getContext ();
833+ Value ptr = op.getPtr ();
834+ Value mask = op.getMask ();
835+ Value llMask = adaptor.getMask ();
827836
828837 // Determine the vectorization size
829838 Type valueElemTy =
830839 typeConverter->convertType (getElementTypeOrSelf (op.getType ()));
831840 unsigned numElems = getTotalElemsPerThread (op.getType ());
832- unsigned vec = 1 ;
841+ unsigned vec = getVectorSize (ptr);
842+ if (llMask)
843+ vec = std::min<size_t >(vec, getMaskAlignment (mask));
833844
834845 SmallVector<Value> ptrElems, maskElems, otherElems;
835846 bool otherIsSplatConstInt = false ;
836847 int64_t splatVal = 0 ;
837848
838- if (isTensorPointerType (op.getPtr ().getType ())) {
839- // TODO: (johnlu) set the vector size > 1; Need to prove the memory is
840- // contiguous on the fast changing dim when fallback to gather load.
849+ if (isTensorPointerType (ptr.getType ())) {
850+ // fallback to gather load.
841851 auto tensorType = cast<RankedTensorType>(op.getType ());
842852 std::tie (ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr (
843853 loc, adaptor.getPtr (), tensorType, valueElemTy, rewriter,
844854 op.getBoundaryCheck (), op.getPadding ());
845855 } else {
846- // original values
847- Value ptr = op.getPtr ();
848856 Value other = op.getOther ();
849- Value mask = op.getMask ();
850-
851- // adaptor values
852857 Value llPtr = adaptor.getPtr ();
853- Value llMask = adaptor.getMask ();
854858 Value llOther = adaptor.getOther ();
855- vec = getVectorSize (ptr);
856- if (llMask)
857- vec = std::min<size_t >(vec, getMaskAlignment (mask));
858859
859860 // Get the LLVM values for pointers
860861 ptrElems = unpackLLElements (loc, llPtr, rewriter);
@@ -987,10 +988,11 @@ struct StoreOpConversion
987988 using ConvertTritonGPUOpToLLVMPattern<
988989 triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
989990
990- StoreOpConversion (TritonIntelGPUToLLVMTypeConverter &converter,
991- const triton::intel::TargetInfo &targetInfo,
992- ModuleAxisInfoAnalysis &axisAnalysisPass,
993- PatternBenefit benefit)
991+ StoreOpConversion (
992+ TritonIntelGPUToLLVMTypeConverter &converter,
993+ const triton::intel::TargetInfo &targetInfo,
994+ const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
995+ PatternBenefit benefit)
994996 : ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
995997 LoadStoreConversionBase (targetInfo, axisAnalysisPass) {}
996998
@@ -1128,14 +1130,20 @@ struct StoreOpConversion
11281130 return success ();
11291131
11301132 Location loc = op->getLoc ();
1133+ auto *typeConverter = getTypeConverter ();
11311134 MLIRContext *ctx = rewriter.getContext ();
11321135 Value ptr = op.getPtr ();
1133- Value value = op.getValue ();
1134- Type valueTy = value.getType ();
1136+ Value mask = op.getMask ();
1137+ Value llMask = adaptor.getMask ();
1138+
1139+ // Determine the vectorization size
1140+ Type valueTy = op.getValue ().getType ();
11351141 Type valueElemTy =
11361142 typeConverter->convertType (getElementTypeOrSelf (valueTy));
11371143 SmallVector<Value> ptrElems, maskElems;
1138- unsigned vec = 1 ;
1144+ unsigned vec = getVectorSize (ptr);
1145+ if (llMask)
1146+ vec = std::min<size_t >(vec, getMaskAlignment (mask));
11391147
11401148 if (isTensorPointerType (ptr.getType ())) {
11411149 // fallback to scatter store.
@@ -1146,20 +1154,9 @@ struct StoreOpConversion
11461154 op.getBoundaryCheck ());
11471155 } else {
11481156 Value llPtr = adaptor.getPtr ();
1149- Value llMask = adaptor.getMask ();
1150-
1151- vec = getVectorSize (ptr);
1152-
11531157 ptrElems = unpackLLElements (loc, llPtr, rewriter);
1154-
1155- // Determine the vectorization size
1156- if (llMask) {
1157- Value mask = op.getMask ();
1158+ if (llMask)
11581159 maskElems = unpackLLElements (loc, llMask, rewriter);
1159-
1160- unsigned maskAlign = getMaskAlignment (mask);
1161- vec = std::min (vec, maskAlign);
1162- }
11631160 }
11641161
11651162 Value llValue = adaptor.getValue ();
@@ -1168,7 +1165,7 @@ struct StoreOpConversion
11681165 assert (!maskElems.size () ||
11691166 valueElems.size () == maskElems.size () && " Mask size mismatch" );
11701167
1171- Value mask = redundantDataMask (valueTy, rewriter, loc, targetInfo);
1168+ mask = redundantDataMask (valueTy, rewriter, loc, targetInfo);
11721169 const size_t dtsize =
11731170 std::max<int >(1 , valueElemTy.getIntOrFloatBitWidth () / 8 );
11741171 const size_t valueElemNBits = dtsize * 8 ;
@@ -1247,10 +1244,11 @@ struct AtomicCASOpConversion
12471244 using ConvertTritonGPUOpToLLVMPattern<
12481245 triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
12491246
1250- AtomicCASOpConversion (TritonIntelGPUToLLVMTypeConverter &converter,
1251- const triton::intel::TargetInfo &targetInfo,
1252- ModuleAxisInfoAnalysis &axisAnalysisPass,
1253- PatternBenefit benefit)
1247+ AtomicCASOpConversion (
1248+ TritonIntelGPUToLLVMTypeConverter &converter,
1249+ const triton::intel::TargetInfo &targetInfo,
1250+ const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
1251+ PatternBenefit benefit)
12541252 : ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>(converter,
12551253 benefit),
12561254 LoadStoreConversionBase (targetInfo, axisAnalysisPass) {}
@@ -1364,10 +1362,11 @@ struct AtomicRMWOpConversion
13641362 using ConvertTritonGPUOpToLLVMPattern<
13651363 triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
13661364
1367- AtomicRMWOpConversion (TritonIntelGPUToLLVMTypeConverter &converter,
1368- const triton::intel::TargetInfo &targetInfo,
1369- ModuleAxisInfoAnalysis &axisAnalysisPass,
1370- PatternBenefit benefit)
1365+ AtomicRMWOpConversion (
1366+ TritonIntelGPUToLLVMTypeConverter &converter,
1367+ const triton::intel::TargetInfo &targetInfo,
1368+ const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
1369+ PatternBenefit benefit)
13711370 : ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(converter,
13721371 benefit),
13731372 LoadStoreConversionBase (targetInfo, axisAnalysisPass) {}
@@ -1627,7 +1626,8 @@ struct AtomicRMWOpConversion
16271626void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns (
16281627 TritonIntelGPUToLLVMTypeConverter &typeConverter,
16291628 const TargetInfo &targetInfo, RewritePatternSet &patterns,
1630- ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) {
1629+ const intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
1630+ PatternBenefit benefit) {
16311631 patterns.add <AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
16321632 StoreOpConversion, PrefetchOpConversion>(
16331633 typeConverter, targetInfo, axisInfoAnalysis, benefit);
0 commit comments