Skip to content

Commit a6ecc75

Browse files
davidberard98htyu
andauthored
[AMD] StreamPipeline V1: fix depArg return mapping (#4832)
Previously, if an arg inside the loop was marked as a depArg, then a new iter_arg would be added to the for loop to handle the arg; but any usages of these variables _after_ the for loop would not be updated; those usages would get the wrong value. This PR fixes this by updating the return mapping. See the comment added in StreamPipeline.cpp for an example. Co-authored-by: Hongtao Yu <[email protected]>
1 parent 256ef34 commit a6ecc75

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline | FileCheck %s
2+
3+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
5+
#loc = loc("/data/users/dberard/triton-env/scripts/matmul.py":6:0)
6+
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = false}>
7+
module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
8+
// CHECK-LABEL: tt.func @use_dep_args
9+
tt.func @use_dep_args(%a_ptrs: tensor<64x32x!tt.ptr<bf16>, #blocked>, %b_ptrs: tensor<32x64x!tt.ptr<bf16>, #blocked1>, %loop_range: i32) -> (tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr<bf16>, #blocked>, tensor<32x64x!tt.ptr<bf16>, #blocked1>) {
10+
%cst = arith.constant dense<32> : tensor<64x32xi32, #blocked>
11+
%cst2 = arith.constant dense<2048> : tensor<32x64xi32, #blocked1>
12+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
13+
%c0_i32 = arith.constant 0 : i32
14+
%c8_i32 = arith.constant 8 : i32
15+
%c32_i32 = arith.constant 32 : i32
16+
// CHECK: tt.load
17+
// CHECK: [[FOR_OUT:%[a-z0-9_]+]]:{{[0-9]+}} = scf.for
18+
%for:3 = scf.for %arg6 = %c0_i32 to %loop_range step %c32_i32 iter_args(%arg7 = %cst_0, %arg8 = %a_ptrs, %arg9 = %b_ptrs) -> (tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr<bf16>, #blocked>, tensor<32x64x!tt.ptr<bf16>, #blocked1>) : i32 {
19+
%63 = tt.load %arg8 : tensor<64x32x!tt.ptr<bf16>, #blocked>
20+
%64 = tt.load %arg9 : tensor<32x64x!tt.ptr<bf16>, #blocked1>
21+
%65 = triton_gpu.convert_layout %63 : tensor<64x32xbf16, #blocked> -> tensor<64x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
22+
%66 = triton_gpu.convert_layout %64 : tensor<32x64xbf16, #blocked1> -> tensor<32x64xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
23+
%67 = tt.dot %65, %66, %arg7 : tensor<64x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x64xf32, #mma>
24+
%68 = tt.addptr %arg8, %cst : tensor<64x32x!tt.ptr<bf16>, #blocked>, tensor<64x32xi32, #blocked>
25+
%69 = tt.addptr %arg9, %cst2 : tensor<32x64x!tt.ptr<bf16>, #blocked1>, tensor<32x64xi32, #blocked1>
26+
scf.yield %67, %68, %69 : tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr<bf16>, #blocked>, tensor<32x64x!tt.ptr<bf16>, #blocked1>
27+
}
28+
// CHECK: tt.return {{[^,]+}}, [[FOR_OUT]]#3, [[FOR_OUT]]#4
29+
tt.return %for#0, %for#1, %for#2 : tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr<bf16>, #blocked>, tensor<32x64x!tt.ptr<bf16>, #blocked1>
30+
}
31+
}

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class LoopPipeliner {
7171
/// shared mem and a next buffer stored in regs.
7272
int numStages = 2;
7373

74-
/// Arg indicies
74+
/// Arg indicies in in pplForOp
7575
size_t depArgsBeginIdx;
7676
DenseMap<BlockArgument, size_t> depArgsIdx;
7777

@@ -165,6 +165,9 @@ class LoopPipeliner {
165165
/// Collect loads to pipeline. Return success if we can pipeline this loop
166166
LogicalResult initialize();
167167

168+
// Update mapping from old forOp results to new pplForOp results
169+
void setResultMapping(DenseMap<Value, Value> &newResults);
170+
168171
/// Emit pipelined loads (before loop body)
169172
void emitPrologue();
170173

@@ -548,6 +551,45 @@ void LoopPipeliner::emitPrologue() {
548551
} // for (Operation *op : orderedDeps)
549552
}
550553

554+
void LoopPipeliner::setResultMapping(DenseMap<Value, Value> &newResults) {
555+
// After pipelining, some of the depArgs have beem mapped to new args.
556+
// We need to remap these.
557+
//
558+
// For example, if we have
559+
//
560+
// ptr = ...
561+
// c = [zeros]
562+
// ret = scf.for iter_args(a_ptr=ptr, c=c)
563+
// a = load(a_ptr)
564+
// c += dot(a, ...)
565+
// a_ptr_new = a_ptr + N
566+
// scf.yield %a_ptr_new, %c
567+
//
568+
// then the ptr arg should be mapped to a new arg in the for loop.
569+
//
570+
// ptr = ...
571+
// c = [zeros]
572+
// load_pre = load(ptr)
573+
// ptr_new = ptr + N
574+
// ret = scf.for iter_args(a_ptr=ptr, c=c, ld=load_pre, A_ptr_1=ptr_new)
575+
// a_next = load(A_ptr_1)
576+
// c += dot(ld, ...)
577+
// A_ptr_new = A_ptr_1 + N
578+
// scf.yield a_ptr, c, a_next, A_ptr_new
579+
//
580+
// After this, if there are downstream users of a_ptr, they should reference
581+
// ret#3 instead of ret#0
582+
for (const auto &origArg : llvm::enumerate(forOp.getRegionIterArgs())) {
583+
if (depArgs.contains(origArg.value())) {
584+
auto oldIdx = origArg.index();
585+
auto newIdx = depArgsIdx[origArg.value()];
586+
auto oldResult = forOp->getResult(oldIdx);
587+
auto newResult = pplForOp->getResult(newIdx);
588+
newResults[oldResult] = newResult;
589+
}
590+
}
591+
}
592+
551593
void LoopPipeliner::emitEpilogue(DenseMap<Value, Value> &newResults) {
552594
if (!peelLastIter)
553595
return;
@@ -846,6 +888,7 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
846888
DenseMap<Value, Value> newResults;
847889
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
848890
newResults[forOp->getResult(i)] = pplForOp->getResult(i);
891+
pipeliner.setResultMapping(newResults);
849892
pipeliner.emitEpilogue(newResults);
850893

851894
// Replace the original loop

0 commit comments

Comments
 (0)