Skip to content

Commit 029056e

Browse files
authored
[Hopper] Verify WarpGroupDotWaitOp has at least 1 dep (#7732)
Also do a bit of code cleanup
1 parent 225f111 commit 029056e

File tree

8 files changed

+61
-31
lines changed

8 files changed

+61
-31
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterface
123123
}];
124124

125125
let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
126+
let hasVerifier = 1;
126127
}
127128

128129
def TTNG_InitBarrierOp : TTNG_Op<"init_barrier"> {

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,18 @@ bool WarpGroupDotOp::verifyDims() {
135135

136136
// -- WarpGroupDotWaitOp --
137137
LogicalResult WarpGroupDotWaitOp::inferReturnTypes(
138-
::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
139-
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
140-
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
141-
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
138+
MLIRContext *context, std::optional<Location> location, ValueRange operands,
139+
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
140+
SmallVectorImpl<Type> &inferredReturnTypes) {
142141
for (Value operand : operands)
143142
inferredReturnTypes.push_back(operand.getType());
144-
return mlir::success();
143+
return success();
144+
}
145+
146+
LogicalResult WarpGroupDotWaitOp::verify() {
147+
if (getOperands().empty())
148+
return emitOpError("expected to be waiting on at least one dependency");
149+
return success();
145150
}
146151

147152
// -- InitBarrierOp --

python/src/gluon_ir.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,10 @@ void init_gluon_ir(py::module &&m) {
485485
})
486486
.def("create_warpgroup_mma_wait",
487487
[](GluonOpBuilder &self, std::vector<Value> &deps, int pendings) {
488-
self.create<ttng::WarpGroupDotWaitOp>(deps, pendings);
488+
std::vector<Value> results;
489+
auto wait = self.create<ttng::WarpGroupDotWaitOp>(deps, pendings);
490+
llvm::append_range(results, wait.getResults());
491+
return results;
489492
})
490493
.def("create_tmem_alloc",
491494
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {

python/test/gluon/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttg
124124
acc = hopper.warpgroup_mma(a_shmem, b_shmem, acc, is_async=ASYNC)
125125

126126
if ASYNC:
127-
hopper.warpgroup_mma_wait(num_outstanding=1, deps=[acc])
127+
acc = hopper.warpgroup_mma_wait(num_outstanding=1, deps=[acc])
128128

129129
ttgl.store(out + out_offs_m * N + out_offs_n, acc)
130130

python/test/gluon/test_frontend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def test_warpgroup_mma():
618618
def warpgroup_mma_wait_kernel():
619619
layout: ttgl.constexpr = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16])
620620
acc = ttgl.full([128, 128], 0, dtype=ttgl.float16, layout=layout)
621-
hopper.warpgroup_mma_wait(num_outstanding=1, deps=[acc])
621+
acc = hopper.warpgroup_mma_wait(num_outstanding=1, deps=[acc])
622622

623623

624624
def test_warpgroup_mma_wait():

python/triton/experimental/gluon/language/nvidia/hopper/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from triton.compiler.code_generator import unflatten_ir_values
12
from ..ampere import async_copy
23
from . import mbarrier, tma
34
from ... import _core
@@ -70,6 +71,10 @@ def warpgroup_mma_wait(num_outstanding=0, deps=None, _semantic=None):
7071
num_outstanding (int): Number of outstanding warpgroup MMA operations to wait for. Defaults to 0.
7172
deps (Sequence[tensor]): List of dependencies that need to be kept alive while the mma is unfinished.
7273
"""
73-
deps = [x.handle for x in deps] if deps is not None else []
74+
deps_handles = [x.handle for x in deps] if deps is not None else []
7475
num_outstanding = _core._unwrap_if_constexpr(num_outstanding)
75-
_semantic.builder.create_warpgroup_mma_wait(deps, num_outstanding)
76+
results = _semantic.builder.create_warpgroup_mma_wait(deps_handles, num_outstanding)
77+
results = tuple(unflatten_ir_values(results, [dep.type for dep in deps]))
78+
if len(results) == 1:
79+
return results[0]
80+
return tuple(results)

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,24 @@ module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scr
538538
tt.return
539539
}
540540
}
541+
542+
// -----
543+
544+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
545+
546+
module attributes {"ttg.target" = "cuda:90", "ttg.num-warps" = 4 : i32} {
547+
548+
// CHECK-LABEL: @warpgroup_dot_wait_1_input
549+
tt.func @warpgroup_dot_wait_1_input(%arg0: tensor<128xf32, #blocked>) {
550+
// CHECK: nvgpu.wgmma_wait_group
551+
ttng.warp_group_dot_wait %arg0 {pendings = 0 : i32} : tensor<128xf32, #blocked>
552+
tt.return
553+
}
554+
555+
tt.func @warpgroup_dot_wait_2_inputs(%arg0: tensor<128xf32, #blocked>, %arg1: tensor<128xf32, #blocked>) {
556+
// CHECK: nvgpu.wgmma_wait_group
557+
ttng.warp_group_dot_wait %arg0, %arg1 {pendings = 0 : i32} : tensor<128xf32, #blocked>, tensor<128xf32, #blocked>
558+
tt.return
559+
}
560+
561+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -108,31 +108,28 @@ struct WarpGroupDotWaitOpConversion
108108
ConversionPatternRewriter &rewriter) const override {
109109
auto pendings = op.getPendings();
110110
Location loc = op.getLoc();
111-
if (adaptor.getInputs().size() <= 1) {
112-
Value input =
113-
adaptor.getInputs().size() == 1 ? adaptor.getInputs()[0] : Value();
114-
rewriter.replaceOpWithNewOp<triton::nvgpu::WGMMAWaitGroupOp>(op, input,
115-
pendings);
111+
ValueRange inputs = adaptor.getInputs();
112+
if (inputs.size() == 1) {
113+
rewriter.replaceOpWithNewOp<triton::nvgpu::WGMMAWaitGroupOp>(
114+
op, inputs.front(), pendings);
116115
return success();
117116
}
118-
std::vector<Type> types;
117+
SmallVector<Type> types;
119118
// Pack the inputs into a single struct.
120-
for (Value input : adaptor.getInputs()) {
121-
auto structType = dyn_cast<LLVM::LLVMStructType>(input.getType());
119+
for (Type type : inputs.getTypes()) {
120+
auto structType = dyn_cast<LLVM::LLVMStructType>(type);
122121
if (!structType)
123122
return failure();
124-
for (Type type : structType.getBody())
125-
types.push_back(type);
123+
llvm::append_range(types, structType.getBody());
126124
}
127125
auto packedType =
128126
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
129127
Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
130128
unsigned outputStructIndex = 0;
131-
for (Value input : adaptor.getInputs()) {
132-
auto structType = dyn_cast<LLVM::LLVMStructType>(input.getType());
133-
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
134-
Value value = rewriter.create<LLVM::ExtractValueOp>(
135-
loc, structType.getBody()[i], input, i);
129+
for (Value input : inputs) {
130+
for (auto [i, type] : llvm::enumerate(
131+
cast<LLVM::LLVMStructType>(input.getType()).getBody())) {
132+
Value value = rewriter.create<LLVM::ExtractValueOp>(loc, input, i);
136133
packed = rewriter.create<LLVM::InsertValueOp>(
137134
loc, packedType, packed, value, outputStructIndex++);
138135
}
@@ -142,14 +139,12 @@ struct WarpGroupDotWaitOpConversion
142139
// Unpack the output into the original struct types.
143140
SmallVector<Value> outputs;
144141
outputStructIndex = 0;
145-
for (Value input : adaptor.getInputs()) {
146-
auto structType = cast<LLVM::LLVMStructType>(input.getType());
142+
for (Type type : inputs.getTypes()) {
143+
auto structType = cast<LLVM::LLVMStructType>(type);
147144
Value unpacked = rewriter.create<LLVM::UndefOp>(loc, structType);
148-
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
145+
for (auto [i, type] : llvm::enumerate(structType.getBody())) {
149146
Value value = rewriter.create<LLVM::ExtractValueOp>(
150-
loc, packedType.getBody()[outputStructIndex], packedOutput,
151-
outputStructIndex);
152-
outputStructIndex++;
147+
loc, packedOutput, outputStructIndex++);
153148
unpacked = rewriter.create<LLVM::InsertValueOp>(loc, structType,
154149
unpacked, value, i);
155150
}

0 commit comments

Comments
 (0)