Skip to content

Commit 1028c8f

Browse files
raikonenfnuAlexAUT
andauthored
[AMD] Enable async pingpong for F16 GEMMs (#796)
* [AMD] Generalize PingPong to have different type of Load/Store Ops This main motivation behind this commit is to add support for PingPong with AsyncOps. In order to accomplish that we made these changes: - Fork "determineDotMemoryOps" to "determineDotAsyncMemoryOps" that handles async memory ops. - Refactor validation and pruning of memory ops to "pruneDotMemoryOps" S.T we can have clean interface for it's async memory ops counterpart "pruneAsyncDotMemoryOps". - Plumb "useBlockPingpong" into StreamPipeliner S.T it can adjust AsyncWait stage/cluster to hoist first AsyncWait and allow set AsyncWait towards the end of the loop to make it easier for 4 PP cluster to move it before the 3rd dot-slice / 2 s_barrier before localLoads this is to ensure no race conditions. - Add check to enable handling of dotSOps (dot scaled) VS dotOps (dot) Signed-off-by: Stanley Winata <[email protected]> Co-authored-by: Alexander Weinrauch <[email protected]>
1 parent 247f4f4 commit 1028c8f

File tree

12 files changed

+467
-170
lines changed

12 files changed

+467
-170
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3535
"TRITON_HIP_LOCAL_PREFETCH",
3636
"TRITON_HIP_USE_ASYNC_COPY",
3737
"TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE",
38+
"TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG",
3839
"TRITON_HIP_USE_BLOCK_PINGPONG",
3940
"TRITON_HIP_USE_IN_THREAD_TRANSPOSE",
4041
"TRITON_LLVM_DEBUG_ONLY",

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def make_ttgir(mod, metadata, options):
262262
amd.passes.ttgpuir.add_reorder_instructions(pm)
263263
use_block_pingpong = is_pingpong_schedule_enabled(options.arch)
264264
if use_block_pingpong and options.num_stages in [2, 4]:
265-
amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)
265+
amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages, use_async_copy)
266266

267267
if knobs.amd.use_buffer_ops:
268268
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_ASYNCUTILITY_H_
2+
#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_ASYNCUTILITY_H_
3+
4+
#include "mlir/IR/Value.h"
5+
6+
namespace mlir::triton::AMD {
7+
// Traverses the def-chain including control flow of the token and returns true
8+
// if all defining operations are an AsyncWait
9+
bool comesFromAsyncWait(mlir::Value value);
10+
} // namespace mlir::triton::AMD
11+
12+
#endif

third_party/amd/include/TritonAMDGPUTransforms/Passes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ std::unique_ptr<Pass> createTritonAMDGPUConvertToBufferOpsPass(
3434
std::string archGenName = std::string());
3535

3636
std::unique_ptr<Pass>
37-
createTritonAMDGPUBlockPingpongPass(int32_t numStages = 2);
37+
createTritonAMDGPUBlockPingpongPass(int32_t numStages = 2,
38+
bool useAsyncCopy = false);
3839

3940
std::unique_ptr<Pass> createTritonAMDGPUInThreadTransposePass();
4041

third_party/amd/include/TritonAMDGPUTransforms/Passes.td

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,12 @@ def TritonAMDGPUBlockPingpong: Pass<"tritonamdgpu-block-pingpong", "mlir::Module
168168

169169
let dependentDialects = ["mlir::ROCDL::ROCDLDialect, mlir::triton::amdgpu::TritonAMDGPUDialect"];
170170

171-
let options = [
172-
Option<"numStages", "num-stages",
173-
"int32_t", /*default*/"2",
174-
"Number of Pipeline stages">,
175-
];
171+
let options =
172+
[Option<"numStages", "num-stages", "int32_t", /*default*/ "2",
173+
"Number of Pipeline stages">,
174+
Option<"useAsyncCopy", "use_async_copy", "bool", /*default*/ "false",
175+
"Use AsyncCopyGlobalToLocal to directly load to shared memory">,
176+
];
176177
}
177178

178179
def TritonAMDGPUInThreadTranspose: Pass<"tritonamdgpu-in-thread-transpose", "mlir::triton::FuncOp"> {
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#include "third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h"
2+
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
3+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
4+
#include "mlir/IR/Operation.h"
5+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6+
7+
namespace mlir::triton::AMD {
8+
9+
// Traverses the def-chain including control flow of the token and returns true
10+
// if all defining operations are an AsyncWait
11+
bool comesFromAsyncWait(mlir::Value token) {
12+
if (auto defOp = token.getDefiningOp()) {
13+
if (isa<triton::gpu::AsyncWaitOp>(defOp))
14+
return true;
15+
else if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(defOp))
16+
return comesFromAsyncWait(castOp.getInputs()[0]);
17+
else
18+
return false;
19+
}
20+
21+
auto blockArg = llvm::dyn_cast<mlir::BlockArgument>(token);
22+
// If the token has no defining op and is not an BlockArgument bail out
23+
if (!blockArg) {
24+
return false;
25+
}
26+
27+
auto block = blockArg.getOwner();
28+
auto argId = blockArg.getArgNumber();
29+
30+
auto destOperandFromAsyncWait = [argId](auto &&operands) {
31+
assert(argId < operands.size());
32+
return comesFromAsyncWait(operands[argId]);
33+
};
34+
35+
// Check all predecessor block's terminator and follow the passed value at
36+
// argId to see if they are immediately an AsyncWait.
37+
for (auto *pred : block->getPredecessors()) {
38+
auto terminator = pred->getTerminator();
39+
if (auto br = llvm::dyn_cast<cf::BranchOp>(terminator)) {
40+
if (!destOperandFromAsyncWait(br.getDestOperands()))
41+
return false;
42+
} else if (auto condBr = llvm::dyn_cast<cf::CondBranchOp>(terminator)) {
43+
if (condBr.getTrueDest() == block) {
44+
if (!destOperandFromAsyncWait(condBr.getTrueDestOperands()))
45+
return false;
46+
}
47+
if (condBr.getFalseDest() == block) {
48+
if (!destOperandFromAsyncWait(condBr.getFalseDestOperands()))
49+
return false;
50+
}
51+
} else if (auto br = llvm::dyn_cast<LLVM::BrOp>(terminator)) {
52+
if (!destOperandFromAsyncWait(br.getDestOperands()))
53+
return false;
54+
} else {
55+
llvm::dbgs() << "no terminator!" << *terminator << "\n";
56+
return false;
57+
}
58+
}
59+
return true;
60+
}
61+
62+
} // namespace mlir::triton::AMD

third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_triton_library(TritonAMDGPUToLLVM
2+
AsyncUtility.cpp
23
AtomicRMWOpsEmitter.cpp
34
BufferOpsEmitter.cpp
45
ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp

third_party/amd/lib/TritonAMDGPUToLLVM/MembarUtility.cpp

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,12 @@
11
#include "third_party/amd/include/TritonAMDGPUToLLVM/MembarUtility.h"
22
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
33
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
4+
#include "third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h"
45
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
56

67
namespace mlir::triton::AMD {
78
namespace {
89

9-
// Traverses the def-chain including control flow of the token and returns true
10-
// if all defining operations are an AsyncWait
11-
bool comesFromAsyncWait(Value token) {
12-
if (auto defOp = token.getDefiningOp()) {
13-
return isa<triton::gpu::AsyncWaitOp>(defOp);
14-
}
15-
16-
auto blockArg = dyn_cast<BlockArgument>(token);
17-
// If the token has no defining op and is not an BlockArgument bail out
18-
if (!blockArg) {
19-
return false;
20-
}
21-
22-
auto block = blockArg.getOwner();
23-
auto argId = blockArg.getArgNumber();
24-
25-
auto destOperandFromAsyncWait = [argId](auto &&operands) {
26-
assert(argId < operands.size());
27-
return comesFromAsyncWait(operands[argId]);
28-
};
29-
30-
// Check all predecessor block's terminator and follow the passed value at
31-
// argId to see if they are immediately an AsyncWait.
32-
for (auto *pred : block->getPredecessors()) {
33-
auto terminator = pred->getTerminator();
34-
if (auto br = dyn_cast<cf::BranchOp>(terminator)) {
35-
if (!destOperandFromAsyncWait(br.getDestOperands()))
36-
return false;
37-
} else if (auto condBr = dyn_cast<cf::CondBranchOp>(terminator)) {
38-
if (condBr.getTrueDest() == block) {
39-
if (!destOperandFromAsyncWait(condBr.getTrueDestOperands()))
40-
return false;
41-
}
42-
if (condBr.getFalseDest() == block) {
43-
if (!destOperandFromAsyncWait(condBr.getFalseDestOperands()))
44-
return false;
45-
}
46-
} else {
47-
return false;
48-
}
49-
}
50-
return true;
51-
}
52-
5310
// Returns true if one of the operands is a LocalLoad synced via AsyncWait.
5411
bool filterAsyncLocalLoadsDeppendencies(Operation *op1, Operation *op2) {
5512
auto isAsyncLoad = [](Operation *op) {

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
55
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
66
#include "mlir/IR/PatternMatch.h"
7+
#include "third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h"
78
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
89
#include "triton/Dialect/Triton/IR/Dialect.h"
910
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1011

1112
namespace tt = mlir::triton;
1213
using mlir::triton::ModuleAxisInfoAnalysis;
14+
using mlir::triton::AMD::comesFromAsyncWait;
1315
using mlir::triton::AMD::DppCtrl;
1416
using mlir::triton::AMD::ISAFamily;
1517
using mlir::triton::gpu::appendOrGetExternFuncOp;
@@ -734,8 +736,9 @@ void addAsyncCopyAliasScope(AliasAnalysisOpInterface directToLdsOp) {
734736
void addLocalLoadNoAliasScope(triton::gpu::LocalLoadOp localLoadOp,
735737
AliasAnalysisOpInterface llLoadOp) {
736738
auto token = localLoadOp.getToken();
737-
if (!token || !token.getDefiningOp<tt::gpu::AsyncWaitOp>())
739+
if (!token || !comesFromAsyncWait(token)) {
738740
return;
741+
}
739742

740743
return addLocalLoadNoAliasScope(llLoadOp);
741744
}

0 commit comments

Comments
 (0)