Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def make_ttgir(mod, metadata, opt, properties):
intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm)
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False)

passes.ttgpuir.add_coalesce(pm)
intel.passes.ttgpuir.add_coalesce(pm)
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
Expand Down
30 changes: 28 additions & 2 deletions third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,37 @@
#ifndef TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H
#define TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H

#include <optional>

#include "intel/include/Analysis/AxisInfo.h"
#include "mlir/IR/Operation.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <triton/Tools/Sys/GetEnv.hpp>

namespace mlir::triton::gpu::intel {

/// Calculate the optimal number of elements per thread for a given operation
/// along an axis with greatest continuity.
inline unsigned getNumElementsPerThread(
Operation *op, SmallVector<unsigned> order,
mlir::triton::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
Value val = getMemAccessPtr(op);
Type valTy = val.getType();
auto ty =
isTensorPointerType(valTy)
? cast<RankedTensorType>(cast<PointerType>(valTy).getPointeeType())
: cast<RankedTensorType>(valTy);
auto shapePerCTA = getShapePerCTA(ty);
mlir::triton::intel::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val);

unsigned elemNumBits = getElementBitWidth(ty);
unsigned elemNumBytes = std::max(elemNumBits / 8, 1u);
unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]);
unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u);
unsigned maxContig =
std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]);
unsigned alignment = std::min(maxMultiple, maxContig);
return std::min(alignment, 128 / elemNumBits);
}

/// Check whether transposed reduction should be performed.
///
/// See: https://github.com/intel/intel-xpu-backend-for-triton/issues/1637
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ def TritonIntelGPUAccelerateMatmul
];
}

def TritonIntelGPUCoalesce
: Pass<"tritonintelgpu-coalesce", "mlir::ModuleOp"> {
let summary = "Intel Coalesce";

let description = [{
The pass analyses loads/stores with type `tensor<tt.ptr<>>` or
`tt.ptr<tensor<>>` and replaces the layouts of these operations with
coalesced layouts, i.e. cache friendly access patterns.
Layout conversions are inserted before and after the load/store op
to maintain consistency with the rest of the program.
}];

let dependentDialects = ["mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect"];
}

def TritonIntelGPUDistributeToWarps
: Pass<"tritonintelgpu-distribute-to-warps", "mlir::ModuleOp"> {
let summary = "distribute the thread block workload to the warps";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_triton_library(TritonIntelGPUTransforms
AccelerateMatmul.cpp
Coalesce.cpp
DistributeToWarps.cpp
MatchTargetSize.cpp
MaterializeBlockPointer.cpp
Expand Down
199 changes: 199 additions & 0 deletions third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
#include "intel/include/Analysis/AxisInfo.h"
#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Support/LLVM.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "tritonintelgpu-coalesce"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

namespace mlir::triton::gpu::intel {
#define GEN_PASS_DEF_TRITONINTELGPUCOALESCE
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"
} // namespace mlir::triton::gpu::intel

using namespace mlir;
namespace tt = mlir::triton;
namespace ttgi = mlir::triton::gpu::intel;

namespace {

struct CoalescePass
: public ttgi::impl::TritonIntelGPUCoalesceBase<CoalescePass> {
void
setCoalescedEncoding(tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
Operation *op, int numWarps, int threadsPerWarp,
llvm::MapVector<Operation *, Attribute> &layoutMap) {
Value ptr = getMemAccessPtr(op);
auto refTensorType = cast<RankedTensorType>(ptr.getType());

LDBG("Considering op: " << *op);
LLVM_DEBUG({
DBGS() << "axis info of pointer: ";
axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs());
llvm::dbgs() << "\n";
});

auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity();
SmallVector<unsigned> order = argSort(contiguity);
LDBG("order=[" << triton::join(order, ", ") << "]");

auto matchesShape = [&refTensorType](const Value &val) {
auto rttType = dyn_cast<RankedTensorType>(val.getType());
return rttType && rttType.getShape() == refTensorType.getShape();
};

// The desired divisibility is the maximum divisibility among all dependent
// pointers which have the same shape and order as `ptr`.
llvm::SmallSetVector<Operation *, 32> memAccessesSameOrder;
memAccessesSameOrder.insert(op);
if (ptr.getDefiningOp()) {
for (Operation *use : mlir::multiRootGetSlice(op)) {
Value val = getMemAccessPtr(use);
if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use))
continue;
auto currOrder =
argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity());
if (order == currOrder) {
LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use);
memAccessesSameOrder.insert(use);
}
}
}

auto shapePerCTA = triton::gpu::getShapePerCTA(refTensorType);
LDBG("shapePerCTA=[" << triton::join(shapePerCTA, ", ") << "]");

int numElems = product<int64_t>(shapePerCTA);
int numThreads = numWarps * threadsPerWarp;

unsigned perThread =
ttgi::getNumElementsPerThread(op, order, axisInfoAnalysis);
LDBG("perThread for op: " << perThread);

for (Operation *opSameOrder : memAccessesSameOrder) {
if (opSameOrder == op)
continue;
unsigned currPerThread =
ttgi::getNumElementsPerThread(opSameOrder, order, axisInfoAnalysis);
LDBG("perThread for opSameOrder: " << currPerThread);
perThread = std::max(perThread, currPerThread);
}

perThread = std::min<int>(perThread, std::max(numElems / numThreads, 1));
LDBG("perThread: " << perThread);

if (!dyn_cast<triton::LoadOp>(op)) {
// For ops that can result in a global memory write, we should enforce
// that each thread handles at most 128 bits, which is the widest
// available vectorized store op; otherwise, the store will have "gaps"
// in the memory write at the warp level, resulting in worse performance.
// For loads, we can expect that the gaps won't matter due to the L1
// cache.
perThread = std::min<int>(perThread, ttgi::getNumElementsPerThread(
op, order, axisInfoAnalysis));
}
SmallVector<unsigned> sizePerThread(refTensorType.getRank(), 1);
sizePerThread[order[0]] = perThread;

auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding());
layoutMap[op] = triton::gpu::BlockedEncodingAttr::get(
&getContext(), refTensorType.getShape(), sizePerThread, order, numWarps,
threadsPerWarp, CTALayout);
}

static Type getNewType(Type type, Attribute encoding) {
RankedTensorType tensorType = cast<RankedTensorType>(type);
return RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
}

void coalesceOp(Attribute encoding, Operation *op) {
OpBuilder builder(op);
// Convert operands
// For load/store with tensor pointers, we don't have to change the
// operands' type, we do this by changing the outputs' type of
// `make_tensor_ptr`
SmallVector<Value, 4> newArgs;
for (auto operand : op->getOperands()) {
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
if (tensorType &&
!isa<triton::gpu::SharedEncodingAttr>(tensorType.getEncoding())) {
Type newType = getNewType(tensorType, encoding);
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, operand));
} else {
newArgs.push_back(operand);
}
}

// Convert output types
SmallVector<Type, 4> newTypes;
for (auto t : op->getResultTypes()) {
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
}

// Construct new op with the new encoding
Operation *newOp =
builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs,
newTypes, op->getAttrs());

// Cast the results back to the original layout
for (size_t i = 0; i < op->getNumResults(); i++) {
Value newResult = newOp->getResult(i);
if (newTypes[i] != op->getResultTypes()[i]) {
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), op->getResult(i).getType(), newResult);
}
op->getResult(i).replaceAllUsesWith(newResult);
}
op->erase();
}

void runOnOperation() override {
// Run axis info analysis
ModuleOp moduleOp = getOperation();
tt::intel::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);

// For each i/o operation, we determine what layout
// the pointers should have for best memory coalescing
llvm::MapVector<Operation *, Attribute> layoutMap;
moduleOp.walk([&](Operation *curr) {
Value ptr = getMemAccessPtr(curr);
if (!ptr)
return;
// We only convert `tensor<tt.ptr<>>` load/store
bool isPtrTensor = false;
if (auto tensorType = dyn_cast<RankedTensorType>(ptr.getType()))
isPtrTensor = isa<tt::PointerType>(tensorType.getElementType());
if (!isPtrTensor)
return;
auto mod = curr->getParentOfType<ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int threadsPerWarp =
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp,
layoutMap);
});

// For each memory op that has a layout L1:
// 1. Create a coalesced memory layout L2 of the pointer operands
// 2. Convert all operands from layout L1 to layout L2
// 3. Create a new memory op that consumes these operands and
// produces a tensor with layout L2
// 4. Convert the output of this new memory op back to L1
// 5. Replace all the uses of the original memory op by the new one
for (auto &kv : layoutMap) {
coalesceOp(kv.second, kv.first);
}
}
};

} // namespace
1 change: 1 addition & 0 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) {
gpu::intel::createTritonIntelGPURemoveLayoutConversions);
ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer",
gpu::intel::createTritonIntelGPURewriteTensorPointer);
ADD_PASS_WRAPPER_0("add_coalesce", gpu::intel::createTritonIntelGPUCoalesce);
ADD_PASS_WRAPPER_OPT_2("add_prefetch_block",
gpu::intel::createTritonIntelGPUPrefetchBlock, int,
bool);
Expand Down
Loading