Skip to content

Commit a36827a

Browse files
authored
Merge OpenAI commit 84ecfd0 (#5167)
This PR change the Triton base from 1b6a74c to 84ecfd0 (Sep 16). Pass rate: 97.47%
2 parents ca793fa + e6e66f4 commit a36827a

File tree

20 files changed

+573
-135
lines changed

20 files changed

+573
-135
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ and `pN` to mean padding:
428428
x1, x3, p2, p3
429429
...]
430430

431-
2. 2D single interval-padding with rearanged rows.
431+
2. 2D single interval-padding with rearranged rows.
432432

433433
#ttg.padded_shared<[16:+1] {offset = [[0, 1], [0, 2], /*gap, stride by 2 rows*/[2, 0], [4, 0], [1, 0]]], block = []}>
434434
[

include/triton/Tools/GenericSwizzling.h

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,17 @@ optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
4040
LinearLayout optimalSwizzlingLdSt(const LinearLayout &src,
4141
const LinearLayout &dst, int32_t bitwidth);
4242

43-
std::pair<int, int> logBankConflictsLdSt(const LinearLayout &src,
44-
const LinearLayout &dst,
45-
const LinearLayout &smem,
46-
int32_t bitwidth);
47-
48-
int logBankConflictsMemDesc(const LinearLayout &reg, const LinearLayout &smem,
49-
int32_t bitwidth);
50-
51-
std::pair<int, int> logBankConflicts(llvm::ArrayRef<int32_t> tileSrc,
52-
llvm::ArrayRef<int32_t> tileDst,
53-
const LinearLayout &smem,
54-
int32_t bitwidth);
43+
std::pair<int, int> bankConflictsLdSt(const LinearLayout &src,
44+
const LinearLayout &dst,
45+
const LinearLayout &smem,
46+
int32_t bitwidth);
47+
48+
int bankConflictsMemDesc(const LinearLayout &reg, const LinearLayout &smem,
49+
int32_t bitwidth);
50+
51+
std::pair<int, int> bankConflicts(llvm::ArrayRef<int32_t> tileSrc,
52+
llvm::ArrayRef<int32_t> tileDst,
53+
const LinearLayout &smem);
5554
} // namespace mlir::triton::gpu
5655

5756
#endif // TRITON_GENERIC_SWIZZLING_H

include/triton/Tools/LayoutUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ LinearLayout reshapeLayout(MLIRContext *ctx, LinearLayout layout,
141141
// order.
142142
LinearLayout transposeLinearLayout(LinearLayout layout, ArrayRef<int> order);
143143

144+
// Given a distributed into shmem layout, return the largest vectorisation
145+
// that can be used to lower the layout via ld/st.
146+
std::pair<int, ColumnAction>
147+
largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth,
148+
std::optional<int> maybeMaxVecElems = std::nullopt);
149+
144150
} // namespace mlir::triton
145151

146152
#endif // TRITON_TOOLS_LAYOUTUTILS_H

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -525,41 +525,6 @@ Value emitPadding(Location loc, RewriterBase &rewriter,
525525
return padOffset;
526526
}
527527

528-
namespace {
529-
std::pair<int, ColumnAction>
530-
largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth,
531-
std::optional<int> maybeMaxVecElems = std::nullopt) {
532-
// Find the largest vectorisation we can use:
533-
StringAttr kReg = str_attr("register");
534-
StringAttr kOffset = str_attr("offset");
535-
LinearLayout quot;
536-
LinearLayout tile;
537-
ColumnAction permutation;
538-
// If there are restrictions on the vectorisation, we don't allow
539-
// permutations.
540-
auto allowPerm = !maybeMaxVecElems.has_value();
541-
auto maxVecElems = maybeMaxVecElems.value_or(128 / bitwidth);
542-
for (int v = maxVecElems; v >= 1; v /= 2) {
543-
tile = LinearLayout::identity1D(v, kReg, kOffset);
544-
auto maybePerm = regPermForDivide(cvt, tile, /*left=*/true);
545-
if (!maybePerm) {
546-
continue;
547-
}
548-
permutation = *maybePerm;
549-
if (!allowPerm && !permutation.isIdentity()) {
550-
continue;
551-
}
552-
auto newCvt = permutation.apply(cvt);
553-
auto maybeQuot = divideLeft(newCvt, tile);
554-
if (!maybeQuot) {
555-
continue;
556-
}
557-
return {v, permutation};
558-
}
559-
llvm_unreachable("Vectorization < 1 is not valid");
560-
}
561-
} // namespace
562-
563528
SmallVector<Value>
564529
lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
565530
ArrayRef<Value> valsArray, // Input for store, output for load

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,8 +1938,8 @@ LogicalResult PaddedSharedEncodingAttr::verify(
19381938
}
19391939
// Ensure all non zero elements are a power of 2. Combined with the
19401940
// broadcast check above this prevents per element swizzling. The intent of
1941-
// the linear component is to rearange whole rows or cache-line sized chunks
1942-
// of rows.
1941+
// the linear component is to rearrange whole rows or cache-line sized
1942+
// chunks of rows.
19431943
if (!llvm::all_of(dimBases, [&](const auto &basis) {
19441944
return llvm::all_of(
19451945
basis, [](auto v) { return v == 0 || llvm::isPowerOf2_32(v); });

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,11 @@ class CombineTMEMStoreAndAlloc : public OpRewritePattern<ttng::TMEMStoreOp> {
181181
return failure();
182182
if (alloc->getBlock() != store->getBlock())
183183
return failure();
184+
if (auto srcDef = store.getSrc().getDefiningOp()) {
185+
if (alloc->getBlock() == srcDef->getBlock() &&
186+
alloc->isBeforeInBlock(srcDef))
187+
return failure();
188+
}
184189
alloc.getSrcMutable().assign(store.getSrc());
185190
rewriter.replaceOp(store, alloc.getToken());
186191
return success();

lib/Tools/GenericSwizzling.cpp

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -231,37 +231,30 @@ SmallVector<int32_t> intersectionBasis(ArrayRef<int32_t> b1,
231231
}
232232
}
233233

234-
std::pair<int, int> logBankConflicts(ArrayRef<int32_t> tileSrc,
235-
ArrayRef<int32_t> tileDst,
236-
const LinearLayout &smem,
237-
int32_t bitwidth) {
234+
std::pair<int, int> bankConflicts(ArrayRef<int32_t> tileSrc,
235+
ArrayRef<int32_t> tileDst,
236+
const LinearLayout &smem) {
238237
auto *ctx = smem.getOutDimNames().begin()->getContext();
239238
auto smemFlat = smem.flattenOuts();
240239
auto inDim = *smem.getInDimNames().begin();
241-
// Take all the bases in the first bank (32 bits)
242-
auto smemBases =
243-
flatten(smemFlat.flattenIns(), *smemFlat.getInDimNames().begin());
244-
auto nBankZero = llvm::Log2_32(std::max<int32_t>(1, 32 / bitwidth));
245-
if (smemBases.size() >= nBankZero) {
246-
smemBases.resize(nBankZero);
247-
}
248-
// And segments
240+
// Look at the intersection between the segment bases and the tile bases
241+
// We don't need to intersect with the bases that covert the bank (as in
242+
// the first 32 / bitwidth bases) because if we hit any of those broadcasting
243+
// will avoid the bank conflict
249244
auto segment = StringAttr::get(ctx, "segment");
250245
auto segmentBases = flatten(smemFlat, segment);
251-
auto bankZero =
252-
llvm::to_vector(llvm::concat<int32_t>(smemBases, segmentBases));
253246

254247
int32_t rank = smem.getTotalOutDimSizeLog2();
255248
// compute conflicts
256-
int write = intersectionBasis(bankZero, tileSrc, rank).size();
257-
int read = intersectionBasis(bankZero, tileDst, rank).size();
258-
return {read, write};
249+
int write = 1 << intersectionBasis(segmentBases, tileSrc, rank).size();
250+
int read = 1 << intersectionBasis(segmentBases, tileDst, rank).size();
251+
return {read - 1, write - 1};
259252
}
260253

261-
std::pair<int, int> logBankConflictsLdSt(const LinearLayout &src,
262-
const LinearLayout &dst,
263-
const LinearLayout &smem,
264-
int32_t bitwidth) {
254+
std::pair<int, int> bankConflictsLdSt(const LinearLayout &src,
255+
const LinearLayout &dst,
256+
const LinearLayout &smem,
257+
int32_t bitwidth) {
265258
auto srcFlat = src.flattenOuts();
266259
auto dstFlat = dst.flattenOuts();
267260
auto *ctx = smem.getOutDimNames().begin()->getContext();
@@ -273,19 +266,24 @@ std::pair<int, int> logBankConflictsLdSt(const LinearLayout &src,
273266
llvm::Log2_32(std::max(smem.getInDimSize(kVec) * bitwidth / 32, 1));
274267
srcLane.resize(srcLane.size() - log2Vec);
275268
dstLane.resize(dstLane.size() - log2Vec);
276-
return logBankConflicts(srcLane, dstLane, smem, bitwidth);
269+
return bankConflicts(srcLane, dstLane, smem);
277270
}
278271

279-
int logBankConflictsMemDesc(const LinearLayout &reg, const LinearLayout &smem,
280-
int32_t bitwidth) {
272+
int bankConflictsMemDesc(const LinearLayout &reg, const LinearLayout &smem,
273+
int32_t bitwidth) {
281274
auto *ctx = smem.getInDimNames().begin()->getContext();
282275
auto S = [ctx](StringRef str) { return StringAttr::get(ctx, str); };
283276

284277
assert(smem.hasInDim(S("offset")) && "shared layout must have an offset dim");
285278
assert(reg.hasInDim(S("register")) &&
286279
"register layout must have a register dim");
280+
auto regNoBroadcast = actionRemoveBroadcastedRegs(reg).apply(reg);
281+
auto regToShared = regNoBroadcast.invertAndCompose(smem);
282+
auto [elemsPerVec, permutation] =
283+
largestVectorisation(ctx, regToShared, bitwidth);
284+
regNoBroadcast = permutation.apply(regNoBroadcast);
287285

288-
int32_t vecSize = reg.invertAndCompose(smem).getNumConsecutiveInOut();
286+
int32_t vecSize = elemsPerVec;
289287
int32_t bankSize =
290288
std::min(32 * 32 / (vecSize * bitwidth), smem.getTotalInDimSize());
291289
int32_t segmentSize = smem.getTotalInDimSize() / (bankSize * vecSize);
@@ -295,7 +293,9 @@ int logBankConflictsMemDesc(const LinearLayout &reg, const LinearLayout &smem,
295293
{S("segment"), segmentSize},
296294
};
297295
auto smemReshaped = smem.reshapeIns(newInDims);
298-
return logBankConflictsLdSt(reg, reg, smemReshaped, bitwidth).first;
296+
return bankConflictsLdSt(regNoBroadcast, regNoBroadcast, smemReshaped,
297+
bitwidth)
298+
.first;
299299
}
300300

301301
std::optional<SmallVector<int32_t>> optimalSwizzlingTile(
@@ -675,7 +675,7 @@ optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
675675
for (auto [instrs, vbasis, tileSrc, tileDst] : tiles) {
676676
auto smem = optimalSwizzling(srcFlat, dstFlat, bitwidth, vbasis, tileSrc,
677677
tileDst, src.getOutDims());
678-
auto [read, write] = logBankConflicts(tileSrc, tileDst, smem, bitwidth);
678+
auto [read, write] = bankConflicts(tileSrc, tileDst, smem);
679679
smems.push_back({read + write, smem, {instrs.first, instrs.second}});
680680
}
681681
// Current heuristic: Minimise total bank conflicts

lib/Tools/LayoutUtils.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,4 +443,38 @@ LinearLayout transposeLinearLayout(LinearLayout layout, ArrayRef<int> order) {
443443
to_vector(layout.getOutDimNames()));
444444
}
445445

446+
std::pair<int, ColumnAction>
447+
largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth,
448+
std::optional<int> maybeMaxVecElems) {
449+
// Find the largest vectorisation we can use:
450+
auto S = [ctx](StringRef str) { return StringAttr::get(ctx, str); };
451+
StringAttr kReg = S("register");
452+
StringAttr kOffset = S("offset");
453+
LinearLayout quot;
454+
LinearLayout tile;
455+
ColumnAction permutation;
456+
// If there are restrictions on the vectorisation, we don't allow
457+
// permutations.
458+
auto allowPerm = !maybeMaxVecElems.has_value();
459+
auto maxVecElems = maybeMaxVecElems.value_or(128 / bitwidth);
460+
for (int v = maxVecElems; v >= 1; v /= 2) {
461+
tile = LinearLayout::identity1D(v, kReg, kOffset);
462+
auto maybePerm = regPermForDivide(cvt, tile, /*left=*/true);
463+
if (!maybePerm) {
464+
continue;
465+
}
466+
permutation = *maybePerm;
467+
if (!allowPerm && !permutation.isIdentity()) {
468+
continue;
469+
}
470+
auto newCvt = permutation.apply(cvt);
471+
auto maybeQuot = divideLeft(newCvt, tile);
472+
if (!maybeQuot) {
473+
continue;
474+
}
475+
return {v, permutation};
476+
}
477+
llvm_unreachable("Vectorization < 1 is not valid");
478+
}
479+
446480
} // namespace mlir::triton

python/src/gluon_ir.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,7 @@ void init_gluon_ir(py::module &&m) {
554554
int bitwidth) -> int {
555555
auto regLayout = ttg::toLinearLayout(shape, regLayoutAttr);
556556
auto smemLayout = ttg::toLinearLayout(shape, sharedLayoutAttr);
557-
return 1 << ttg::logBankConflictsMemDesc(regLayout, smemLayout,
558-
bitwidth);
557+
return ttg::bankConflictsMemDesc(regLayout, smemLayout, bitwidth);
559558
})
560559
.def("create_local_dealloc",
561560
[](GluonOpBuilder &self, Value memDesc) -> Operation * {

python/test/gluon/test_frontend.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,37 +1336,36 @@ def test_static_assert():
13361336

13371337

13381338
@pytest.mark.parametrize("reg_layout, shared_layout, shape, bitwidth, ref_conflicts", [
1339-
(ttgl.BlockedLayout([1], [32], [4], [0]), ttgl.SwizzledSharedLayout(1, 1, 1, order=[0]), [32], 32, 1),
1340-
# FIXME: This one should be zero conflicts due to broadcasting.
1341-
(ttgl.BlockedLayout([1], [32], [4], [0]), ttgl.SwizzledSharedLayout(1, 1, 1, order=[0]), [32], 16, 2),
1339+
(ttgl.BlockedLayout([1], [32], [4], [0]), ttgl.SwizzledSharedLayout(1, 1, 1, order=[0]), [32], 32, 0),
1340+
(ttgl.BlockedLayout([1], [32], [4], [0]), ttgl.SwizzledSharedLayout(1, 1, 1, order=[0]), [32], 16, 0),
13421341
# MMAv3 accumulator tile lowered with the 128B swizzle (WGMMA default path).
13431342
(ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16]),
1344-
ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2), [128, 128], 16, 1),
1343+
ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2), [128, 128], 16, 0),
13451344
# Small-M tiles disable swizzling entirely.
13461345
# MMAv2 rhs operand emitted with the 64B swizzle.
13471346
(ttgl.DotOperandLayout(
13481347
operand_index=1, parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[1, 4], instr_shape=[16, 8]),
1349-
k_width=2), ttgl.NVMMASharedLayout(swizzle_byte_width=64, element_bitwidth=16, rank=2), [64, 32], 16, 2),
1348+
k_width=2), ttgl.NVMMASharedLayout(swizzle_byte_width=64, element_bitwidth=16, rank=2), [64, 32], 16, 0),
13501349
# MMAv2 lhs operand uses the transposed 64B swizzle flavour.
13511350
(ttgl.DotOperandLayout(
13521351
operand_index=0, parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[1, 4], instr_shape=[16, 8]),
13531352
k_width=2), ttgl.NVMMASharedLayout(swizzle_byte_width=64, element_bitwidth=16, rank=2,
1354-
transposed=True), [32, 64], 16, 2),
1353+
transposed=True), [32, 64], 16, 0),
13551354
# int8 tensor-core tiles follow the 32B swizzle path.
13561355
(ttgl.DotOperandLayout(
13571356
operand_index=1, parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[1, 4], instr_shape=[16, 8]),
1358-
k_width=1), ttgl.NVMMASharedLayout(swizzle_byte_width=32, element_bitwidth=8, rank=2), [8, 32], 8, 4),
1357+
k_width=1), ttgl.NVMMASharedLayout(swizzle_byte_width=32, element_bitwidth=8, rank=2), [8, 32], 8, 0),
13591358
# Small-M tiles disable swizzling entirely.
13601359
(ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], instr_shape=[16, 8]),
1361-
ttgl.NVMMASharedLayout(swizzle_byte_width=64, element_bitwidth=16, rank=2, transposed=True), [64, 64], 16, 2),
1360+
ttgl.NVMMASharedLayout(swizzle_byte_width=64, element_bitwidth=16, rank=2, transposed=True), [64, 64], 16, 0),
13621361
(ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[2, 2], instr_shape=[16, 32, 16]),
1363-
ttgl.NVMMASharedLayout(swizzle_byte_width=64, element_bitwidth=16, rank=2), [64, 32], 16, 1),
1362+
ttgl.NVMMASharedLayout(swizzle_byte_width=64, element_bitwidth=16, rank=2), [64, 32], 16, 0),
13641363
(ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], instr_shape=[16, 8]),
1365-
ttgl.NVMMASharedLayout(swizzle_byte_width=32, element_bitwidth=8, rank=2), [32, 32], 8, 2),
1364+
ttgl.NVMMASharedLayout(swizzle_byte_width=32, element_bitwidth=8, rank=2), [32, 32], 8, 0),
13661365
(ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], instr_shape=[16, 8]),
1367-
ttgl.NVMMASharedLayout(swizzle_byte_width=0, element_bitwidth=16, rank=2), [4, 64], 16, 4),
1366+
ttgl.NVMMASharedLayout(swizzle_byte_width=0, element_bitwidth=16, rank=2), [4, 64], 16, 3),
13681367
(ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16]),
1369-
ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=32, rank=2), [128, 64], 32, 2),
1368+
ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=32, rank=2), [128, 64], 32, 1),
13701369
])
13711370
def test_bank_conflicts(reg_layout, shared_layout, shape, bitwidth, ref_conflicts):
13721371
dtype = {8: ttgl.int8, 16: ttgl.float16, 32: ttgl.float32}[bitwidth]

0 commit comments

Comments
 (0)