diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index a13834f991..9b1738e8b7 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -235,7 +235,8 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) - intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) + if os.getenv("TRITON_INTEL_REWRITE_TENSOR_POINTER", "0") == "1": + intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False) intel.passes.ttgpuir.add_coalesce(pm) diff --git a/third_party/intel/include/Analysis/AxisInfo.h b/third_party/intel/include/Analysis/AxisInfo.h index b0b90f7d10..5ba786f4b0 100644 --- a/third_party/intel/include/Analysis/AxisInfo.h +++ b/third_party/intel/include/Analysis/AxisInfo.h @@ -12,7 +12,6 @@ namespace mlir::triton::intel { // axis info based on the axis info of all the callers. In the future, we can // perform optimization using function cloning so that each call site will have // unique axis info. - class ModuleAxisInfoAnalysis : public triton::ModuleAxisInfoAnalysis { public: explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp) diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index 09da088e8c..7161dedf7a 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -1030,13 +1030,24 @@ class MakeTensorPtrOpAxisInfoVisitor final strideInfo[dim].getConstantValue() == 1 ? blkShape[dim] : 1); divisibility.push_back( contiguity[dim] > 1 - ? std::min(ptrDivisibility, - strideInfo[dim == 0 ? 1 : 0].getDivisibility()[0]) + ? std::min( + ptrDivisibility, + (rank == 2 ? strideInfo[dim == 0 ? 1 : 0] : strideInfo[dim]) + .getDivisibility()[0]) : 1); constancy.push_back(1); } - return AxisInfo(contiguity, divisibility, constancy); + auto axisInfo = AxisInfo(contiguity, divisibility, constancy); + + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo.print(os); + LDBG("-- " << axisStr); + }); + + return axisInfo; } }; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 426726feba..cf2475c41b 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -161,29 +161,33 @@ getWarpsPerCTA(const ArrayRef tensorShape, // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { - explicit LoadStoreConversionBase(const triton::intel::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass) + explicit LoadStoreConversionBase( + const triton::intel::TargetInfo &targetInfo, + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass) : targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) {} unsigned getContiguity(Value ptr) const { - auto tensorTy = dyn_cast(ptr.getType()); - if (!tensorTy) - return 1; - return axisAnalysisPass.getPtrContiguity(ptr); + return const_cast(axisAnalysisPass) + .getPtrContiguity(ptr); } unsigned getVectorSize(Value ptr) const { - auto tensorTy = dyn_cast(ptr.getType()); + auto tensorTy = getRankedTensorType(ptr.getType()); if (!tensorTy) return 1; - auto contiguity = getContiguity(ptr); - auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); + + unsigned contiguity = getContiguity(ptr); + unsigned pointeeBitWidth = + isTensorPointerType(ptr.getType()) + ? tensorTy.getElementType().getIntOrFloatBitWidth() + : triton::getPointeeBitWidth(tensorTy); // The maximum vector size is 128 bits. return std::min(128 / pointeeBitWidth, contiguity); } unsigned getMaskAlignment(Value mask) const { - return axisAnalysisPass.getMaskAlignment(mask); + return const_cast(axisAnalysisPass) + .getMaskAlignment(mask); } std::tuple, SmallVector, SmallVector> @@ -289,7 +293,7 @@ struct LoadStoreConversionBase { } protected: - ModuleAxisInfoAnalysis &axisAnalysisPass; + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass; const triton::intel::TargetInfo &targetInfo; }; @@ -299,10 +303,11 @@ struct PrefetchOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::gpu::intel::PrefetchOp>::ConvertTritonGPUOpToLLVMPattern; - PrefetchOpConversion(TritonGPUToLLVMTypeConverter &converter, - const triton::intel::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + PrefetchOpConversion( + TritonGPUToLLVMTypeConverter &converter, + const triton::intel::TargetInfo &targetInfo, + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} @@ -475,10 +480,11 @@ struct LoadOpConversion using ValueTable = std::map, Value>; - LoadOpConversion(TritonIntelGPUToLLVMTypeConverter &converter, - const triton::intel::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + LoadOpConversion( + TritonIntelGPUToLLVMTypeConverter &converter, + const triton::intel::TargetInfo &targetInfo, + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} @@ -824,37 +830,32 @@ struct LoadOpConversion Location loc = op->getLoc(); auto typeConverter = getTypeConverter(); MLIRContext *ctx = rewriter.getContext(); + Value ptr = op.getPtr(); + Value mask = op.getMask(); + Value llMask = adaptor.getMask(); // Determine the vectorization size Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(op.getType())); unsigned numElems = getTotalElemsPerThread(op.getType()); - unsigned vec = 1; + unsigned vec = getVectorSize(ptr); + if (llMask) + vec = std::min(vec, getMaskAlignment(mask)); SmallVector ptrElems, maskElems, otherElems; bool otherIsSplatConstInt = false; int64_t splatVal = 0; - if (isTensorPointerType(op.getPtr().getType())) { - // TODO: (johnlu) set the vector size > 1; Need to prove the memory is - // contiguous on the fast changing dim when fallback to gather load. + if (isTensorPointerType(ptr.getType())) { + // fallback to gather load. auto tensorType = cast(op.getType()); std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr( loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter, op.getBoundaryCheck(), op.getPadding()); } else { - // original values - Value ptr = op.getPtr(); Value other = op.getOther(); - Value mask = op.getMask(); - - // adaptor values Value llPtr = adaptor.getPtr(); - Value llMask = adaptor.getMask(); Value llOther = adaptor.getOther(); - vec = getVectorSize(ptr); - if (llMask) - vec = std::min(vec, getMaskAlignment(mask)); // Get the LLVM values for pointers ptrElems = unpackLLElements(loc, llPtr, rewriter); @@ -987,10 +988,11 @@ struct StoreOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern; - StoreOpConversion(TritonIntelGPUToLLVMTypeConverter &converter, - const triton::intel::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + StoreOpConversion( + TritonIntelGPUToLLVMTypeConverter &converter, + const triton::intel::TargetInfo &targetInfo, + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} @@ -1128,14 +1130,20 @@ struct StoreOpConversion return success(); Location loc = op->getLoc(); + auto *typeConverter = getTypeConverter(); MLIRContext *ctx = rewriter.getContext(); Value ptr = op.getPtr(); - Value value = op.getValue(); - Type valueTy = value.getType(); + Value mask = op.getMask(); + Value llMask = adaptor.getMask(); + + // Determine the vectorization size + Type valueTy = op.getValue().getType(); Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(valueTy)); SmallVector ptrElems, maskElems; - unsigned vec = 1; + unsigned vec = getVectorSize(ptr); + if (llMask) + vec = std::min(vec, getMaskAlignment(mask)); if (isTensorPointerType(ptr.getType())) { // fallback to scatter store. @@ -1146,20 +1154,9 @@ struct StoreOpConversion op.getBoundaryCheck()); } else { Value llPtr = adaptor.getPtr(); - Value llMask = adaptor.getMask(); - - vec = getVectorSize(ptr); - ptrElems = unpackLLElements(loc, llPtr, rewriter); - - // Determine the vectorization size - if (llMask) { - Value mask = op.getMask(); + if (llMask) maskElems = unpackLLElements(loc, llMask, rewriter); - - unsigned maskAlign = getMaskAlignment(mask); - vec = std::min(vec, maskAlign); - } } Value llValue = adaptor.getValue(); @@ -1168,7 +1165,7 @@ struct StoreOpConversion assert(!maskElems.size() || valueElems.size() == maskElems.size() && "Mask size mismatch"); - Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); const size_t valueElemNBits = dtsize * 8; @@ -1247,10 +1244,11 @@ struct AtomicCASOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern; - AtomicCASOpConversion(TritonIntelGPUToLLVMTypeConverter &converter, - const triton::intel::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + AtomicCASOpConversion( + TritonIntelGPUToLLVMTypeConverter &converter, + const triton::intel::TargetInfo &targetInfo, + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} @@ -1364,10 +1362,11 @@ struct AtomicRMWOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern; - AtomicRMWOpConversion(TritonIntelGPUToLLVMTypeConverter &converter, - const triton::intel::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + AtomicRMWOpConversion( + TritonIntelGPUToLLVMTypeConverter &converter, + const triton::intel::TargetInfo &targetInfo, + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} @@ -1627,7 +1626,8 @@ struct AtomicRMWOpConversion void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns( TritonIntelGPUToLLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, - ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { + const intel::ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) { patterns.add( typeConverter, targetInfo, axisInfoAnalysis, benefit); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h index 36223e3245..40116a17ca 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -53,7 +53,7 @@ void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter, void populateLoadStoreOpToLLVMPatterns( TritonIntelGPUToLLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, - ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit); + const ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit); void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,