Skip to content

Commit 7ec882d

Browse files
Merge OpenAI Triton commit d412906 (#5045)
This PR change the Triton base from 818e892 to d412906 (Aug 26). Pass rate: 98.78%->98.74%
2 parents ebfa3ce + 90494c0 commit 7ec882d

File tree

53 files changed

+1987
-1994
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1987
-1994
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ test-distributed: all
5454
.PHONY: test-gluon
5555
test-gluon: all
5656
$(PYTEST) -s -n $(NUM_PROCS) python/test/gluon
57-
$(PYTEST) -vs python/tutorials/gluon/01-attention-forward.py
57+
$(PYTEST) -vs python/examples/gluon/01-attention-forward.py
5858

5959
.PHONY: test-regression
6060
test-regression: all

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,8 @@ class SharedMemoryObject {
377377

378378
Value getShmemOffset(Location loc, RewriterBase &rewriter,
379379
triton::gpu::MemDescType srcTy) const;
380+
Value getShmemAffineBase(Location loc, RewriterBase &rewriter,
381+
triton::gpu::MemDescType srcTy) const;
380382

381383
// TODO(Keren): deprecate the method once AMD backend has cleaned up
382384
Value getCSwizzleOffset(int dim) const {

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -701,8 +701,13 @@ def TTNG_TMEMCopyOp : TTNG_Op<"tmem_copy"> {
701701
for the completion of the copy before MMA, since tcgen05.cp followed by tcgen05.mma is guaranteed to
702702
execute in that order.
703703

704-
This op lowers to the PTX instruction tcgen05.cp. Right now, we only support 1CTA and the warpx4.32x128b
705-
variant of the instruction. Each 32x128b block in SMEM is duplicated over 4 warps and stored into 128 rows
704+
This op lowers to the PTX instruction tcgen05.cp. This supports writing either to scales tmem layout as well as default tmem layout.
705+
Currently the semantic is different when writing to tmem scale layout.
706+
707+
In case of default layout the copy doesn't change the logical elements between the source and destination memdesc.
708+
709+
In case of scale layout:
710+
Each 32x128b block in SMEM is duplicated over 4 warps and stored into 128 rows
706711
and 4 columns of TMEM. The primary use case of this op is to copy blocked scales from SMEM to TMEM.
707712

708713
The shape of the input SMEM can be flexibily chosen depending on use cases. In the simplest case (e.g. unit test),
@@ -741,7 +746,7 @@ def TTNG_TMEMCopyOp : TTNG_Op<"tmem_copy"> {
741746
Optional<TTG_MemDescType>:$barrier
742747
);
743748

744-
let assemblyFormat = [{$src `,` $dst `,` $barrier attr-dict `:` functional-type(operands, results)}];
749+
let assemblyFormat = [{$src `,` $dst (`,` $barrier^)? attr-dict `:` qualified(type(operands))}];
745750
let hasVerifier = 1;
746751
}
747752

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,14 @@ Value SharedMemoryObject::getShmemOffset(Location loc, RewriterBase &rewriter,
12091209
return offset;
12101210
}
12111211

1212+
Value SharedMemoryObject::getShmemAffineBase(
1213+
Location loc, RewriterBase &rewriter,
1214+
triton::gpu::MemDescType srcTy) const {
1215+
auto b = TritonLLVMOpBuilder(loc, rewriter);
1216+
Value offset = getShmemOffset(loc, rewriter, srcTy);
1217+
return b.gep(base.getType(), baseElemType, base, offset);
1218+
}
1219+
12121220
Value getStructFromSharedMemoryObject(Location loc,
12131221
const SharedMemoryObject &smemObj,
12141222
RewriterBase &rewriter) {

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,8 +561,8 @@ struct MemDescReinterpretOpConversion
561561

562562
auto smemObj =
563563
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), srcElemTy, b);
564-
SharedMemoryObject newObj(smemObj.getBase(), dstElemTy, dstTy.getRank(),
565-
loc, b);
564+
Value newBase = smemObj.getShmemAffineBase(loc, b, srcTy);
565+
SharedMemoryObject newObj(newBase, dstElemTy, dstTy.getRank(), loc, b);
566566
b.replaceOp(op, getStructFromSharedMemoryObject(loc, newObj, b));
567567
return success();
568568
}

lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,27 @@ class ConcurrencySanitizerPass
552552
}
553553
}
554554
if (auto commitOp = dyn_cast<ttng::TCGen5CommitOp>(op)) {
555+
// Workaround: scan towards the beginning of the current block looking
556+
// for mmav5s and mark their operands as reads guarded by the barrier.
557+
Operation *prevOp = op->getPrevNode();
558+
while (prevOp) {
559+
auto setBarrier = [&](TypedValue<ttg::MemDescType> buf) {
560+
MemType memType = MemType::TENSOR_MEM;
561+
if (isa<ttg::SharedEncodingTrait>(buf.getType().getEncoding())) {
562+
memType = MemType::SHARED_MEM;
563+
}
564+
b.create<tti::ExperimentalSetReadBarrierOp>(
565+
buf, commitOp.getBarrier(), buffersTensor[(int)memType],
566+
barriers, readBarriersAlloc[(int)memType],
567+
readBarriersType[(int)memType], commitOp.getPred());
568+
};
569+
if (auto mmav5Op = dyn_cast<ttng::TCGen5MMAOp>(prevOp)) {
570+
setBarrier(mmav5Op.getA());
571+
setBarrier(mmav5Op.getB());
572+
}
573+
prevOp = prevOp->getPrevNode();
574+
}
575+
555576
b.create<tti::ExperimentalCommitWriteWithBarrierOp>(
556577
commitOp.getBarrier(), barriers,
557578
writeBarriersAlloc[(int)MemType::TENSOR_MEM],

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -588,9 +588,6 @@ LogicalResult TMEMCopyOp::verify() {
588588
if (!isa<triton::gpu::SharedMemorySpaceAttr>(
589589
getSrc().getType().getMemorySpace()))
590590
return emitOpError("The source must be a shared memory buffer");
591-
if (!isa<TensorMemoryEncodingAttr, TensorMemoryScalesEncodingAttr>(
592-
getDst().getType().getEncoding()))
593-
return emitOpError("The destination must be a tensor memory buffer.");
594591

595592
if (getBarrier() && !isa<triton::gpu::SharedMemorySpaceAttr>(
596593
getBarrier().getType().getMemorySpace())) {
@@ -599,19 +596,41 @@ LogicalResult TMEMCopyOp::verify() {
599596
if (!getDst().getType().getMutableMemory()) {
600597
return emitOpError("Cannot copy into an immutable alloc");
601598
}
602-
603599
auto srcTy = cast<triton::gpu::MemDescType>(getSrc().getType());
604600
auto sharedEnc =
605-
cast<triton::gpu::NVMMASharedEncodingAttr>(srcTy.getEncoding());
606-
607-
if (!sharedEnc || sharedEnc.getTransposed() || sharedEnc.getFp4Padded() ||
608-
sharedEnc.getSwizzlingByteWidth() != 0)
609-
return emitOpError("The source should not have swizzling applied for now");
610-
611-
if (!triton::gpu::isInnermostContiguous(srcTy, 512)) {
612-
return emitOpError("The source must be in a row-major order.");
601+
dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(srcTy.getEncoding());
602+
if (!sharedEnc) {
603+
return emitOpError("Source must have nvmma layout.");
604+
}
605+
if (sharedEnc.getTransposed() || sharedEnc.getFp4Padded())
606+
return emitOpError("The source should not be transposed or passed");
607+
if (isa<TensorMemoryScalesEncodingAttr>(getDst().getType().getEncoding())) {
608+
if (sharedEnc.getSwizzlingByteWidth() != 0) {
609+
return emitOpError("The source should not be swizzled for now");
610+
}
611+
if (!triton::gpu::isInnermostContiguous(srcTy, 512)) {
612+
return emitOpError("The source must be in a row-major order.");
613+
}
614+
} else {
615+
if (getSrc().getType().getShape() != getDst().getType().getShape()) {
616+
return emitOpError(
617+
"The source and destination must have the same shape.");
618+
}
619+
auto tmemEnc = dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
620+
getDst().getType().getEncoding());
621+
if (!tmemEnc) {
622+
return emitOpError("Incorrect tmem layout.");
623+
}
624+
if (tmemEnc.getBlockM() != 128) {
625+
return emitOpError("Tmem layout ahouls have M=128.");
626+
}
627+
if (sharedEnc.getSwizzlingByteWidth() == 0) {
628+
return emitOpError("Source layout should be swizzled.");
629+
}
630+
if (srcTy.getElementType().getIntOrFloatBitWidth() != 32) {
631+
return emitOpError("Source element type should be 32-bit.");
632+
}
613633
}
614-
615634
// Given that we want to support flexible input SMEM shapes, kinds of shape
616635
// checking we can do here are limited. For simplicity, shape checking is
617636
// omitted.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,6 @@ line-length = 120
3535

3636
[tool.ruff.lint]
3737
ignore = ["E501", "E701", "E731", "E741"]
38+
39+
[tool.ruff.lint.per-file-ignores]
40+
"__init__.py" = ["F401"]

python/src/gluon_ir.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,14 @@ void init_gluon_ir(py::module &&m) {
403403
ctx, block[0], block[1], unpacked, ctaSplitNum[0],
404404
ctaSplitNum[1]);
405405
})
406+
.def("get_tensor_memory_scales_layout",
407+
[](GluonOpBuilder &self,
408+
std::vector<unsigned> &ctaSplitNum) -> Attribute {
409+
auto ctx = self.getContext();
410+
assert(ctaSplitNum.size() == 2);
411+
return self.getChecked<ttng::TensorMemoryScalesEncodingAttr>(
412+
ctx, ctaSplitNum[0], ctaSplitNum[1]);
413+
})
406414
.def("get_gluon_layout_from_tensor",
407415
[](GluonOpBuilder &self, Value tensor) -> py::object {
408416
auto ty = dyn_cast<RankedTensorType>(tensor.getType());
@@ -548,6 +556,10 @@ void init_gluon_ir(py::module &&m) {
548556
[](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value {
549557
return self.create<ttng::TMEMLoadOp>(resultTy, memDesc);
550558
})
559+
.def("create_tmem_copy",
560+
[](GluonOpBuilder &self, Value src, Value dst) {
561+
self.create<ttng::TMEMCopyOp>(src, dst, /*barrier=*/Value());
562+
})
551563
.def("create_tmem_subslice",
552564
[](GluonOpBuilder &self, Type resultTy, Value memDesc,
553565
int N) -> Value {
@@ -585,6 +597,10 @@ void init_gluon_ir(py::module &&m) {
585597
pred, two_ctas, mbarriers,
586598
mbarrier_preds);
587599
})
600+
.def("create_tcgen05_cp",
601+
[](GluonOpBuilder &self, Value src, Value dst) {
602+
self.create<ttng::TMEMCopyOp>(src, dst, Value());
603+
})
588604
.def("create_tcgen05_commit",
589605
[](GluonOpBuilder &self, Value &barrier) {
590606
self.create<ttng::TCGen5CommitOp>(barrier);

python/test/gluon/test_consan.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ def tcgen5_mma_kernel(input_desc, XBLOCK: ttgl.constexpr, FAILURE: ttgl.constexp
228228
mbarrier.init(bar.index(0), count=1)
229229
mbarrier.init(bar.index(1), count=1)
230230

231-
blackwell.tcgen05_mma(smemA, smemB.permute([1, 0]), acc, mbarriers=[bar.index(0)])
231+
blackwell.tcgen05_mma(smemA, smemB.permute([1, 0]), acc)
232+
blackwell.tcgen05_commit(bar.index(0))
232233

233234
if not FAILURE:
234235
mbarrier.wait(bar.index(0), 0)
@@ -285,32 +286,6 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
285286
tcgen5_mma_kernel[(1, )](input_desc, XBLOCK, FAILURE=FAILURE, MEM_ACCESS_KIND=MEM_ACCESS_KIND, num_warps=4)
286287

287288

288-
@gluon.jit
289-
def tcgen5_mma_multibar_kernel(input_desc, XBLOCK: ttgl.constexpr, BUF_IDX: ttgl.constexpr, BAR_IDX: ttgl.constexpr):
290-
acc_layout: ttgl.constexpr = blackwell.TensorMemoryLayout([XBLOCK, XBLOCK], unpacked=True, cta_split_num=[1, 1])
291-
blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1],
292-
warps_per_cta=[4, 1], order=[0, 1])
293-
smemA = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout)
294-
smemB = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout)
295-
bar = ttgl.allocate_shared_memory(ttgl.int64, [4, 1], mbarrier.MBarrierLayout())
296-
acc = blackwell.allocate_tensor_memory(ttgl.float32, [2, XBLOCK, XBLOCK], acc_layout)
297-
for i in range(4):
298-
mbarrier.init(bar.index(i), count=1)
299-
300-
blackwell.tcgen05_mma(smemA, smemB.permute([1, 0]), acc.index(0), mbarriers=[bar.index(0),
301-
bar.index(1)],
302-
mbarrier_preds=[False, True])
303-
blackwell.tcgen05_mma(smemA, smemB.permute([1, 0]), acc.index(1), mbarriers=[bar.index(2)])
304-
blackwell.tcgen05_commit(bar.index(3))
305-
306-
mbarrier.wait(bar.index(BAR_IDX), 0)
307-
308-
acc.index(BUF_IDX).store(ttgl.full([XBLOCK, XBLOCK], 42, ttgl.float32, blocked_layout))
309-
310-
for i in range(4):
311-
mbarrier.invalidate(bar.index(i))
312-
313-
314289
@gluon.jit
315290
def warpgroup_mma_kernel(input, XBLOCK: ttgl.constexpr, FAILURE: ttgl.constexpr):
316291
smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2)
@@ -405,6 +380,32 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
405380
warpgroup_mma_kernel[(1, )](input, XBLOCK, FAILURE=FAILURE)
406381

407382

383+
@gluon.jit
384+
def tcgen5_mma_multibar_kernel(input_desc, XBLOCK: ttgl.constexpr, BUF_IDX: ttgl.constexpr, BAR_IDX: ttgl.constexpr):
385+
acc_layout: ttgl.constexpr = blackwell.TensorMemoryLayout([XBLOCK, XBLOCK], unpacked=True, cta_split_num=[1, 1])
386+
blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1],
387+
warps_per_cta=[4, 1], order=[0, 1])
388+
smemA = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout)
389+
smemB = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout)
390+
bar = ttgl.allocate_shared_memory(ttgl.int64, [4, 1], mbarrier.MBarrierLayout())
391+
acc = blackwell.allocate_tensor_memory(ttgl.float32, [2, XBLOCK, XBLOCK], acc_layout)
392+
for i in range(4):
393+
mbarrier.init(bar.index(i), count=1)
394+
395+
blackwell.tcgen05_mma(smemA, smemB.permute([1, 0]), acc.index(0), mbarriers=[bar.index(0),
396+
bar.index(1)],
397+
mbarrier_preds=[False, True])
398+
blackwell.tcgen05_mma(smemA, smemB.permute([1, 0]), acc.index(1), mbarriers=[bar.index(2)])
399+
blackwell.tcgen05_commit(bar.index(3))
400+
401+
mbarrier.wait(bar.index(BAR_IDX), 0)
402+
403+
acc.index(BUF_IDX).store(ttgl.full([XBLOCK, XBLOCK], 42, ttgl.float32, blocked_layout))
404+
405+
for i in range(4):
406+
mbarrier.invalidate(bar.index(i))
407+
408+
408409
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer")
409410
@pytest.mark.parametrize("BUF_IDX", [0, 1])
410411
@pytest.mark.parametrize("BAR_IDX", [0, 1, 2, 3])

0 commit comments

Comments
 (0)