Skip to content

Commit dab71a1

Browse files
committed
Fix multi batch correctness
1 parent 06eef9d commit dab71a1

File tree

6 files changed

+44
-30
lines changed

6 files changed

+44
-30
lines changed

python/src/ir.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1622,7 +1622,7 @@ void init_triton_ir(py::module &&m) {
16221622
if (haveDump) {
16231623
auto printingFlags = OpPrintingFlags();
16241624
printingFlags.elideLargeElementsAttrs(16);
1625-
printingFlags.enableDebugInfo();
1625+
// printingFlags.enableDebugInfo();
16261626
auto printAlways = [funcToDump](Pass *, Operation *op) -> bool {
16271627
if (funcToDump.empty())
16281628
return true;

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,6 @@ DpasEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
180180
elemsPerThread[rank - 2] = sizePerThread[rank - 2] * tilesRow;
181181
elemsPerThread[rank - 1] = sizePerThread[rank - 1] * tilesCol;
182182

183-
// if (rank == 3)
184-
// std::cout << "elemsPerThread: " << elemsPerThread[0] << ", " <<
185-
// elemsPerThread[1] << ", " << elemsPerThread[2] << std::endl;
186-
// else
187-
// std::cout << "elemsPerThread: " << elemsPerThread[0] << ", " <<
188-
// elemsPerThread[1] << std::endl;
189183
return elemsPerThread;
190184
}
191185

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "../TritonGPUToLLVMBase.h"
22
#include "../Utility.h"
33
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
4+
#include "mlir/Support/LLVM.h"
5+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
46
#include "llvm/Support/ErrorHandling.h"
57

68
using ValueTable = std::map<std::array<int, 3>, Value>;
@@ -16,16 +18,18 @@ template <unsigned opIdx> class DpasMatmulLoader {
1618
DpasMatmulLoader(DpasEncodingAttr dpasLayout, MemDescType descTy,
1719
unsigned warpsPerTile, ArrayRef<Value> smemStrides,
1820
const SmallVector<unsigned> &warpShape,
21+
SmallVector<Value> multiDimWarpId,
1922
ConversionPatternRewriter &rewriter,
2023
const LLVMTypeConverter *typeConverter, Location loc)
2124
: dpasLayout(dpasLayout), descTy(descTy), smemStrides(smemStrides),
22-
rewriter(rewriter), loc(loc) {
25+
multiDimWarpId(multiDimWarpId), rewriter(rewriter), loc(loc) {
2326
static_assert(opIdx == 0 || opIdx == 1);
2427

2528
size_t rank = warpShape.size();
26-
unsigned kDim = (opIdx == 0) ? rank - 1 : rank - 2;
27-
unsigned nonKDim = (opIdx == 0) ? rank - 2 : rank - 1;
28-
repBatchDimStride = rank == 3 ? smemStrides[0] : i32_val(1);
29+
unsigned kDim = opIdx ? rank - 2 : rank - 1;
30+
unsigned nonKDim = opIdx ? rank - 1 : rank - 2;
31+
// Assume that smem is create with layout offset {2, 1, 0}
32+
repBatchDimStride = smemStrides[0];
2933
repKDimStride = mul(i32_val(warpShape[kDim]), smemStrides[kDim]);
3034
repNonKDimStride =
3135
mul(i32_val(warpShape[nonKDim] * warpsPerTile), smemStrides[nonKDim]);
@@ -60,6 +64,7 @@ template <unsigned opIdx> class DpasMatmulLoader {
6064
MemDescType descTy;
6165

6266
SmallVector<Value> smemStrides;
67+
SmallVector<Value> multiDimWarpId;
6368
Value repBatchDimStride;
6469
Value repNonKDimStride;
6570
Value repKDimStride;
@@ -133,6 +138,7 @@ DpasMatmulLoader<opIdx>::computeLdsMatOffs(Value warpId, Value laneId,
133138
SmallVector<unsigned> instShape = opIdx == 0 ? dpasLayout.getDPASInstShapeA()
134139
: dpasLayout.getDPASInstShapeB();
135140
ArrayRef<int64_t> shareMemoryShape = descTy.getShape();
141+
SmallVector<int64_t> shapePerCTA = getShapePerCTA(descTy);
136142

137143
SmallVector<Value> offs(numPtrs);
138144
const unsigned repClusterSize =
@@ -160,8 +166,8 @@ DpasMatmulLoader<opIdx>::computeLdsMatOffs(Value warpId, Value laneId,
160166

161167
// round the offset into the tensor's shape limitation. (Rounded
162168
// broadcast)
163-
iBase = urem(iBase, i32_val(shareMemoryShape[0]));
164-
jBase = urem(jBase, i32_val(shareMemoryShape[1]));
169+
iBase = urem(iBase, i32_val(shareMemoryShape[rank - 2]));
170+
jBase = urem(jBase, i32_val(shareMemoryShape[rank - 1]));
165171

166172
// inner index offset
167173
Value jOff = zeroVal;
@@ -170,10 +176,17 @@ DpasMatmulLoader<opIdx>::computeLdsMatOffs(Value warpId, Value laneId,
170176
jOff = add(jOff, udiv(cSwizzleOffset, vecVal));
171177
jOff = mul(xor_(jOff, phase), vecVal);
172178

173-
Value i = add(mul(iBase, smemStrides[0]), iOff);
174-
Value j = add(mul(jBase, smemStrides[1]), jOff);
179+
Value i = add(mul(iBase, smemStrides[rank - 2]), iOff);
180+
Value j = add(mul(jBase, smemStrides[rank - 1]), jOff);
175181

176-
offs[index++] = add(i, j);
182+
Value baseOff;
183+
if (shapePerCTA.size() == 3 && shapePerCTA[0] > 1) {
184+
Value batchOffset =
185+
mul(multiDimWarpId[0], i32_val(shapePerCTA[1] * shapePerCTA[2]));
186+
offs[index++] = add(batchOffset, add(i, j));
187+
} else {
188+
offs[index++] = add(i, j);
189+
}
177190
}
178191
}
179192
}
@@ -190,13 +203,14 @@ Value DpasMatmulLoader<opIdx>::loadMatrix(
190203
llvm::any_of(structTy.getBody(), [&](Type ty) { return ty == elemTy; }) &&
191204
"The struct should have the same element types.");
192205

193-
Value offsetBatch = mul(i32_val(repBatch), repBatchDimStride);
194206
Value offsetOuter = mul(i32_val(repOuter), repNonKDimStride);
195207
Value offsetInner = mul(i32_val(repInner), repKDimStride);
196208
Value offset = add(offsetOuter, offsetInner);
197-
// FIXME: repBatchSize and
209+
SmallVector<unsigned> warpsPerCTA = dpasLayout.getWarpsPerCTA();
210+
// 3DTODO: check if repBatch * warpsPerCTA[0] is correct for the offset.
198211
if (repBatch > 0) {
199-
Value offsetBatch = mul(i32_val(repBatch), repBatchDimStride);
212+
Value offsetBatch =
213+
mul(i32_val(repBatch * warpsPerCTA[0]), repBatchDimStride);
200214
offset = add(offset, offsetBatch);
201215
}
202216

@@ -260,7 +274,8 @@ template <unsigned opIdx>
260274
std::function<void(int, int, int)>
261275
getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
262276
DpasEncodingAttr dpasLayout, unsigned warpsPerTile,
263-
SmallVector<unsigned> instrShape, Value warpId,
277+
SmallVector<unsigned> shapePerWarp,
278+
SmallVector<Value> multiDimWarpId, Value warpId,
264279
Value outerWarpDim, Value laneId, ValueTable &vals,
265280
const LLVMTypeConverter *typeConverter,
266281
ConversionPatternRewriter &rewriter, Location loc) {
@@ -271,13 +286,14 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
271286

272287
auto sharedLayout = cast<SharedEncodingAttr>(descTy.getEncoding());
273288
ArrayRef<unsigned> order = sharedLayout.getOrder();
289+
unsigned rank = order.size();
274290

275291
// (a, b) is the coordinate.
276-
auto load = [=, &rewriter, &smemObj, &instrShape, &vals](int batch, int outer,
277-
int inner) {
278-
DpasMatmulLoader<opIdx> loader(dpasLayout, descTy, warpsPerTile,
279-
smemObj.strides, instrShape, rewriter,
280-
typeConverter, loc);
292+
auto load = [=, &rewriter, &smemObj, &shapePerWarp, &multiDimWarpId,
293+
&vals](int batch, int outer, int inner) {
294+
DpasMatmulLoader<opIdx> loader(
295+
dpasLayout, descTy, warpsPerTile, smemObj.strides, shapePerWarp,
296+
multiDimWarpId, rewriter, typeConverter, loc);
281297

282298
// Offset of a slice within the original tensor in shared memory.
283299
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
@@ -295,7 +311,7 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
295311
gep(ptr_ty(rewriter.getContext(), 3), smemTy, smemBase, offs[i]);
296312

297313
// Load from shared memory.
298-
unsigned totalElem = product<unsigned>(instrShape);
314+
unsigned totalElem = product<unsigned>(shapePerWarp);
299315
unsigned threadsPerWarp = product<unsigned>(getThreadsPerWarp(dpasLayout));
300316
auto matTy = LLVM::LLVMStructType::getLiteral(
301317
eltTy.getContext(),
@@ -348,8 +364,9 @@ Value loadOperand(ConversionPatternRewriter &rewriter, Location loc,
348364
// Get the function to use to load the operand.
349365
ValueTable vals;
350366
std::function<void(int, int, int)> loadFn = getLoadMatrixFn<opIdx>(
351-
descTy, smemObj, dpasLayout, warpsPerTile, std::move(shape), warpId,
352-
outerWarpDim, laneId, vals, typeConverter, rewriter, loc);
367+
descTy, smemObj, dpasLayout, warpsPerTile, std::move(shape),
368+
std::move(multiDimWarpId), warpId, outerWarpDim, laneId, vals,
369+
typeConverter, rewriter, loc);
353370

354371
// Load the operand.
355372
int64_t numRepBatch = numReps[0];

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ class DotOpDPASConversionHelper {
299299

300300
size_t totalElems = elems.size();
301301
size_t numElemsPerOperand =
302-
totalElems / ((outer * inner) * (repClusterOuter * repClusterInner));
302+
totalElems /
303+
((batch * outer * inner) * (repClusterOuter * repClusterInner));
303304
VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand);
304305

305306
int offset = 0;

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ struct PrefetchOpConversion
349349
Type eltTy = tensorType.getElementType();
350350
const ArrayRef<int64_t> shapeRef = tensorType.getShape();
351351
SmallVector<int64_t> tensorShape{shapeRef.begin(), shapeRef.end()};
352+
assert(tensorShape.size() == 2 && "Only 2D tensors are prefetch supported");
352353

353354
if (!memoryRowMajor) {
354355
// Swap the shape to make it row major and then get the tiling

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,6 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
568568

569569
inline SmallVector<SmallVector<unsigned>>
570570
emitOffsetForLayout(Attribute layout, RankedTensorType type) {
571-
std::cout << "~! emitOffsetForLayout\n";
572571
if (auto dpasLayout = dyn_cast<DpasEncodingAttr>(layout))
573572
return emitOffsetForDpasLayout(dpasLayout, type);
574573
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout))
@@ -657,7 +656,9 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
657656
// Order
658657
auto inOrder = triton::gpu::getOrder(srcEncoding);
659658
auto outOrder = triton::gpu::getOrder(resSharedLayout);
659+
unsigned rank = outOrder.size();
660660
assert(maxPhase == 1 ||
661+
// outVec * maxPhase <= srcShape[outOrder[rank-2]] &&
661662
outVec * maxPhase <= srcShape[outOrder[0]] &&
662663
"Swizzling would generate out of bounds memory accesses");
663664
// Tensor indices held by the current thread, as LLVM values

0 commit comments

Comments
 (0)