Skip to content

Commit b3fc429

Browse files
Jokerenmeta-codesync[bot]
authored andcommitted
[Cherry-pick] [PROTON] Filter out all intrinsics when counting the number of triton functions (#8021) (#556)
Summary: Cherry-picked from upstream OAI repository. Original Commit: 23a4421 Original Author: Keren Zhou Original Date: 2025-09-01 18:27:52 -0400 Original commit message: ``` [PROTON] Filter out all intrinsics when counting the number of triton functions (#8021) ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. Pull Request resolved: #556 Reviewed By: agron911 Differential Revision: D85909904 Pulled By: dshi7 fbshipit-source-id: 60f11154a4ffb226fafd20b347192dda2ac32a3c
1 parent 2e51006 commit b3fc429

File tree

5 files changed

+37
-32
lines changed

5 files changed

+37
-32
lines changed

test/Proton/amd/add_sched_barriers.mlir

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file -add-sched-barriers --verify-diagnostics | FileCheck --check-prefix=CHECK %s
1+
// RUN: triton-opt %s -split-input-file -add-sched-barriers --verify-diagnostics | FileCheck --check-prefix=CHECK %s
22

33
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
44
#smem = #ttg.shared_memory
@@ -75,3 +75,14 @@ module attributes {"ttg.num-warps" = 8 : i32, ttg.profile_scratch_memory_alignme
7575
llvm.return
7676
}
7777
}
78+
79+
// -----
80+
81+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} {
82+
llvm.func @llvm.exp2.f32(f32) -> f32 attributes {libname = "", libpath = ""}
83+
// CHECK-LABEL: two_functions
84+
llvm.func @two_functions(%arg: f32) -> f32 {
85+
%1 = llvm.call @llvm.exp2.f32(%arg) : (f32) -> f32
86+
llvm.return %1 : f32
87+
}
88+
}

third_party/proton/Dialect/include/Conversion/ProtonGPUToLLVM/Utility.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ CircularStoreDataPack
4545
lowerCircularStoreOpHelper(CircularStoreOp op, Value segmentStruct,
4646
ConversionPatternRewriter &rewriter);
4747

48+
SmallVector<FunctionOpInterface> getTritonFunctions(ModuleOp mod);
49+
4850
} // namespace proton::gpu
4951
} // namespace triton
5052

third_party/proton/Dialect/lib/ProtonGPUToLLVM/AllocateProtonGlobalScratchBuffer.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "Conversion/ProtonGPUToLLVM/Passes.h"
2+
#include "Conversion/ProtonGPUToLLVM/Utility.h"
23
#include "Dialect/ProtonGPU/IR/Dialect.h"
34
#include "mlir/Pass/Pass.h"
45
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -17,25 +18,13 @@ struct AllocateProtonGlobalScratchBufferPass
1718
MLIRContext *ctx = &getContext();
1819
OpBuilder builder(ctx);
1920

20-
int numFuncOps = 0;
21-
FunctionOpInterface func;
22-
mod.walk([&](FunctionOpInterface op) {
23-
// Ignore any intrinsic functions. On AMD the predicate load/store ops
24-
// are currently pseduo instrunctions at this point and will get picked up
25-
// here and trigger the FunctionOpInterface range based assert below
26-
StringRef funcName(op.getNameAttr());
27-
if (!funcName.contains("__")) {
28-
numFuncOps += 1;
29-
func = op;
30-
}
31-
});
32-
33-
assert(numFuncOps == 1);
21+
auto funcOps = triton::proton::gpu::getTritonFunctions(mod);
22+
assert(funcOps.size() == 1 && "Expected exactly one funcOp");
3423

3524
int32_t cumulativeMemorySize = 0; // bytes
3625
std::vector<uint32_t> alignments;
3726

38-
func.walk([&](proton::gpu::GlobalScratchAllocOp op) {
27+
funcOps[0].walk([&](proton::gpu::GlobalScratchAllocOp op) {
3928
int offset = llvm::alignTo(cumulativeMemorySize,
4029
proton::gpu::getBytesPerClockEntry());
4130
op->setAttr("offset",

third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AddSchedBarriers.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "Conversion/ProtonGPUToLLVM/Passes.h"
2+
#include "Conversion/ProtonGPUToLLVM/Utility.h"
23
#include "Dialect/ProtonGPU/IR/Dialect.h"
34
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
45
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -28,33 +29,21 @@ struct AddSchedBarriers
2829
MLIRContext *ctx = &getContext();
2930
OpBuilder builder(ctx);
3031

31-
int numFuncOps = 0;
32-
FunctionOpInterface func;
33-
mod.walk([&](FunctionOpInterface op) {
34-
// Ignore any intrinsic functions. On AMD the predicate load/store ops
35-
// are currently pseduo instrunctions at this point and may get picked up
36-
// here and trigger the FunctionOpInterface range based assert below
37-
StringRef funcName(op.getNameAttr());
38-
if (!funcName.contains("__")) {
39-
numFuncOps += 1;
40-
func = op;
41-
}
42-
});
43-
44-
assert(numFuncOps == 1);
32+
auto funcOps = triton::proton::gpu::getTritonFunctions(mod);
33+
assert(funcOps.size() == 1 && "Expected exactly one funcOp");
4534

4635
IntegerAttr zeroAttrValue =
4736
builder.getI32IntegerAttr(static_cast<int32_t>(0));
4837

49-
func.walk([&](mlir::triton::proton::gpu::ReadCounterOp op) {
38+
funcOps[0].walk([&](mlir::triton::proton::gpu::ReadCounterOp op) {
5039
auto loc = op.getLoc();
5140
if (!isa_and_nonnull<ROCDL::SchedBarrier>(op->getPrevNode())) {
5241
builder.setInsertionPoint(op);
5342
builder.create<ROCDL::SchedBarrier>(loc, zeroAttrValue);
5443
}
5544
});
5645

57-
func.walk([&](mlir::triton::proton::gpu::CircularStoreOp op) {
46+
funcOps[0].walk([&](mlir::triton::proton::gpu::CircularStoreOp op) {
5847
auto loc = op.getLoc();
5948
if (!isa_and_nonnull<ROCDL::SchedBarrier>(op->getNextNode())) {
6049
builder.setInsertionPointAfter(op);

third_party/proton/Dialect/lib/ProtonGPUToLLVM/Utility.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,20 @@ lowerCircularStoreOpHelper(CircularStoreOp op, Value segmentStruct,
156156
return {isWriter, valsVec, vecPtr, addrSpace};
157157
}
158158

159+
SmallVector<FunctionOpInterface> getTritonFunctions(ModuleOp mod) {
160+
SmallVector<FunctionOpInterface> funcOps;
161+
mod.walk([&](FunctionOpInterface funcOp) {
162+
// Ignore any intrinsic functions which have an empty body.
163+
// For example, on AMD the predicate load/store ops are currently pseudo
164+
// instructions at this point and may get picked up here and trigger the
165+
// FunctionOpInterface range based assert below.
166+
if (funcOp.empty())
167+
return;
168+
funcOps.push_back(funcOp);
169+
});
170+
return funcOps;
171+
}
172+
159173
} // namespace proton::gpu
160174
} // namespace triton
161175

0 commit comments

Comments
 (0)