Skip to content

Commit ecb52e4

Browse files
Merge OpenAI Triton commit 34758e4 (#4552)
This PR change the Triton base from 0e9706c to 34758e4 (Jun 18). Pass rate: 97.12%
2 parents 0275e97 + 5ef3686 commit ecb52e4

File tree

34 files changed

+636
-406
lines changed

34 files changed

+636
-406
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [
8888
let description = [{
8989
This operation copies data from global memory to local memory asynchronously.
9090
This is analogue to tt.load except the data are copied to local memory pointed
91-
by by the memory descriptor instead of a distributed tensor. The rest of the
91+
to by the memory descriptor instead of a distributed tensor. The rest of the
9292
operands are the same as tt.load.
9393
}];
9494

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,15 @@ def TTNG_ArriveBarrierOp : TTNG_Op<"arrive_barrier"> {
262262
let hasVerifier = 1;
263263
}
264264

265+
def TTNG_AsyncCopyMbarrierArriveOp : TTNG_Op<"async_copy_mbarrier_arrive"> {
266+
let summary = "arrive on mbarrier once all previously issued copies are completed";
267+
let arguments = (ins
268+
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$barrier,
269+
UnitAttr:$noIncrement
270+
);
271+
let assemblyFormat = "$barrier attr-dict `:` qualified(type($barrier))";
272+
}
273+
265274

266275
def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local"> {
267276
let summary = "copy data based on descriptor from global memory to local memory asynchronously";

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
1515
// clang-format off
1616
"AMDGCN_ENABLE_DUMP",
1717
"AMDGCN_USE_BUFFER_OPS",
18-
"DISABLE_FAST_REDUCTION",
1918
"DISABLE_LLVM_OPT",
2019
"DISABLE_MMA_V3",
2120
"DISABLE_MMA_V5",
@@ -30,7 +29,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3029
"MLIR_DISABLE_MULTITHREADING",
3130
"TRITON_DEFAULT_FP_FUSION",
3231
"TRITON_DISABLE_LINE_INFO",
33-
"TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE",
3432
"TRITON_ENABLE_LLVM_DEBUG",
3533
"TRITON_HIP_GLOBAL_PREFETCH",
3634
"TRITON_HIP_LOCAL_PREFETCH",
@@ -42,7 +40,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4240
"TRITON_OVERRIDE_ARCH",
4341
"USE_IR_LOC",
4442
"NVPTX_ENABLE_DUMP",
45-
"STORE_TMEM_TO_GLOBAL_BYPASS_SMEM",
4643
"ALLOW_LHS_TMEM_LAYOUT_CONVERSION",
4744
"TRITON_F32_DEFAULT",
4845
"TRITON_PREFER_TMEM_16x256_LAYOUT",

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
8484
// Split a block after the call.
8585
Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator());
8686
rewriter.setInsertionPointToEnd(ifBlock);
87-
rewriter.create<cf::BranchOp>(loc, thenBlock);
87+
rewriter.create<LLVM::BrOp>(loc, thenBlock);
8888
rewriter.setInsertionPointToEnd(prevBlock);
89-
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, thenBlock);
89+
rewriter.create<LLVM::CondBrOp>(loc, condition, ifBlock, thenBlock);
9090
rewriter.setInsertionPointToStart(thenBlock);
9191
}
9292

lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,12 @@ inline SmallVector<Value> applyCombineOp(Location loc,
9797
thenBlockArgs.push_back(undef);
9898
thenBlock->addArgument(ty, loc);
9999
}
100-
rewriter.create<cf::CondBranchOp>(loc, pred, &newCombine, combineArgs,
101-
thenBlock, thenBlockArgs);
100+
rewriter.create<LLVM::CondBrOp>(loc, pred, &newCombine, combineArgs,
101+
thenBlock, thenBlockArgs);
102102

103103
// Split a block after the call.
104104
rewriter.setInsertionPointToEnd(&newCombine);
105-
rewriter.replaceOpWithNewOp<cf::BranchOp>(returnOp, thenBlock, results);
105+
rewriter.replaceOpWithNewOp<LLVM::BrOp>(returnOp, results, thenBlock);
106106
rewriter.setInsertionPointToStart(thenBlock);
107107
return SmallVector<Value>(thenBlock->getArguments());
108108
}

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2277,13 +2277,6 @@ struct TritonGPUInferLayoutInterface
22772277
return success();
22782278
}
22792279

2280-
// Feature flag to disable this routine while it's relatively new.
2281-
// TODO(jlebar): Remove this once we're confident in the code.
2282-
if (triton::tools::getBoolEnv(
2283-
"TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE")) {
2284-
return failure();
2285-
}
2286-
22872280
// Cowardly refuse to handle encodings with multiple CTAs. CTAsPerCGA
22882281
// should be like the other fields in blocked encoding, but I'm not sure how
22892282
// to handle CTASplitNum.

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct PipelinedLoad {
4747

4848
SmallVector<Operation *, 1> allocOps;
4949
SmallVector<Operation *, 1> liveBeforeOps;
50-
SmallVector<Operation *, 0> liveUntilOps;
50+
SmallVector<std::pair<Operation *, bool>, 0> liveUntilOps;
5151
SmallVector<Operation *, 1> asyncUsers;
5252
};
5353

@@ -252,8 +252,6 @@ LogicalResult PipelinedLoad::determineLiveRange(Block &container,
252252
// memory must be live until after this operation.
253253
Operation *lastShmemSink =
254254
findNearestCommonPostDominator(shmemTerminals, postDomInfo);
255-
if (lastShmemSink)
256-
lastShmemSink = lastShmemSink->getNextNode();
257255

258256
// The memory only needs to be live until before the first register user.
259257
Operation *liveUntilReg = findNearestCommonDominator(regSink, domInfo);
@@ -262,14 +260,16 @@ LogicalResult PipelinedLoad::determineLiveRange(Block &container,
262260

263261
// The memory is live until before the first register user or after the last
264262
// shmem terminal, whichever is later.
265-
Operation *liveUntilOp;
263+
std::pair<Operation *, bool> liveUntilOp{nullptr, false};
266264
if (lastShmemSink && liveUntilReg) {
267-
liveUntilOp = liveUntilReg->isBeforeInBlock(lastShmemSink) ? lastShmemSink
268-
: liveUntilReg;
265+
if (liveUntilReg->isBeforeInBlock(lastShmemSink))
266+
liveUntilOp = {lastShmemSink, /*after=*/true};
267+
else
268+
liveUntilOp = {liveUntilReg, /*after=*/false};
269269
} else if (liveUntilReg) {
270-
liveUntilOp = liveUntilReg;
270+
liveUntilOp = {liveUntilReg, /*after=*/false};
271271
} else {
272-
liveUntilOp = lastShmemSink;
272+
liveUntilOp = {lastShmemSink, /*after=*/true};
273273
}
274274
liveUntilOps.push_back(liveUntilOp);
275275
}
@@ -316,7 +316,7 @@ void PipelinedLoadGroup::allocateAref(scf::ForOp &loop, int numStages) {
316316
for (PipelinedLoad &load : loads) {
317317
distinctAsyncUsers.insert(load.asyncUsers.begin(), load.asyncUsers.end());
318318
int numLiveUntil =
319-
llvm::count_if(load.liveUntilOps, [](Operation *op) { return !!op; });
319+
llvm::count_if(load.liveUntilOps, [](auto p) { return !!p.first; });
320320
maxLiveUntil = std::max(maxLiveUntil, numLiveUntil);
321321
}
322322
int arriveCount = distinctAsyncUsers.size() + maxLiveUntil;
@@ -390,8 +390,11 @@ LogicalResult PipelinedLoadGroup::lowerLoads(WarpSchedule &schedule,
390390

391391
SmallVector<Operation *> liveUntilOps;
392392
for (PipelinedLoad &load : loads) {
393-
if (Operation *liveUntilOp = load.liveUntilOps[i])
394-
liveUntilOps.push_back(liveUntilOp);
393+
auto [liveUntilOp, after] = load.liveUntilOps[i];
394+
if (liveUntilOp) {
395+
liveUntilOps.push_back(after ? liveUntilOp->getNextNode()
396+
: liveUntilOp);
397+
}
395398
}
396399
if (!liveUntilOps.empty()) {
397400
Operation *liveUntilOp =

python/src/gluon_ir.cc

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ py::object layoutToGluon(Attribute layout) {
130130
return layouts.DistributedLinearLayout(
131131
ll.getBases().lookup(kReg), ll.getBases().lookup(kLane),
132132
ll.getBases().lookup(kWarp), ll.getBases().lookup(kBlock),
133-
ll.getOutDimSizes());
133+
toStdVector(ArrayRef(llvm::to_vector(ll.getOutDimSizes()))));
134134
} else if (auto nvmma = dyn_cast<ttg::NVMMASharedEncodingAttr>(layout)) {
135135
auto ctaLayout = nvmma.getCTALayout();
136136
return layouts.NVMMASharedLayout(
@@ -279,6 +279,29 @@ void init_gluon_ir(py::module &&m) {
279279
blockTy.getShape(), blockTy.getElementType(), layout);
280280
return triton::TensorDescType::get(ctx, blockTyLayout, isSigned);
281281
})
282+
.def("create_async_copy_global_to_local",
283+
[](GluonOpBuilder &self, Value smem, Value pointer, Value mask,
284+
tt::CacheModifier cacheModifier,
285+
tt::EvictionPolicy evictionPolicy, bool isVolatile) {
286+
self.create<ttg::AsyncCopyGlobalToLocalOp>(
287+
pointer, smem, mask, /*other*/ Value{}, cacheModifier,
288+
evictionPolicy, isVolatile);
289+
})
290+
.def("create_async_copy_mbarrier_arrive",
291+
[](GluonOpBuilder &self, Value mbarrier, bool incrementCount) {
292+
self.create<ttng::AsyncCopyMbarrierArriveOp>(mbarrier,
293+
!incrementCount);
294+
})
295+
.def("create_async_commit_group",
296+
[](GluonOpBuilder &self) {
297+
ValueRange tokens;
298+
self.create<ttg::AsyncCommitGroupOp>(tokens);
299+
})
300+
.def("create_async_wait_group",
301+
[](GluonOpBuilder &self, int num) {
302+
ValueRange tokens;
303+
self.create<ttg::AsyncWaitOp>(tokens, num);
304+
})
282305
.def("create_convert_layout",
283306
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
284307
return self.create<ttg::ConvertLayoutOp>(resultTy, value);

python/test/gluon/test_core.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
22
import pytest
33

4-
from triton._internal_testing import is_cuda
4+
from triton._internal_testing import is_ampere_or_newer, is_hopper
55
from triton.experimental import gluon
66
from triton.experimental.gluon import language as ttgl
7+
from triton.experimental.gluon.language.nvidia.ampere import async_copy, mbarrier
78
from triton.experimental.gluon.language.nvidia.hopper import tma
89

910

@@ -45,7 +46,7 @@ def tma_kernel(desc):
4546
alloc._keep_alive()
4647

4748

48-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires Hopper")
49+
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper")
4950
def test_tma():
5051
out = torch.ones((16, 16), dtype=torch.float16, device="cuda")
5152
layout = ttgl.NVMMASharedLayout(
@@ -59,3 +60,36 @@ def test_tma():
5960
desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(out, [16, 16], layout)
6061
tma_kernel[(1, )](desc)
6162
torch.testing.assert_close(out, torch.zeros_like(out))
63+
64+
65+
@gluon.jit
66+
def async_copy_mbarrier_kernel(out, inp, xnumel, XBLOCK: ttgl.constexpr, YBLOCK: ttgl.constexpr):
67+
smem = ttgl.allocate_shared_memory(inp.dtype.element_ty, [XBLOCK, YBLOCK],
68+
ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]))
69+
block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0])
70+
xindex = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(1, block_layout))[:, None]
71+
yindex = ttgl.arange(0, YBLOCK, ttgl.SliceLayout(0, block_layout))[None, :]
72+
mask = xindex < xnumel
73+
async_copy.async_copy_global_to_shared(
74+
smem,
75+
inp + xindex * YBLOCK + yindex,
76+
mask,
77+
)
78+
mbar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
79+
mbarrier.init(mbar, count=1)
80+
async_copy.mbarrier_arrive(mbar)
81+
mbarrier.arrive(mbar)
82+
mbarrier.wait(mbar, 0)
83+
84+
val = smem.load(block_layout)
85+
ttgl.store(out + xindex * YBLOCK + yindex, val)
86+
87+
88+
@pytest.mark.skipif(not is_ampere_or_newer(), reason="Requires Ampere")
89+
def test_async_copy_mbarrier():
90+
tensor_opts = dict(dtype=torch.float, device="cuda")
91+
out = torch.empty((32, 32), **tensor_opts)
92+
inp = torch.randn((20, 32), **tensor_opts)
93+
async_copy_mbarrier_kernel[(1, )](out, inp, inp.shape[0], XBLOCK=32, YBLOCK=32)
94+
torch.testing.assert_close(out[:20], inp)
95+
torch.testing.assert_close(out[20:], torch.zeros((12, 32), **tensor_opts))

0 commit comments

Comments
 (0)