Skip to content

Commit 43d1349

Browse files
manman-renhtyu
andauthored
Pick WarpSpec PRs and fixes to 3.4 release (triton-lang#7462)
Including: - Update pipeline to get GEMM/FA working (triton-lang#7136) - Use required layout for buffers (triton-lang#7284) - Add back support of async_task Verified GEMM + WarpSpec and tma_ws TritonBench FA --------- Co-authored-by: Hongtao Yu <[email protected]>
1 parent 3ba7d6d commit 43d1349

File tree

20 files changed

+697
-85
lines changed

20 files changed

+697
-85
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ namespace gpu {
3333
#define GEN_PASS_DEF_TRITONGPUPIPELINE
3434
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
3535

36-
static void pipelineWgmma(ModuleOp moduleOp) {
36+
static void pipelineWgmma(ModuleOp moduleOp, unsigned numStages) {
3737
SmallVector<scf::ForOp> loops;
3838
moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); });
3939

4040
for (scf::ForOp forOp : loops) {
41-
mlir::triton::asyncLaunchDots(forOp);
41+
if (getNumStagesOrDefault(forOp, numStages) >= 1)
42+
mlir::triton::asyncLaunchDots(forOp);
4243
}
4344
}
4445

@@ -223,7 +224,6 @@ struct PipelinePass : public impl::TritonGPUPipelineBase<PipelinePass> {
223224

224225
void runOnOperation() override {
225226
ModuleOp moduleOp = getOperation();
226-
227227
// Transform the loop by introducing async operations to prepare it for
228228
// pipeline expansion.
229229
lowerLoops(moduleOp);
@@ -244,7 +244,7 @@ struct PipelinePass : public impl::TritonGPUPipelineBase<PipelinePass> {
244244
// Cleanup the IR from the pipeline attributes.
245245
removeAttributes(moduleOp);
246246

247-
pipelineWgmma(moduleOp);
247+
pipelineWgmma(moduleOp, numStages);
248248

249249
// schedule the waits
250250
mlir::triton::updateWaits(getOperation());

python/src/ir.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,20 @@
3838

3939
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
4040

41+
#include "llvm/ADT/SmallVector.h"
42+
43+
void setAsyncTaskIds(mlir::Operation *op,
44+
llvm::ArrayRef<AsyncTaskId> asyncTaskIds) {
45+
llvm::SmallVector<AsyncTaskId> sortedAsyncTaskIds(asyncTaskIds.begin(),
46+
asyncTaskIds.end());
47+
sort(sortedAsyncTaskIds);
48+
auto i32Ty = IntegerType::get(op->getContext(), 32);
49+
auto size = static_cast<int64_t>(sortedAsyncTaskIds.size());
50+
auto vecTy = VectorType::get(size, i32Ty);
51+
op->setAttr("async_task_id",
52+
DenseI32ArrayAttr::get(op->getContext(), sortedAsyncTaskIds));
53+
}
54+
4155
namespace {
4256

4357
namespace py = pybind11;
@@ -744,6 +758,12 @@ void init_triton_ir(py::module &&m) {
744758
[](TritonOpBuilder &self, OpBuilder::InsertPoint pt) {
745759
self.restoreInsertionPoint(pt);
746760
})
761+
.def("set_async_task_ids",
762+
[](TritonOpBuilder &self, std::vector<int> v) {
763+
self.setAsyncTaskIds(v);
764+
})
765+
.def("unset_async_task_ids",
766+
[](TritonOpBuilder &self) { self.unsetAsyncTaskIds(); })
747767
// Attr
748768
.def(
749769
"get_unit_attr",

python/src/ir.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
#pragma once
22
#include "mlir/IR/Builders.h"
33
#include "triton/Tools/Sys/GetEnv.hpp"
4+
#include "llvm/ADT/ArrayRef.h"
45
#include <memory>
56

7+
typedef int AsyncTaskId;
8+
void setAsyncTaskIds(mlir::Operation *op,
9+
llvm::ArrayRef<AsyncTaskId> asyncTaskIds);
10+
611
// A custom op builder that keeps track of the last location
712
class TritonOpBuilder {
813
public:
@@ -62,7 +67,10 @@ class TritonOpBuilder {
6267

6368
template <typename OpTy, typename... Args> OpTy create(Args &&...args) {
6469
auto loc = getLastLoc();
65-
return builder->create<OpTy>(loc, std::forward<Args>(args)...);
70+
auto ret = builder->create<OpTy>(loc, std::forward<Args>(args)...);
71+
if (asyncTaskIds)
72+
::setAsyncTaskIds(ret, *asyncTaskIds);
73+
return ret;
6674
}
6775

6876
// Overload to create or fold a single result operation.
@@ -82,9 +90,16 @@ class TritonOpBuilder {
8290
return builder->createOrFold<OpTy>(loc, std::forward<Args>(args)...);
8391
}
8492

93+
void setAsyncTaskIds(std::vector<int> taskIds) {
94+
this->asyncTaskIds = taskIds;
95+
}
96+
97+
void unsetAsyncTaskIds() { this->asyncTaskIds = std::nullopt; }
98+
8599
private:
86100
std::unique_ptr<mlir::OpBuilder> builder;
87101
std::unique_ptr<mlir::Location> lastLoc;
102+
std::optional<std::vector<int>> asyncTaskIds;
88103
bool lineInfoEnabled =
89104
!mlir::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
90105
};

python/triton/compiler/code_generator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,20 @@ def _verify_loop_carried_variable(self, name, loop_val, live_val):
926926
f'but is re-assigned to {loop_val.type} in loop! '\
927927
f'Please make sure that the type stays consistent.'
928928

929+
def visit_withitem(self, node):
930+
return self.visit(node.context_expr)
931+
932+
def visit_With(self, node):
933+
assert len(node.items) == 1
934+
context = node.items[0].context_expr
935+
withitemClass = self.visit(context.func)
936+
if withitemClass == language.async_task:
937+
args = [self.visit(arg) for arg in context.args]
938+
with withitemClass(*args, _builder=self.builder):
939+
self.visit_compound_statement(node.body)
940+
else:
941+
self.visit_compound_statement(node.body)
942+
929943
def visit_While(self, node):
930944
with enter_sub_region(self) as sr:
931945
liveins, insert_block = sr

python/triton/language/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
arange,
4040
associative_scan,
4141
assume,
42+
async_task,
4243
atomic_add,
4344
atomic_and,
4445
atomic_cas,
@@ -145,6 +146,7 @@
145146
"argmin",
146147
"associative_scan",
147148
"assume",
149+
"async_task",
148150
"atomic_add",
149151
"atomic_and",
150152
"atomic_cas",

python/triton/language/core.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3138,6 +3138,22 @@ def __next__(self):
31383138
raise RuntimeError("static_range can only be used in @triton.jit'd functions")
31393139

31403140

3141+
class async_task:
3142+
"""
3143+
Context manager to run code fragments asynchronously.
3144+
"""
3145+
3146+
def __init__(self, task_ids, _builder=None):
3147+
self.task_ids = list({_unwrap_if_constexpr(tid) for tid in task_ids})
3148+
self.builder = _builder
3149+
3150+
def __enter__(self):
3151+
self.builder.set_async_task_ids(self.task_ids)
3152+
3153+
def __exit__(self, exc_type, exc_value, traceback):
3154+
self.builder.unset_async_task_ids()
3155+
3156+
31413157
class range:
31423158
"""
31433159
Iterator that counts upward forever.

python/tutorials/09-persistent-matmul.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,12 @@ def supports_tma():
4747
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
4848

4949

50+
def is_hopper():
51+
return torch.cuda.get_device_capability()[0] == 9
52+
53+
5054
def supports_ws():
51-
return is_cuda() and torch.cuda.get_device_capability()[0] >= 10
55+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
5256

5357

5458
def _matmul_launch_metadata(grid, kernel, args):
@@ -465,21 +469,31 @@ def grid(META):
465469
return c
466470

467471

468-
@triton.autotune(
469-
configs=matmul_tma_persistent_get_configs(),
470-
key=["M", "N", "K", "WARP_SPECIALIZE"],
471-
)
472+
def prune_invalid_configs(configs, named_args, **kwargs):
473+
FLATTEN = kwargs["FLATTEN"]
474+
# Filter out configs where EPILOGUE_SUBTILE is true and HOPPER is true
475+
return [conf for conf in configs if not (conf.kwargs.get("EPILOGUE_SUBTILE", True) and FLATTEN is False)]
476+
477+
478+
@triton.autotune(configs=matmul_tma_persistent_get_configs(), key=["M", "N", "K", "WARP_SPECIALIZE", "FLATTEN"],
479+
prune_configs_by={'early_config_prune': prune_invalid_configs})
472480
@triton.jit(launch_metadata=_matmul_launch_metadata)
473-
def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
474-
M, N, K, #
475-
BLOCK_SIZE_M: tl.constexpr, #
476-
BLOCK_SIZE_N: tl.constexpr, #
477-
BLOCK_SIZE_K: tl.constexpr, #
478-
GROUP_SIZE_M: tl.constexpr, #
479-
EPILOGUE_SUBTILE: tl.constexpr, #
480-
NUM_SMS: tl.constexpr, #
481-
WARP_SPECIALIZE: tl.constexpr, #
482-
):
481+
def matmul_kernel_descriptor_persistent(
482+
a_ptr,
483+
b_ptr,
484+
c_ptr, #
485+
M,
486+
N,
487+
K, #
488+
BLOCK_SIZE_M: tl.constexpr, #
489+
BLOCK_SIZE_N: tl.constexpr, #
490+
BLOCK_SIZE_K: tl.constexpr, #
491+
GROUP_SIZE_M: tl.constexpr, #
492+
EPILOGUE_SUBTILE: tl.constexpr, #
493+
NUM_SMS: tl.constexpr, #
494+
WARP_SPECIALIZE: tl.constexpr, #
495+
FLATTEN: tl.constexpr,
496+
):
483497
# Matmul using TMA and device-side descriptor creation
484498
dtype = c_ptr.dtype.element_ty
485499
start_pid = tl.program_id(axis=0)
@@ -512,7 +526,7 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
512526
tile_id_c = start_pid - NUM_SMS
513527
num_pid_in_group = GROUP_SIZE_M * num_pid_n
514528

515-
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE):
529+
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=WARP_SPECIALIZE):
516530
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
517531
offs_am = pid_m * BLOCK_SIZE_M
518532
offs_bn = pid_n * BLOCK_SIZE_N
@@ -560,12 +574,19 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
560574

561575
triton.set_allocator(alloc_fn)
562576

577+
# Hopper warpspec doesn't work with flatten
578+
flatten = False if (warp_specialize and is_hopper()) else True
563579
grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
564580
matmul_kernel_descriptor_persistent[grid](
565-
a, b, c, #
566-
M, N, K, #
581+
a,
582+
b,
583+
c, #
584+
M,
585+
N,
586+
K, #
567587
NUM_SMS=NUM_SMS, #
568588
WARP_SPECIALIZE=warp_specialize, #
589+
FLATTEN=flatten,
569590
)
570591
return c
571592

@@ -632,7 +653,8 @@ def bench(K, dtype, reps=10000, warmup_reps=10000):
632653
warp_specialize = [False, True] if HAS_WARP_SPECIALIZE else [False]
633654
for ws in warp_specialize:
634655
ws_str = "_ws" if ws else ""
635-
if HAS_HOST_TENSOR_DESC:
656+
# disable on-host warpspec on Hopper
657+
if HAS_HOST_TENSOR_DESC and not (is_hopper() and ws):
636658
bench_fn(f"tma_persistent{ws_str}", reps, warmup_reps, lambda a, b: matmul_tma_persistent(a, b, ws), a, b)
637659
bench_fn(f"tma{ws_str}", reps, warmup_reps, lambda a, b: matmul_tma(a, b, ws), a, b)
638660
if HAS_TENSOR_DESC:
@@ -671,7 +693,9 @@ def validate(M, N, K, dtype):
671693

672694
for (kernel, label, enabled), warp_specialize in itertools.product(kernels, warp_specialize):
673695
label = f"{label} (warp_specialize={warp_specialize})"
674-
enabled = enabled and (not warp_specialize or HAS_TENSOR_DESC)
696+
# skip if hopper and warp_specialize and not on-device
697+
skipped = is_hopper() and warp_specialize and kernel != matmul_descriptor_persistent
698+
enabled = enabled and (not warp_specialize or HAS_TENSOR_DESC) and (not skipped)
675699
run_test(naive_result, lambda a, b: kernel(a, b, warp_specialize), a, b, label, enabled)
676700
print()
677701

test/Hopper/WarpSpecialization/ws_code_partition.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,48 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
260260
tt.return
261261
}
262262
}
263+
264+
265+
// -----
266+
267+
// CHECK-DAG: #[[$SHARED:.*]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
268+
// CHECK-DAG: #[[$SHARED1:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
269+
// CHECK-LABEL: @_fbgemm_grouped_gemm_fp8_rowwise_ws
270+
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x64x64xf8E4M3FN, #[[$SHARED1]], #smem, mutable>
271+
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf8E4M3FN, #[[$SHARED1]], #smem, mutable>
272+
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x128xf32, #[[$SHARED]], #smem, mutable>
273+
274+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
275+
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
276+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 32]}>
277+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
278+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
279+
#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 8}>
280+
#smem = #ttg.shared_memory
281+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
282+
tt.func public @_fbgemm_grouped_gemm_fp8_rowwise_ws(%arg0: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}, %arg1: i32, %arg2: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}, %arg3: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} {
283+
%c0_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 0 : i32
284+
%c2048_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 2048 : i32
285+
%c64_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 64 : i32
286+
%cst = arith.constant {async_task_id = array<i32: 0, 1, 2>} dense<0.000000e+00> : tensor<64x128xf32, #mma>
287+
%0 = tt.get_program_id x {async_task_id = array<i32: 0, 1, 2>} : i32
288+
%1 = ttng.reinterpret_tensor_descriptor %arg0 {async_task_id = array<i32: 0>} : !tt.ptr<i8, 0> to !tt.tensordesc<tensor<64x64xf8E4M3FN, #shared>>
289+
%2 = ttng.reinterpret_tensor_descriptor %arg2 {async_task_id = array<i32: 0>} : !tt.ptr<i8, 0> to !tt.tensordesc<tensor<128x64xf8E4M3FN, #shared>>
290+
%3 = ttng.reinterpret_tensor_descriptor %arg3 {async_task_id = array<i32: 0>} : !tt.ptr<i8, 0> to !tt.tensordesc<tensor<128xf32, #shared1>>
291+
scf.for %arg4 = %0 to %arg1 step %c64_i32 : i32 {
292+
%4 = arith.muli %arg4, %c2048_i32 {async_task_id = array<i32: 0>} : i32
293+
%5 = scf.for %arg5 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg6 = %cst) -> (tensor<64x128xf32, #mma>) : i32 {
294+
%8 = tt.descriptor_load %1[%4, %arg5] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<64x64xf8E4M3FN, #shared>> -> tensor<64x64xf8E4M3FN, #blocked>
295+
%9 = ttg.local_alloc %8 {async_task_id = array<i32: 1>} : (tensor<64x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<64x64xf8E4M3FN, #shared, #smem>
296+
%10 = tt.descriptor_load %2[%4, %arg5] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<128x64xf8E4M3FN, #shared>> -> tensor<128x64xf8E4M3FN, #blocked>
297+
%11 = ttg.local_alloc %10 {async_task_id = array<i32: 1, 2>} : (tensor<128x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<128x64xf8E4M3FN, #shared, #smem>
298+
%12 = ttg.memdesc_trans %11 {async_task_id = array<i32: 1, 2>, order = array<i32: 1, 0>} : !ttg.memdesc<128x64xf8E4M3FN, #shared, #smem> -> !ttg.memdesc<64x128xf8E4M3FN, #shared2, #smem>
299+
%13 = ttng.warp_group_dot %9, %12, %arg6 {async_task_id = array<i32: 1>, inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<64x64xf8E4M3FN, #shared, #smem> * !ttg.memdesc<64x128xf8E4M3FN, #shared2, #smem> -> tensor<64x128xf32, #mma>
300+
scf.yield {async_task_id = array<i32: 1, 2>} %13 : tensor<64x128xf32, #mma>
301+
} {async_task_id = array<i32: 0, 1, 2>}
302+
%6 = tt.descriptor_load %3[%4] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<128xf32, #shared1>> -> tensor<128xf32, #blocked1>
303+
%7 = ttg.convert_layout %6 {async_task_id = array<i32: 1, 2>} : tensor<128xf32, #blocked1> -> tensor<128xf32, #ttg.slice<{dim = 0, parent = #blocked}>>
304+
} {async_task_id = array<i32: 1, 2>}
305+
tt.return
306+
}
307+
}

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def make_ttgir(mod, metadata, opt, capability):
259259
passes.ttir.add_triton_licm(pm)
260260
passes.common.add_canonicalizer(pm)
261261
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
262+
nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
262263
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
263264
passes.ttgpuir.add_schedule_loops(pm)
264265
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)

third_party/nvidia/hopper/include/Transforms/Passes.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@ def NVGPUWarpSpecialization : Pass<"nvgpu-warp-specialization", "mlir::ModuleOp"
1414

1515
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
1616
let options = [
17-
Option<"numWarpGroups", "num-warp-groups",
17+
Option<"numStages", "num-stages",
1818
"int32_t", /*default*/"0",
19-
"number of warp groups for warp specialization">
19+
"number of buffers for warp specialization">,
20+
Option<"dumpIntermediateSteps", "dump-intermediate-steps",
21+
"bool", /*default*/"false",
22+
"Dump intermediate steps">
2023
];
2124
}
2225

0 commit comments

Comments
 (0)