Skip to content

Commit ffc614d

Browse files
authored
[Hopper][WS] Add TaskIdPropagate pass (#7038)
This change adds a pass to propagate `async_task_id`s from anchor ops using a sparse backward dataflow analysis.
1 parent 859dcf0 commit ffc614d

File tree

8 files changed

+470
-5
lines changed

8 files changed

+470
-5
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: triton-opt %s -split-input-file --nvgpu-test-taskid-propagate=num-warp-groups=2 | FileCheck %s
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
4+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
5+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}>
6+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 0}>
7+
#smem = #ttg.shared_memory
8+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
9+
10+
// CHECK-LABEL: @matmul_persistent_tma_ws_cooperative_kernel
11+
// CHECK: %[[C0:.*]] = arith.constant {async_task_id = array<i32: 0, 1, 2>} 0 : i32
12+
// CHECK-NEXT: %[[C1:.*]] = arith.constant {async_task_id = array<i32: 0, 1, 2>} 1 : i32
13+
// CHECK-NEXT: %[[C64:.*]] = arith.constant {async_task_id = array<i32: 0>} 64 : i32
14+
// CHECK-NEXT: %[[INIT:.*]] = arith.constant {async_task_id = array<i32: 1, 2>} dense<0.000000e+00> : tensor<128x256xf32, #mma>
15+
// CHECK-NEXT: %[[PID:.*]] = tt.get_program_id x {async_task_id = array<i32: 0, 1, 2>} : i32
16+
// CHECK-NEXT: %[[NUM:.*]] = tt.get_num_programs x {async_task_id = array<i32: 0, 1, 2>} : i32
17+
// CHECK-NEXT: scf.for %[[IV:.*]] = %[[PID]] to %[[UB:.*]] step %[[NUM]] : i32 {
18+
// CHECK-NEXT: %[[FOR:.*]]:2 = scf.for %{{.*}} = %[[C0]] to %{{.*}} step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]], %[[OFF:.*]] = %[[C0]])
19+
// CHECK-NEXT: %[[LOAD1:.*]] = tt.descriptor_load %[[INPUT1:.*]][%[[IV]], %[[OFF]]] {async_task_id = array<i32: 0>}
20+
// CHECK-NEXT: %[[ALLOC1:.*]] = ttg.local_alloc %[[LOAD1]] {async_task_id = array<i32: 1, 2>}
21+
// CHECK-NEXT: %[[LOAD2:.*]] = tt.descriptor_load %[[INPUT2:.*]][%[[OFF]], %[[IV]]] {async_task_id = array<i32: 0>}
22+
// CHECK-NEXT: %[[ALLOC2:.*]] = ttg.local_alloc %[[LOAD2]] {async_task_id = array<i32: 1, 2>}
23+
// CHECK-NEXT: %[[DOT:.*]] = ttng.warp_group_dot %[[ALLOC1]], %[[ALLOC2]], %[[ACC]] {async_task_id = array<i32: 1, 2>, inputPrecision = 0 : i32}
24+
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[OFF]], %[[C64]] {async_task_id = array<i32: 0>}
25+
// CHECK-NEXT: scf.yield {async_task_id = array<i32: 0, 1, 2>} %[[DOT]], %[[ADD]]
26+
// CHECK-NEXT: } {async_task_id = array<i32: 0, 1, 2>}
27+
// CHECK-NEXT: arith.truncf %[[FOR]]#0 {async_task_id = array<i32: 1, 2>}
28+
// CHECK-NEXT: ttg.convert_layout %{{.*}} {async_task_id = array<i32: 1, 2>}
29+
// CHECK-NEXT: tt.descriptor_store %[[OUTPUT:.*]][%[[IV]], %[[IV]]], %{{.*}} {async_task_id = array<i32: 1, 2>}
30+
// CHECK-NEXT: } {async_task_id = array<i32: 0, 1, 2>}
31+
32+
tt.func public @matmul_persistent_tma_ws_cooperative_kernel(%arg0: !tt.tensordesc<tensor<128x64xf16>>, %arg1: !tt.tensordesc<tensor<64x256xf16>>, %arg2: !tt.tensordesc<tensor<128x256xf16>>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
33+
%c0_i32 = arith.constant 0 : i32
34+
%c1_i32 = arith.constant 1 : i32
35+
%c64_i32 = arith.constant 64 : i32
36+
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
37+
%0 = tt.get_program_id x : i32
38+
%1 = tt.get_num_programs x : i32
39+
scf.for %arg6 = %0 to %arg3 step %1 : i32 {
40+
%2:2 = scf.for %arg7 = %c0_i32 to %arg5 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 {
41+
%5 = tt.descriptor_load %arg0[%arg6, %arg9] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<128x64xf16>> -> tensor<128x64xf16, #blocked>
42+
%6 = ttg.local_alloc %5 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
43+
%7 = tt.descriptor_load %arg1[%arg9, %arg6] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<64x256xf16>> -> tensor<64x256xf16, #blocked1>
44+
%8 = ttg.local_alloc %7 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #smem>
45+
%9 = ttng.warp_group_dot %6, %8, %arg8 {async_task_id = array<i32: 1, 2>, inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma>
46+
%10 = arith.addi %arg9, %c64_i32 : i32
47+
scf.yield %9, %10 : tensor<128x256xf32, #mma>, i32
48+
}
49+
%3 = arith.truncf %2#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma>
50+
%4 = ttg.convert_layout %3 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1>
51+
tt.descriptor_store %arg2[%arg6, %arg6], %4 {async_task_id = array<i32: 1, 2>} : !tt.tensordesc<tensor<128x256xf16>>, tensor<128x256xf16, #blocked1>
52+
}
53+
tt.return
54+
}
55+
}

third_party/nvidia/hopper/include/Transforms/Passes.td

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
include "mlir/Pass/PassBase.td"
55

66
def NVGPUWarpSpecialization : Pass<"nvgpu-warp-specialization", "mlir::ModuleOp"> {
7-
let summary = "Automaticl Warp specialization for NVIDIA GPU";
7+
let summary = "Automatic Warp specialization for NVIDIA GPU";
88

99
let description = [{
1010
This pass automatically partitions user-defined kernels into
@@ -33,6 +33,24 @@ def NVGPUTestWSTaskPartition : Pass<"nvgpu-test-ws-task-partition", "mlir::Modul
3333
];
3434
}
3535

36+
def NVGPUTestWSTaskIdPropagate : Pass<"nvgpu-test-taskid-propagate", "mlir::ModuleOp"> {
37+
let summary = "test warp specialization task id propagation";
38+
39+
let description = [{
40+
This pass propagates the `async_task_id` annotation to the dependencies
41+
of any op that has it set. This has the functional effect of partitioning
42+
the graph into multiple async tasks, based on the initial annotation.
43+
}];
44+
45+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
46+
47+
let options = [
48+
Option<"numWarpGroups", "num-warp-groups",
49+
"int32_t", /*default*/"0",
50+
"number of warp groups for warp specialization">
51+
];
52+
}
53+
3654
def NVGPUTestWSDataPartition : Pass<"nvgpu-test-ws-data-partition", "mlir::ModuleOp"> {
3755
let summary = "test warp specialization data partition";
3856

third_party/nvidia/hopper/lib/Transforms/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
add_triton_library(NVHopperTransforms
22
WarpSpecialization.cpp
33
WarpSpecialization/CodePartitionUtility.cpp
4+
WarpSpecialization/TaskIdPropagation.cpp
5+
WarpSpecialization/Utility.cpp
46
WarpSpecialization/WSBuffer.cpp
57
WarpSpecialization/WSCodePartition.cpp
8+
WarpSpecialization/WSDataPartition.cpp
69
WarpSpecialization/WSLowerMem.cpp
710
WarpSpecialization/WSSpecialize.cpp
8-
WarpSpecialization/Utility.cpp
9-
WarpSpecialization/WSDataPartition.cpp
11+
WarpSpecialization/WSTaskIdPropagate.cpp
1012
WarpSpecialization/WSTaskPartition.cpp
1113

1214
DEPENDS
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#include "TaskIdPropagation.h"
2+
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
3+
#include "mlir/Analysis/DataFlowFramework.h"
4+
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
#include "mlir/Support/LLVM.h"
6+
#include "nvidia/hopper/lib/Transforms/WarpSpecialization/Utility.h"
7+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
8+
#include "llvm/ADT/STLExtras.h"
9+
#include "llvm/Support/ErrorHandling.h"
10+
#include "llvm/Support/raw_ostream.h"
11+
12+
#define DEBUG_TYPE "task-id-propagation"
13+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
14+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
15+
16+
using namespace mlir;
17+
using namespace mlir::dataflow;
18+
19+
namespace mlir::triton::gpu {
20+
21+
//===----------------------------------------------------------------------===//
22+
// TaskId
23+
//===----------------------------------------------------------------------===//
24+
25+
void TaskId::print(raw_ostream &os) const {
26+
if (isUninitialized()) {
27+
os << "<UNINITIALIZED>";
28+
return;
29+
}
30+
if (isUnknown()) {
31+
os << "<UNKNOWN>";
32+
return;
33+
}
34+
return getTaskIds().print(os);
35+
}
36+
37+
TaskId TaskId::join(const TaskId &lhs, const TaskId &rhs) {
38+
return TaskId::getUnknownTaskId();
39+
}
40+
41+
TaskId TaskId::meet(const TaskId &lhs, const TaskId &rhs) {
42+
if (lhs.isUnknown() || rhs.isUnknown())
43+
return TaskId::getUnknownTaskId();
44+
if (lhs.isUninitialized())
45+
return rhs;
46+
if (rhs.isUninitialized())
47+
return lhs;
48+
if (lhs == rhs)
49+
return lhs;
50+
51+
auto context = lhs.getTaskIds().getContext();
52+
auto lhsTasks = lhs.getTaskIds().asArrayRef();
53+
auto rhsTasks = rhs.getTaskIds().asArrayRef();
54+
// Meet the task ids by merging and deduplicating them
55+
SmallVector<AsyncTaskId> result(lhsTasks.begin(), lhsTasks.end());
56+
result.insert(result.end(), rhsTasks.begin(), rhsTasks.end());
57+
std::sort(result.begin(), result.end());
58+
result.erase(std::unique(result.begin(), result.end()), result.end());
59+
auto mergedAndDedupedTaskIds =
60+
TaskId(DenseI32ArrayAttr::get(context, ArrayRef<AsyncTaskId>(result)));
61+
return mergedAndDedupedTaskIds;
62+
}
63+
64+
//===----------------------------------------------------------------------===//
65+
// TaskIdBackwardPropagation
66+
//===----------------------------------------------------------------------===//
67+
68+
void TaskIdBackwardPropagation::propagateToYield(
69+
scf::YieldOp yieldOp, SmallVector<TaskId> &lattices) {
70+
for (auto [lattice, yieldOperand] :
71+
llvm::zip_equal(lattices, yieldOp->getOperands())) {
72+
auto yieldLattice = getLatticeElement(yieldOperand);
73+
ChangeResult changed = yieldLattice->meet(lattice);
74+
propagateIfChanged(yieldLattice, changed);
75+
}
76+
}
77+
78+
void TaskIdBackwardPropagation::propagateToParent(Operation *op,
79+
const TaskId &taskId) {
80+
auto parentOp = op->getParentOp();
81+
while (parentOp && !isa<triton::FuncOp>(parentOp)) {
82+
if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
83+
// Propagate to the control operands of the for op.
84+
for (auto controlOperand :
85+
forOp.getOperands().take_front(forOp.getNumControlOperands())) {
86+
auto controlLattice = getLatticeElement(controlOperand);
87+
ChangeResult changed = controlLattice->meet(taskId);
88+
propagateIfChanged(controlLattice, changed);
89+
}
90+
} else if (auto ifOp = dyn_cast<scf::IfOp>(parentOp)) {
91+
auto cond = ifOp.getCondition();
92+
auto condLattice = getLatticeElement(cond);
93+
ChangeResult changed = condLattice->meet(taskId);
94+
propagateIfChanged(condLattice, changed);
95+
} else {
96+
if (!isa<triton::FuncOp>(parentOp))
97+
llvm_unreachable("Other parent ops are not supported.");
98+
}
99+
parentOp = parentOp->getParentOp();
100+
}
101+
}
102+
103+
LogicalResult TaskIdBackwardPropagation::visitOperation(
104+
Operation *op, ArrayRef<TaskIdLattice *> operands,
105+
ArrayRef<const TaskIdLattice *> results) {
106+
// Already annotated
107+
// TODO(Arda): Replace the following with getAsyncTaskIds when we no longer
108+
// need to dump the task ids into the IR.
109+
auto taskIdAttr = op->getAttrOfType<DenseI32ArrayAttr>("async_task_id");
110+
if (taskIdAttr) {
111+
const auto annotated = TaskId(taskIdAttr);
112+
for (auto operandLattice : operands) {
113+
ChangeResult changed = operandLattice->meet(annotated);
114+
propagateIfChanged(operandLattice, changed);
115+
}
116+
// Propagate to the parent ops such as control flows
117+
propagateToParent(op, annotated);
118+
return success();
119+
}
120+
// If it is not annotated by the user, propagate from results to the
121+
// operands
122+
for (const auto resultLattice : results) {
123+
for (auto operandLattice : operands) {
124+
ChangeResult changed = operandLattice->meet(resultLattice->getValue());
125+
propagateIfChanged(operandLattice, changed);
126+
}
127+
}
128+
129+
for (const auto resultLattice : results)
130+
propagateToParent(op, resultLattice->getValue());
131+
132+
return success();
133+
}
134+
135+
void TaskIdBackwardPropagation::visitBranchOperand(OpOperand &operand) {
136+
auto defOp = operand.getOwner();
137+
assert(isa<scf::IfOp>(defOp) || isa<scf::ForOp>(defOp));
138+
139+
SmallVector<TaskId> lattices(defOp->getNumResults(),
140+
TaskId::getUninitialized());
141+
for (auto [i, result] : llvm::enumerate(defOp->getResults())) {
142+
auto resultLattice = getLatticeElement(result);
143+
// Wait for all the results to be initialized.
144+
if (resultLattice->getValue().isUninitialized())
145+
return;
146+
lattices[i] =
147+
resultLattice->getValue().meet(lattices[i], resultLattice->getValue());
148+
}
149+
150+
// Propagate to the yield ops
151+
if (auto forOp = dyn_cast<scf::ForOp>(defOp)) {
152+
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
153+
propagateToYield(yieldOp, lattices);
154+
} else if (auto ifOp = dyn_cast<scf::IfOp>(defOp)) {
155+
propagateToYield(ifOp.thenYield(), lattices);
156+
if (!ifOp.getElseRegion().empty())
157+
propagateToYield(ifOp.elseYield(), lattices);
158+
} else {
159+
llvm_unreachable("Unknown branch operation");
160+
}
161+
return;
162+
163+
// TODO(Arda): Address what happens when loop is annotated
164+
}
165+
166+
void TaskIdBackwardPropagation::visitCallOperand(OpOperand &operand) {
167+
llvm_unreachable(
168+
"Should not have any call operands in the IR after inlining.");
169+
}
170+
171+
void TaskIdBackwardPropagation::setToExitState(TaskIdLattice *lattice) {}
172+
173+
} // namespace mlir::triton::gpu
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#ifndef NVHOPPER_ANALYSIS_TASKIDPROPAGATION_H
2+
#define NVHOPPER_ANALYSIS_TASKIDPROPAGATION_H
3+
4+
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
5+
#include "mlir/IR/BuiltinAttributes.h"
6+
#include "mlir/Support/LLVM.h"
7+
#include "triton/Dialect/Triton/IR/Dialect.h"
8+
#include <optional>
9+
10+
using namespace mlir::dataflow;
11+
12+
namespace mlir::triton::gpu {
13+
14+
//===----------------------------------------------------------------------===//
15+
// TaskId
16+
//===----------------------------------------------------------------------===//
17+
18+
/// This lattice value represents known information on the async_task_id of a
19+
/// lattice.
20+
class TaskId {
21+
public:
22+
/// Construct a taskId value as uninitialized.
23+
explicit TaskId() = default;
24+
25+
/// Construct a taskId value with a known constant.
26+
TaskId(DenseI32ArrayAttr taskIds) : taskIds(std::move(taskIds)) {}
27+
28+
/// Get the constant value. Returns null if no value was determined.
29+
DenseI32ArrayAttr getTaskIds() const {
30+
assert(!isUninitialized());
31+
return *taskIds;
32+
}
33+
34+
/// Compare the taskId values.
35+
bool operator==(const TaskId &rhs) const { return taskIds == rhs.taskIds; }
36+
37+
/// Print the taskId value.
38+
void print(raw_ostream &os) const;
39+
40+
/// The state where the taskIds value is uninitialized. This happens when the
41+
/// state hasn't been set during the analysis.
42+
static TaskId getUninitialized() { return TaskId{}; }
43+
44+
/// Whether the state is uninitialized.
45+
bool isUninitialized() const { return !taskIds.has_value(); }
46+
47+
/// Whether the state is unknown.
48+
bool isUnknown() const { return taskIds == nullptr; }
49+
50+
/// The state where the taskId value is unknown.
51+
static TaskId getUnknownTaskId() { return TaskId{/*taskIds=*/nullptr}; }
52+
53+
static TaskId meet(const TaskId &lhs, const TaskId &rhs);
54+
55+
static TaskId join(const TaskId &lhs, const TaskId &rhs);
56+
57+
private:
58+
std::optional<DenseI32ArrayAttr> taskIds;
59+
};
60+
61+
//===----------------------------------------------------------------------===//
62+
// TaskIdLattice
63+
//===----------------------------------------------------------------------===//
64+
65+
class TaskIdLattice : public Lattice<TaskId> {
66+
public:
67+
using Lattice::Lattice;
68+
};
69+
70+
//===----------------------------------------------------------------------===//
71+
// TaskIdBackwardPropagation
72+
//===----------------------------------------------------------------------===//
73+
74+
/// This analysis implements sparse backward propagation, which attempts to
75+
/// determine the async_task_id of an SSA value.
76+
77+
class TaskIdBackwardPropagation
78+
: public SparseBackwardDataFlowAnalysis<TaskIdLattice> {
79+
public:
80+
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
81+
82+
LogicalResult
83+
visitOperation(Operation *op, ArrayRef<TaskIdLattice *> operands,
84+
ArrayRef<const TaskIdLattice *> results) override;
85+
86+
void visitBranchOperand(OpOperand &operand) override;
87+
88+
void visitCallOperand(OpOperand &operand) override;
89+
90+
void setToExitState(TaskIdLattice *lattice) override;
91+
92+
void propagateToYield(scf::YieldOp yieldOp, SmallVector<TaskId> &lattices);
93+
94+
void propagateToParent(Operation *op, const TaskId &taskId);
95+
};
96+
97+
} // namespace mlir::triton::gpu
98+
99+
#endif // NVHOPPER_ANALYSIS_TASKIDPROPAGATION_H

0 commit comments

Comments
 (0)