From 3826a8bf578cc1821859782c2e8d31676c3d4f32 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 18 Oct 2024 21:16:07 +0000 Subject: [PATCH 1/2] Copy Coalesce pass for customization Signed-off-by: Tiotto, Ettore --- third_party/intel/backend/compiler.py | 2 +- .../include/Dialect/TritonIntelGPU/IR/Utils.h | 30 ++- .../TritonIntelGPU/Transforms/Passes.td | 16 ++ .../TritonIntelGPUTransforms/CMakeLists.txt | 1 + .../lib/TritonIntelGPUTransforms/Coalesce.cpp | 199 ++++++++++++++++++ third_party/intel/triton_xpu.cc | 1 + 6 files changed, 246 insertions(+), 3 deletions(-) create mode 100644 third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index da853d0d09..86948112b9 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -238,7 +238,7 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False) - passes.ttgpuir.add_coalesce(pm) + intel.passes.ttgpuir.add_coalesce(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_thread_locality(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h index 4c0031e2dd..6357d4a8c2 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h @@ -9,11 +9,37 @@ #ifndef TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H #define TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H -#include - +#include "intel/include/Analysis/AxisInfo.h" +#include "mlir/IR/Operation.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include namespace mlir::triton::gpu::intel { + +/// Calculate the optimal number of elements per thread for a given operation +/// along an axis with greatest continuity. +inline unsigned getNumElementsPerThread( + Operation *op, SmallVector order, + mlir::triton::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + Value val = getMemAccessPtr(op); + Type valTy = val.getType(); + auto ty = + isTensorPointerType(valTy) + ? cast(cast(valTy).getPointeeType()) + : cast(valTy); + auto shapePerCTA = getShapePerCTA(ty); + mlir::triton::intel::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); + + unsigned elemNumBits = getElementBitWidth(ty); + unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); + unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]); + unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); + unsigned maxContig = + std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]); + unsigned alignment = std::min(maxMultiple, maxContig); + return std::min(alignment, 128 / elemNumBits); +} + /// Check whether transposed reduction should be performed. /// /// See: https://github.com/intel/intel-xpu-backend-for-triton/issues/1637 diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index 42e386fe29..86a88dd8a1 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -27,6 +27,22 @@ def TritonIntelGPUAccelerateMatmul ]; } +def TritonIntelGPUCoalesce + : Pass<"tritonintelgpu-coalesce", "mlir::ModuleOp"> { + let summary = "Intel Coalesce"; + + let description = [{ + The pass analyses loads/stores with type `tensor>` or + `tt.ptr>` and replaces the layouts of these operations with + coalesced layouts, i.e. cache friendly access patterns. + Layout conversions are inserted before and after the load/store op + to maintain consistency with the rest of the program. + }]; + + let dependentDialects = ["mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect"]; +} + def TritonIntelGPUDistributeToWarps : Pass<"tritonintelgpu-distribute-to-warps", "mlir::ModuleOp"> { let summary = "distribute the thread block workload to the warps"; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index 8c2e290ada..9c02e5752c 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(TritonIntelGPUTransforms AccelerateMatmul.cpp + Coalesce.cpp DistributeToWarps.cpp MatchTargetSize.cpp MaterializeBlockPointer.cpp diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp new file mode 100644 index 0000000000..28e119dc6d --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -0,0 +1,199 @@ +#include "intel/include/Analysis/AxisInfo.h" +#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritonintelgpu-coalesce" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir::triton::gpu::intel { +#define GEN_PASS_DEF_TRITONINTELGPUCOALESCE +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" +} // namespace mlir::triton::gpu::intel + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttgi = mlir::triton::gpu::intel; + +namespace { + +struct CoalescePass + : public ttgi::impl::TritonIntelGPUCoalesceBase { + void + setCoalescedEncoding(tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis, + Operation *op, int numWarps, int threadsPerWarp, + llvm::MapVector &layoutMap) { + Value ptr = getMemAccessPtr(op); + auto refTensorType = cast(ptr.getType()); + + LDBG("Considering op: " << *op); + LLVM_DEBUG({ + DBGS() << "axis info of pointer: "; + axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity(); + SmallVector order = argSort(contiguity); + LDBG("order=[" << triton::join(order, ", ") << "]"); + + auto matchesShape = [&refTensorType](const Value &val) { + auto rttType = dyn_cast(val.getType()); + return rttType && rttType.getShape() == refTensorType.getShape(); + }; + + // The desired divisibility is the maximum divisibility among all dependent + // pointers which have the same shape and order as `ptr`. + llvm::SmallSetVector memAccessesSameOrder; + memAccessesSameOrder.insert(op); + if (ptr.getDefiningOp()) { + for (Operation *use : mlir::multiRootGetSlice(op)) { + Value val = getMemAccessPtr(use); + if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use)) + continue; + auto currOrder = + argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); + if (order == currOrder) { + LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use); + memAccessesSameOrder.insert(use); + } + } + } + + auto shapePerCTA = triton::gpu::getShapePerCTA(refTensorType); + LDBG("shapePerCTA=[" << triton::join(shapePerCTA, ", ") << "]"); + + int numElems = product(shapePerCTA); + int numThreads = numWarps * threadsPerWarp; + + unsigned perThread = + ttgi::getNumElementsPerThread(op, order, axisInfoAnalysis); + LDBG("perThread for op: " << perThread); + + for (Operation *opSameOrder : memAccessesSameOrder) { + if (opSameOrder == op) + continue; + unsigned currPerThread = + ttgi::getNumElementsPerThread(opSameOrder, order, axisInfoAnalysis); + LDBG("perThread for opSameOrder: " << currPerThread); + perThread = std::max(perThread, currPerThread); + } + + perThread = std::min(perThread, std::max(numElems / numThreads, 1)); + LDBG("perThread: " << perThread); + + if (!dyn_cast(op)) { + // For ops that can result in a global memory write, we should enforce + // that each thread handles at most 128 bits, which is the widest + // available vectorized store op; otherwise, the store will have "gaps" + // in the memory write at the warp level, resulting in worse performance. + // For loads, we can expect that the gaps won't matter due to the L1 + // cache. + perThread = std::min(perThread, ttgi::getNumElementsPerThread( + op, order, axisInfoAnalysis)); + } + SmallVector sizePerThread(refTensorType.getRank(), 1); + sizePerThread[order[0]] = perThread; + + auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding()); + layoutMap[op] = triton::gpu::BlockedEncodingAttr::get( + &getContext(), refTensorType.getShape(), sizePerThread, order, numWarps, + threadsPerWarp, CTALayout); + } + + static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + } + + void coalesceOp(Attribute encoding, Operation *op) { + OpBuilder builder(op); + // Convert operands + // For load/store with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` + SmallVector newArgs; + for (auto operand : op->getOperands()) { + auto tensorType = dyn_cast(operand.getType()); + if (tensorType && + !isa(tensorType.getEncoding())) { + Type newType = getNewType(tensorType, encoding); + newArgs.push_back(builder.create( + op->getLoc(), newType, operand)); + } else { + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + bool isAsync = isa(op); + newTypes.push_back(isAsync ? t : getNewType(t, encoding)); + } + + // Construct new op with the new encoding + Operation *newOp = + builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs, + newTypes, op->getAttrs()); + + // Cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); + } + + void runOnOperation() override { + // Run axis info analysis + ModuleOp moduleOp = getOperation(); + tt::intel::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // For each i/o operation, we determine what layout + // the pointers should have for best memory coalescing + llvm::MapVector layoutMap; + moduleOp.walk([&](Operation *curr) { + Value ptr = getMemAccessPtr(curr); + if (!ptr) + return; + // We only convert `tensor>` load/store + bool isPtrTensor = false; + if (auto tensorType = dyn_cast(ptr.getType())) + isPtrTensor = isa(tensorType.getElementType()); + if (!isPtrTensor) + return; + auto mod = curr->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp, + layoutMap); + }); + + // For each memory op that has a layout L1: + // 1. Create a coalesced memory layout L2 of the pointer operands + // 2. Convert all operands from layout L1 to layout L2 + // 3. Create a new memory op that consumes these operands and + // produces a tensor with layout L2 + // 4. Convert the output of this new memory op back to L1 + // 5. Replace all the uses of the original memory op by the new one + for (auto &kv : layoutMap) { + coalesceOp(kv.second, kv.first); + } + } +}; + +} // namespace diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 951de6ce35..201ec17a74 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -82,6 +82,7 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { gpu::intel::createTritonIntelGPURemoveLayoutConversions); ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", gpu::intel::createTritonIntelGPURewriteTensorPointer); + ADD_PASS_WRAPPER_0("add_coalesce", gpu::intel::createTritonIntelGPUCoalesce); ADD_PASS_WRAPPER_OPT_2("add_prefetch_block", gpu::intel::createTritonIntelGPUPrefetchBlock, int, bool); From 7a60088553639f2a705c91ddf8da081c7bddce3f Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 21 Oct 2024 14:26:59 +0000 Subject: [PATCH 2/2] Fix regressions Signed-off-by: Tiotto, Ettore --- third_party/intel/lib/Analysis/AxisInfo.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index aeb8e12b5b..82bd971695 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -1010,8 +1010,12 @@ class MakeTensorPtrOpAxisInfoVisitor final getAxisInfo(triton::MakeTensorPtrOp op, ArrayRef *> operands) override { LDBG("MakeTensorPtrOpAxisInfoVisitor: " << *op); - assert(op.getShape().size() == 2 && operands.size() == 7 && - "MakeTensorPtrOp should have 2D shape"); + + // TODO: Extend to higher dimension tensor pointers. + if (op.getShape().size() != 2) + return AxisInfo(); + + assert(operands.size() == 7 && "MakeTensorPtrOp should have 2D shape"); AxisInfo ptrInfo = operands[0]->getValue(); AxisInfo shapeInfo0 = operands[1]->getValue();