Skip to content

Commit 5355880

Browse files
authored
[NFI] Copy Coalesce pass for further customization (#2514)
Copy the coalescing pass in the intel specific directory. Subsequent PRs will add support for block pointers to this pass. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 8d72744 commit 5355880

File tree

7 files changed

+252
-5
lines changed

7 files changed

+252
-5
lines changed

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def make_ttgir(mod, metadata, opt, properties):
238238
intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm)
239239
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False)
240240

241-
passes.ttgpuir.add_coalesce(pm)
241+
intel.passes.ttgpuir.add_coalesce(pm)
242242
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
243243
passes.ttgpuir.add_optimize_thread_locality(pm)
244244
passes.ttgpuir.add_optimize_dot_operands(pm, True)

third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,37 @@
99
#ifndef TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H
1010
#define TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H
1111

12-
#include <optional>
13-
12+
#include "intel/include/Analysis/AxisInfo.h"
13+
#include "mlir/IR/Operation.h"
14+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1415
#include <triton/Tools/Sys/GetEnv.hpp>
1516

1617
namespace mlir::triton::gpu::intel {
18+
19+
/// Calculate the optimal number of elements per thread for a given operation
20+
/// along an axis with greatest continuity.
21+
inline unsigned getNumElementsPerThread(
22+
Operation *op, SmallVector<unsigned> order,
23+
mlir::triton::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
24+
Value val = getMemAccessPtr(op);
25+
Type valTy = val.getType();
26+
auto ty =
27+
isTensorPointerType(valTy)
28+
? cast<RankedTensorType>(cast<PointerType>(valTy).getPointeeType())
29+
: cast<RankedTensorType>(valTy);
30+
auto shapePerCTA = getShapePerCTA(ty);
31+
mlir::triton::intel::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val);
32+
33+
unsigned elemNumBits = getElementBitWidth(ty);
34+
unsigned elemNumBytes = std::max(elemNumBits / 8, 1u);
35+
unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]);
36+
unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u);
37+
unsigned maxContig =
38+
std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]);
39+
unsigned alignment = std::min(maxMultiple, maxContig);
40+
return std::min(alignment, 128 / elemNumBits);
41+
}
42+
1743
/// Check whether transposed reduction should be performed.
1844
///
1945
/// See: https://github.com/intel/intel-xpu-backend-for-triton/issues/1637

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,22 @@ def TritonIntelGPUAccelerateMatmul
2727
];
2828
}
2929

30+
def TritonIntelGPUCoalesce
31+
: Pass<"tritonintelgpu-coalesce", "mlir::ModuleOp"> {
32+
let summary = "Intel Coalesce";
33+
34+
let description = [{
35+
The pass analyses loads/stores with type `tensor<tt.ptr<>>` or
36+
`tt.ptr<tensor<>>` and replaces the layouts of these operations with
37+
coalesced layouts, i.e. cache friendly access patterns.
38+
Layout conversions are inserted before and after the load/store op
39+
to maintain consistency with the rest of the program.
40+
}];
41+
42+
let dependentDialects = ["mlir::triton::TritonDialect",
43+
"mlir::triton::gpu::TritonGPUDialect"];
44+
}
45+
3046
def TritonIntelGPUDistributeToWarps
3147
: Pass<"tritonintelgpu-distribute-to-warps", "mlir::ModuleOp"> {
3248
let summary = "distribute the thread block workload to the warps";

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,8 +1010,12 @@ class MakeTensorPtrOpAxisInfoVisitor final
10101010
getAxisInfo(triton::MakeTensorPtrOp op,
10111011
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
10121012
LDBG("MakeTensorPtrOpAxisInfoVisitor: " << *op);
1013-
assert(op.getShape().size() == 2 && operands.size() == 7 &&
1014-
"MakeTensorPtrOp should have 2D shape");
1013+
1014+
// TODO: Extend to higher dimension tensor pointers.
1015+
if (op.getShape().size() != 2)
1016+
return AxisInfo();
1017+
1018+
assert(operands.size() == 7 && "MakeTensorPtrOp should have 2D shape");
10151019

10161020
AxisInfo ptrInfo = operands[0]->getValue();
10171021
AxisInfo shapeInfo0 = operands[1]->getValue();

third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_triton_library(TritonIntelGPUTransforms
22
AccelerateMatmul.cpp
3+
Coalesce.cpp
34
DistributeToWarps.cpp
45
MatchTargetSize.cpp
56
MaterializeBlockPointer.cpp
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
#include "intel/include/Analysis/AxisInfo.h"
2+
#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
3+
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
4+
#include "mlir/Analysis/SliceAnalysis.h"
5+
#include "mlir/Support/LLVM.h"
6+
#include "triton/Dialect/Triton/IR/Utility.h"
7+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
8+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
9+
#include "triton/Tools/StrUtil.h"
10+
#include "llvm/Support/Debug.h"
11+
12+
#define DEBUG_TYPE "tritonintelgpu-coalesce"
13+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
14+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
15+
16+
namespace mlir::triton::gpu::intel {
17+
#define GEN_PASS_DEF_TRITONINTELGPUCOALESCE
18+
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"
19+
} // namespace mlir::triton::gpu::intel
20+
21+
using namespace mlir;
22+
namespace tt = mlir::triton;
23+
namespace ttgi = mlir::triton::gpu::intel;
24+
25+
namespace {
26+
27+
struct CoalescePass
28+
: public ttgi::impl::TritonIntelGPUCoalesceBase<CoalescePass> {
29+
void
30+
setCoalescedEncoding(tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
31+
Operation *op, int numWarps, int threadsPerWarp,
32+
llvm::MapVector<Operation *, Attribute> &layoutMap) {
33+
Value ptr = getMemAccessPtr(op);
34+
auto refTensorType = cast<RankedTensorType>(ptr.getType());
35+
36+
LDBG("Considering op: " << *op);
37+
LLVM_DEBUG({
38+
DBGS() << "axis info of pointer: ";
39+
axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs());
40+
llvm::dbgs() << "\n";
41+
});
42+
43+
auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity();
44+
SmallVector<unsigned> order = argSort(contiguity);
45+
LDBG("order=[" << triton::join(order, ", ") << "]");
46+
47+
auto matchesShape = [&refTensorType](const Value &val) {
48+
auto rttType = dyn_cast<RankedTensorType>(val.getType());
49+
return rttType && rttType.getShape() == refTensorType.getShape();
50+
};
51+
52+
// The desired divisibility is the maximum divisibility among all dependent
53+
// pointers which have the same shape and order as `ptr`.
54+
llvm::SmallSetVector<Operation *, 32> memAccessesSameOrder;
55+
memAccessesSameOrder.insert(op);
56+
if (ptr.getDefiningOp()) {
57+
for (Operation *use : mlir::multiRootGetSlice(op)) {
58+
Value val = getMemAccessPtr(use);
59+
if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use))
60+
continue;
61+
auto currOrder =
62+
argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity());
63+
if (order == currOrder) {
64+
LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use);
65+
memAccessesSameOrder.insert(use);
66+
}
67+
}
68+
}
69+
70+
auto shapePerCTA = triton::gpu::getShapePerCTA(refTensorType);
71+
LDBG("shapePerCTA=[" << triton::join(shapePerCTA, ", ") << "]");
72+
73+
int numElems = product<int64_t>(shapePerCTA);
74+
int numThreads = numWarps * threadsPerWarp;
75+
76+
unsigned perThread =
77+
ttgi::getNumElementsPerThread(op, order, axisInfoAnalysis);
78+
LDBG("perThread for op: " << perThread);
79+
80+
for (Operation *opSameOrder : memAccessesSameOrder) {
81+
if (opSameOrder == op)
82+
continue;
83+
unsigned currPerThread =
84+
ttgi::getNumElementsPerThread(opSameOrder, order, axisInfoAnalysis);
85+
LDBG("perThread for opSameOrder: " << currPerThread);
86+
perThread = std::max(perThread, currPerThread);
87+
}
88+
89+
perThread = std::min<int>(perThread, std::max(numElems / numThreads, 1));
90+
LDBG("perThread: " << perThread);
91+
92+
if (!dyn_cast<triton::LoadOp>(op)) {
93+
// For ops that can result in a global memory write, we should enforce
94+
// that each thread handles at most 128 bits, which is the widest
95+
// available vectorized store op; otherwise, the store will have "gaps"
96+
// in the memory write at the warp level, resulting in worse performance.
97+
// For loads, we can expect that the gaps won't matter due to the L1
98+
// cache.
99+
perThread = std::min<int>(perThread, ttgi::getNumElementsPerThread(
100+
op, order, axisInfoAnalysis));
101+
}
102+
SmallVector<unsigned> sizePerThread(refTensorType.getRank(), 1);
103+
sizePerThread[order[0]] = perThread;
104+
105+
auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding());
106+
layoutMap[op] = triton::gpu::BlockedEncodingAttr::get(
107+
&getContext(), refTensorType.getShape(), sizePerThread, order, numWarps,
108+
threadsPerWarp, CTALayout);
109+
}
110+
111+
static Type getNewType(Type type, Attribute encoding) {
112+
RankedTensorType tensorType = cast<RankedTensorType>(type);
113+
return RankedTensorType::get(tensorType.getShape(),
114+
tensorType.getElementType(), encoding);
115+
}
116+
117+
void coalesceOp(Attribute encoding, Operation *op) {
118+
OpBuilder builder(op);
119+
// Convert operands
120+
// For load/store with tensor pointers, we don't have to change the
121+
// operands' type, we do this by changing the outputs' type of
122+
// `make_tensor_ptr`
123+
SmallVector<Value, 4> newArgs;
124+
for (auto operand : op->getOperands()) {
125+
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
126+
if (tensorType &&
127+
!isa<triton::gpu::SharedEncodingAttr>(tensorType.getEncoding())) {
128+
Type newType = getNewType(tensorType, encoding);
129+
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
130+
op->getLoc(), newType, operand));
131+
} else {
132+
newArgs.push_back(operand);
133+
}
134+
}
135+
136+
// Convert output types
137+
SmallVector<Type, 4> newTypes;
138+
for (auto t : op->getResultTypes()) {
139+
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
140+
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
141+
}
142+
143+
// Construct new op with the new encoding
144+
Operation *newOp =
145+
builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs,
146+
newTypes, op->getAttrs());
147+
148+
// Cast the results back to the original layout
149+
for (size_t i = 0; i < op->getNumResults(); i++) {
150+
Value newResult = newOp->getResult(i);
151+
if (newTypes[i] != op->getResultTypes()[i]) {
152+
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
153+
op->getLoc(), op->getResult(i).getType(), newResult);
154+
}
155+
op->getResult(i).replaceAllUsesWith(newResult);
156+
}
157+
op->erase();
158+
}
159+
160+
void runOnOperation() override {
161+
// Run axis info analysis
162+
ModuleOp moduleOp = getOperation();
163+
tt::intel::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
164+
165+
// For each i/o operation, we determine what layout
166+
// the pointers should have for best memory coalescing
167+
llvm::MapVector<Operation *, Attribute> layoutMap;
168+
moduleOp.walk([&](Operation *curr) {
169+
Value ptr = getMemAccessPtr(curr);
170+
if (!ptr)
171+
return;
172+
// We only convert `tensor<tt.ptr<>>` load/store
173+
bool isPtrTensor = false;
174+
if (auto tensorType = dyn_cast<RankedTensorType>(ptr.getType()))
175+
isPtrTensor = isa<tt::PointerType>(tensorType.getElementType());
176+
if (!isPtrTensor)
177+
return;
178+
auto mod = curr->getParentOfType<ModuleOp>();
179+
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
180+
int threadsPerWarp =
181+
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
182+
setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp,
183+
layoutMap);
184+
});
185+
186+
// For each memory op that has a layout L1:
187+
// 1. Create a coalesced memory layout L2 of the pointer operands
188+
// 2. Convert all operands from layout L1 to layout L2
189+
// 3. Create a new memory op that consumes these operands and
190+
// produces a tensor with layout L2
191+
// 4. Convert the output of this new memory op back to L1
192+
// 5. Replace all the uses of the original memory op by the new one
193+
for (auto &kv : layoutMap) {
194+
coalesceOp(kv.second, kv.first);
195+
}
196+
}
197+
};
198+
199+
} // namespace

third_party/intel/triton_xpu.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) {
8282
gpu::intel::createTritonIntelGPURemoveLayoutConversions);
8383
ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer",
8484
gpu::intel::createTritonIntelGPURewriteTensorPointer);
85+
ADD_PASS_WRAPPER_0("add_coalesce", gpu::intel::createTritonIntelGPUCoalesce);
8586
ADD_PASS_WRAPPER_OPT_2("add_prefetch_block",
8687
gpu::intel::createTritonIntelGPUPrefetchBlock, int,
8788
bool);

0 commit comments

Comments
 (0)