Skip to content

Commit 148b0a9

Browse files
Merge OpenAI Triton commit 00d5ca7 (#4638)
This PR change the Triton base from 3b27971 to 00d5ca7 (Jul 7). Pass rate: 97.08%
2 parents d90db37 + a317379 commit 148b0a9

File tree

20 files changed

+90
-562
lines changed

20 files changed

+90
-562
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,6 @@ class TargetInfoBase {
9898
virtual bool supportStMatrix() const { return false; }
9999
virtual bool isCuda() const { return false; }
100100

101-
// Annotate target specific information to local store operations during
102-
// lowering to LLVM.
103-
virtual void localStoreOpAnnotation(triton::gpu::LocalStoreOp op,
104-
size_t localStoreOpCount,
105-
Type type) const {}
106101
// Annotate target specific information to local load operations during
107102
// lowering to LLVM. `llLoadOp` is the generated LLVM load op.
108103
virtual void localLoadOpAnnotation(triton::gpu::LocalLoadOp localLoadOp,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,11 +548,12 @@ SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
548548
Location loc, RewriterBase &rewriter,
549549
const TargetInfoBase &target);
550550

551-
void storeDistributedToShared(
552-
triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
553-
ArrayRef<Value> srcVals, const SharedMemoryObject &smemObj, Location loc,
554-
RewriterBase &rewriter, const TargetInfoBase &target,
555-
std::pair<size_t, Type> *const llvmOpCount = nullptr);
551+
void storeDistributedToShared(triton::gpu::MemDescType dstTy,
552+
RankedTensorType srcTy, Type elemLlvmTy,
553+
ArrayRef<Value> srcVals,
554+
const SharedMemoryObject &smemObj, Location loc,
555+
RewriterBase &rewriter,
556+
const TargetInfoBase &target);
556557

557558
// Close cousin of lowerLdStMatrix in MemoryOpToLLVM.cpp
558559
// We might want to merge them at some point, but having to support

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,19 @@ using namespace mlir::triton::gpu;
1515
// blocked -> shared.
1616
// Swizzling in shared memory to avoid bank conflict. Normally used for
1717
// A/B operands of dots.
18-
void lowerDistributedToShared(
19-
Location loc, Value src, Value dst, Value adaptorSrc,
20-
const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter,
21-
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
22-
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
18+
void lowerDistributedToShared(Location loc, Value src, Value dst,
19+
Value adaptorSrc,
20+
const SharedMemoryObject &smemObj,
21+
const LLVMTypeConverter *typeConverter,
22+
ConversionPatternRewriter &rewriter,
23+
const TargetInfoBase &targetInfo) {
2324
auto srcTy = cast<RankedTensorType>(src.getType());
2425
auto dstTy = cast<MemDescType>(dst.getType());
2526
auto elemTy = typeConverter->convertType(srcTy.getElementType());
2627

2728
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
2829
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter,
29-
targetInfo, llvmOpCount);
30+
targetInfo);
3031
}
3132

3233
LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
@@ -245,20 +246,16 @@ struct LocalStoreOpConversion
245246
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getDst(),
246247
llvmElemTy, rewriter);
247248
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
248-
std::pair<size_t, Type> llvmOpCount;
249249
if (targetInfo.isCuda()) {
250250
if (failed(lowerLocalStore(loc, ctx, regVal, memDescTy, smemObj, inVals,
251251
typeConverter, rewriter, targetInfo))) {
252252
return failure();
253253
}
254254
} else {
255255
lowerDistributedToShared(loc, regVal, memDescVal, adaptor.getSrc(),
256-
smemObj, typeConverter, rewriter, targetInfo,
257-
&llvmOpCount);
256+
smemObj, typeConverter, rewriter, targetInfo);
258257
}
259258

260-
targetInfo.localStoreOpAnnotation(op, llvmOpCount.first,
261-
llvmOpCount.second);
262259
rewriter.eraseOp(op);
263260
return success();
264261
}

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 69 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,53 @@ std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc) {
296296
return {laneId, warpId};
297297
}
298298

299+
// Helper function: applies linear layout vectorized over register indices
300+
SmallVector<SmallVector<std::pair<StringAttr, Value>>>
301+
applyLinearLayoutVec(Location loc, RewriterBase &rewriter,
302+
const LinearLayout &layout,
303+
ArrayRef<std::pair<StringAttr, Value>> indices,
304+
ArrayRef<uint32_t> registers) {
305+
auto b = TritonLLVMOpBuilder(loc, rewriter);
306+
MLIRContext *ctx = rewriter.getContext();
307+
308+
StringAttr kRegister = str_attr("register");
309+
310+
// Precompute the base (with register = 0)
311+
SmallVector<std::pair<StringAttr, Value>> indicesWithZeroReg;
312+
for (const auto &[attr, val] : indices) {
313+
if (attr == kRegister)
314+
indicesWithZeroReg.emplace_back(attr, b.i32_val(0));
315+
else
316+
indicesWithZeroReg.emplace_back(attr, val);
317+
}
318+
319+
auto baseIndices =
320+
applyLinearLayout(loc, rewriter, layout, indicesWithZeroReg);
321+
322+
SmallVector<SmallVector<std::pair<StringAttr, Value>>> ret;
323+
324+
// Iterate over registers, applying XOR trick
325+
for (auto reg : registers) {
326+
SmallVector<std::pair<StringAttr, int32_t>> constRegIndices;
327+
for (const auto &[attr, val] : indices) {
328+
constRegIndices.emplace_back(attr, attr == kRegister ? reg : 0);
329+
}
330+
auto regIndices = layout.apply(constRegIndices);
331+
332+
SmallVector<std::pair<StringAttr, Value>> combinedIndices;
333+
for (auto [base, regIdx] : llvm::zip(baseIndices, regIndices)) {
334+
assert(base.first == regIdx.first);
335+
Value combined = b.xor_(base.second, b.i32_val(regIdx.second));
336+
combinedIndices.emplace_back(base.first, combined);
337+
}
338+
339+
ret.push_back(combinedIndices);
340+
}
341+
342+
return ret;
343+
}
344+
345+
// Refactored emitIndices function using applyLinearLayoutVec
299346
SmallVector<SmallVector<Value>>
300347
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
301348
Attribute layout, RankedTensorType type, bool withCTAOffset) {
@@ -305,8 +352,6 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
305352

306353
LinearLayout ll = triton::gpu::toLinearLayout(shape, layout);
307354

308-
// TODO(jlebar): We could add strong typing if we wanted; for now this is
309-
// "stringly typed".
310355
StringAttr kRegister = str_attr("register");
311356
StringAttr kLane = str_attr("lane");
312357
StringAttr kWarp = str_attr("warp");
@@ -315,38 +360,29 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
315360
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
316361
Value blockId =
317362
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0);
363+
364+
SmallVector<std::pair<StringAttr, Value>> commonIndices = {
365+
{kRegister, b.i32_val(0)},
366+
{kLane, laneId},
367+
{kWarp, warpId},
368+
{kBlock, blockId}};
369+
370+
// Vectorize over registers
371+
SmallVector<uint32_t> registerIndices;
372+
for (unsigned reg = 0; reg < ll.getInDimSize(kRegister); ++reg)
373+
registerIndices.push_back(reg);
374+
375+
auto vecIndices =
376+
applyLinearLayoutVec(loc, rewriter, ll, commonIndices, registerIndices);
377+
318378
unsigned rank = shape.size();
319379
SmallVector<SmallVector<Value>> ret;
320-
// Linear layout function is split in two parts below:
321-
// L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0)
322-
// idxs = idxsBase xor idxsReg
323-
//
324-
// L(0, t, w, b) part is the same for all registers,
325-
// so we hoist it out of the main register loop in the below.
326-
//
327-
// This approach produces code with lower register pressure and
328-
// less computations, compared to fused L(r,t,w,b) method.
329-
auto idxsBase = applyLinearLayout(loc, rewriter, ll,
330-
{{kRegister, b.i32_val(0)},
331-
{kLane, laneId},
332-
{kWarp, warpId},
333-
{kBlock, blockId}});
334-
for (unsigned reg = 0; reg < ll.getInDimSize(str_attr("register")); reg++) {
335-
auto idxsReg =
336-
ll.apply({{kRegister, reg}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}});
337-
SmallVector<std::pair<StringAttr, Value>> idxs;
338-
for (auto [idxBase, idxReg] : llvm::zip(idxsBase, idxsReg)) {
339-
auto dimName = idxBase.first;
340-
assert(dimName == idxReg.first &&
341-
"dim names of block+warp+thread and register idx should be equal");
342-
auto idx = b.xor_(idxBase.second, b.i32_val(idxReg.second));
343-
idxs.emplace_back(dimName, idx);
344-
}
345-
assert(idxs.size() == rank);
346-
for (unsigned k = 0; k < rank; ++k) {
347-
assert(idxs[k].first == str_attr("dim" + std::to_string(k)));
348-
}
349-
ret.push_back(llvm::to_vector(llvm::make_second_range(idxs)));
380+
for (auto &indices : vecIndices) {
381+
SmallVector<Value> vals;
382+
assert(indices.size() == rank);
383+
for (auto &idx : indices)
384+
vals.push_back(idx.second);
385+
ret.push_back(vals);
350386
}
351387

352388
return ret;
@@ -781,8 +817,7 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy,
781817
ArrayRef<Value> srcVals,
782818
const SharedMemoryObject &smemObj, Location loc,
783819
RewriterBase &rewriter,
784-
const TargetInfoBase &target,
785-
std::pair<size_t, Type> *const llvmOpCount) {
820+
const TargetInfoBase &target) {
786821
auto b = TritonLLVMOpBuilder(loc, rewriter);
787822
bool success = emitTransferBetweenRegistersAndShared(
788823
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc,
@@ -797,10 +832,6 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy,
797832
b.store(vec, vecAddr)
798833
.setAlignment(vecTy.getNumElements() *
799834
elemLlvmTy.getIntOrFloatBitWidth() / 8);
800-
if (llvmOpCount) {
801-
++(llvmOpCount->first);
802-
llvmOpCount->second = vecTy;
803-
}
804835
});
805836

806837
if (!success)

python/triton_kernels/triton_kernels/reduction_details/reduce_bitmatrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def sum_bitmatrix_rows(x, out_ret, partials_block_size=None, n_rows_raw=None):
8888
n_rows_pad, n_cols_raw = x.shape_pad[0], x.shape_raw[1]
8989
assert out_ret.shape == (n_cols_raw, )
9090

91-
TILE_SIZE = 2
91+
TILE_SIZE = max(1, 128 // PARTIALS_BLOCK_M)
9292
BLOCK_MM = PARTIALS_BLOCK_M * TILE_SIZE
9393

9494
pids_x = cdiv(n_rows_pad, BLOCK_MM)

python/triton_kernels/triton_kernels/routing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class SortTokens(torch.autograd.Function):
9494

9595
@staticmethod
9696
def forward(ctx, expt_scal, expt_indx, bitmatrix):
97-
HIST_BLOCK_M = 64
97+
HIST_BLOCK_M = 32
9898
INDX_OFFS_BLOCK_M = 512
9999
MEMSET_BLOCK = 1024
100100
cdiv = triton.cdiv

python/triton_kernels/triton_kernels/topk_details/_topk_forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.co
6262
acc = tl.topk(x, N_EXPTS_ACT, dim=1)
6363

6464
# subsequent iterations:
65-
for _i in range(loop_iterations):
65+
for _i in (tl.static_range if loop_iterations <= 4 else range)(loop_iterations):
6666
acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge
6767
X_ptrs -= BLOCK_N
6868
offs_x_n -= BLOCK_N

test/Conversion/intel/dot_layout_offset.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32}
1414
// CHECK: %[[VAL_145:.*]] = llvm.mlir.constant(16 : i32) : i32
1515
// CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_I32]], %[[VAL_145]] : i32
1616
// CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_I32]], %[[VAL_145]] : i32
17-
// CHECK-COUNT-3: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
17+
// CHECK-COUNT-4: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
1818
// CHECK: %[[VAL_149:.*]] = llvm.mlir.constant(1 : i32) : i32
1919
// CHECK: %[[VAL_150:.*]] = llvm.and %[[LANE_ID]], %[[VAL_149]] : i32
2020
// CHECK: %[[VAL_151:.*]] = llvm.icmp "eq" %[[VAL_150]], %[[CST_0]] : i32
@@ -336,7 +336,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.thr
336336
// CHECK: %[[VAL_145:.*]] = llvm.mlir.constant(16 : i32) : i32
337337
// CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_I32]], %[[VAL_145]] : i32
338338
// CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_I32]], %[[VAL_145]] : i32
339-
// CHECK-COUNT-3: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
339+
// CHECK-COUNT-4: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
340340
// CHECK: %[[VAL_149:.*]] = llvm.mlir.constant(1 : i32) : i32
341341
// CHECK: %[[VAL_150:.*]] = llvm.and %[[LANE_ID]], %[[VAL_149]] : i32
342342
// CHECK: %[[VAL_151:.*]] = llvm.icmp "eq" %[[VAL_150]], %[[CST_0]] : i32

0 commit comments

Comments
 (0)