Skip to content

Commit d06ec83

Browse files
authored
[Instrumentation][Proton] Add MLIR/LLVM level compiler instrumentation pass support in Proton (triton-lang#5067)
Basic functionality to print load/store address spaces chosen by the compiler. Usage/example with matmul Proton tutorial: ``` $ proton --instrument=print-mem-spaces matmul.py 0 matmul_kernel matmul.py:180:20 SHARED STORE 1 matmul_kernel matmul.py:181:20 SHARED STORE 2 matmul_kernel matmul.py:180:20 SHARED LOAD 3 matmul_kernel matmul.py:181:20 SHARED LOAD matmul-performance: M N K cuBLAS Triton 0 256.0 256.0 256.0 2.231013 1.691252 1 384.0 384.0 384.0 5.947805 4.626071 2 512.0 512.0 512.0 12.336188 8.924051 3 640.0 640.0 640.0 26.006348 14.628980 4 768.0 768.0 768.0 36.065672 20.972006 5 896.0 896.0 896.0 51.974214 29.480457 6 1024.0 1024.0 1024.0 63.913206 27.560463 7 1152.0 1152.0 1152.0 52.790876 34.125533 ```
1 parent 0b68388 commit d06ec83

File tree

9 files changed

+302
-13
lines changed

9 files changed

+302
-13
lines changed

.github/workflows/integration-tests.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,9 @@ jobs:
279279
ctest -j32
280280
- name: Run Proton tests
281281
run: |
282-
cd third_party/proton
283-
python3 -m pytest -s test
282+
cd third_party/proton/test
283+
python3 -m pytest -s .
284+
cd ..
284285
- # If we're on branch `main`, save the ccache Triton compilation artifacts
285286
# to the cache so they can be used by other (non-main) CI runs.
286287
#
@@ -425,8 +426,9 @@ jobs:
425426
python3 -m pytest -s -n 8 ./test_cast_matmul.py
426427
- name: Run Proton tests
427428
run: |
428-
cd third_party/proton
429-
python3 -m pytest -s test
429+
cd third_party/proton/test
430+
python3 -m pytest -s .
431+
cd ..
430432
- name: Run C++ unittests
431433
run: |
432434
cd python

.github/workflows/integration-tests.yml.in

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,9 @@ jobs:
319319
- &run-proton-tests-step
320320
name: Run Proton tests
321321
run: |
322-
cd third_party/proton
323-
python3 -m pytest -s test
322+
cd third_party/proton/test
323+
python3 -m pytest -s .
324+
cd ..
324325

325326
# If we're on branch `main`, save the ccache Triton compilation artifacts
326327
# to the cache so they can be used by other (non-main) CI runs.

lib/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ add_subdirectory(Conversion)
33
add_subdirectory(Dialect)
44
add_subdirectory(Target)
55
add_subdirectory(Tools)
6+
add_subdirectory(Instrumentation)

lib/Instrumentation/CMakeLists.txt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
set(GPU_INSTRUMENTATION_PASSES
2+
PrintLoadStoreMemSpaces
3+
)
4+
5+
set(PrintLoadStoreMemSpaces_SOURCES
6+
PrintLoadStoreMemSpaces.cpp
7+
)
8+
9+
10+
foreach( plugin ${GPU_INSTRUMENTATION_PASSES} )
11+
add_library(
12+
${plugin}
13+
SHARED
14+
${${plugin}_SOURCES}
15+
)
16+
17+
target_link_libraries(
18+
${plugin}
19+
PRIVATE
20+
LLVMCore
21+
LLVMSupport
22+
LLVMTransformUtils
23+
"$<$<PLATFORM_ID:Darwin>:-undefined dynamic_lookup>"
24+
)
25+
# CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python
26+
# build. It is empty if building directly from the root
27+
# CMakeLists.txt file. Therefore if not building from Python just
28+
# use the default CMake shared lib path otherwise this causes a hard
29+
# build error
30+
if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
31+
set_target_properties(${plugin} PROPERTIES
32+
LIBRARY_OUTPUT_DIRECTORY
33+
"${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation")
34+
endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
35+
36+
# This is set to -fvisibility=hidden in the top level CMake file
37+
# which causes the llvmGetPassPluginInfo symbol to be hidden and
38+
# an "entry point not found" error. Reset it just for this target
39+
target_compile_options(${plugin} PRIVATE -fvisibility=default -fno-rtti)
40+
endforeach()
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#include "llvm/IR/Module.h"
2+
#include "llvm/IR/PassManager.h"
3+
#include "llvm/Passes/PassBuilder.h"
4+
#include "llvm/Passes/PassPlugin.h"
5+
#include <map>
6+
7+
using namespace llvm;
8+
9+
namespace {
10+
11+
struct LoadStoreMemSpace : public PassInfoMixin<LoadStoreMemSpace> {
12+
PreservedAnalyses run(llvm::Module &module, ModuleAnalysisManager &) {
13+
bool modifiedCodeGen = runOnModule(module);
14+
15+
return (modifiedCodeGen ? llvm::PreservedAnalyses::none()
16+
: llvm::PreservedAnalyses::all());
17+
}
18+
bool runOnModule(llvm::Module &module);
19+
// isRequired being set to true keeps this pass from being skipped
20+
// if it has the optnone LLVM attribute
21+
static bool isRequired() { return true; }
22+
};
23+
24+
} // end anonymous namespace
25+
26+
std::map<int, std::string> AddrSpaceMap = {
27+
{0, "FLAT"}, {1, "GLOBAL"}, {3, "SHARED"}, {4, "CONSTANT"}, {5, "SCRATCH"}};
28+
29+
std::map<std::string, uint32_t> LocationCounterSourceMap;
30+
31+
std::string LoadOrStoreMap(const BasicBlock::iterator &I) {
32+
if (LoadInst *LI = dyn_cast<LoadInst>(I))
33+
return "LOAD";
34+
else if (StoreInst *SI = dyn_cast<StoreInst>(I))
35+
return "STORE";
36+
else
37+
throw std::runtime_error("Error: unknown operation type");
38+
}
39+
template <typename LoadOrStoreInst>
40+
void InstrumentationFunction(const BasicBlock::iterator &I, const Function &F,
41+
const llvm::Module &M, uint32_t &LocationCounter) {
42+
auto LSI = dyn_cast<LoadOrStoreInst>(I);
43+
if (not LSI)
44+
return;
45+
Value *Op = LSI->getPointerOperand()->stripPointerCasts();
46+
uint32_t AddrSpace = cast<PointerType>(Op->getType())->getAddressSpace();
47+
DILocation *DL = dyn_cast<Instruction>(I)->getDebugLoc();
48+
49+
std::string SourceAndAddrSpaceInfo =
50+
(F.getName() + " " + DL->getFilename() + ":" + Twine(DL->getLine()) +
51+
":" + Twine(DL->getColumn()))
52+
.str() +
53+
" " + AddrSpaceMap[AddrSpace] + " " + LoadOrStoreMap(I);
54+
55+
if (LocationCounterSourceMap.find(SourceAndAddrSpaceInfo) ==
56+
LocationCounterSourceMap.end()) {
57+
errs() << LocationCounter << " " << SourceAndAddrSpaceInfo << "\n";
58+
LocationCounterSourceMap[SourceAndAddrSpaceInfo] = LocationCounter;
59+
LocationCounter++;
60+
}
61+
}
62+
63+
bool LoadStoreMemSpace::runOnModule(Module &M) {
64+
bool ModifiedCodeGen = false;
65+
uint32_t LocationCounter = 0;
66+
for (auto &F : M) {
67+
if (F.isIntrinsic())
68+
continue;
69+
StringRef functionName = F.getName();
70+
if (F.getCallingConv() == CallingConv::AMDGPU_KERNEL ||
71+
F.getCallingConv() == CallingConv::PTX_Kernel ||
72+
functionName.contains("kernel")) {
73+
for (Function::iterator BB = F.begin(); BB != F.end(); BB++) {
74+
for (BasicBlock::iterator I = BB->begin(); I != BB->end(); I++) {
75+
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
76+
InstrumentationFunction<LoadInst>(I, F, M, LocationCounter);
77+
} else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
78+
InstrumentationFunction<StoreInst>(I, F, M, LocationCounter);
79+
}
80+
}
81+
}
82+
}
83+
}
84+
return ModifiedCodeGen;
85+
}
86+
87+
PassPluginLibraryInfo getPassPluginInfo() {
88+
const auto callback = [](PassBuilder &PB) {
89+
PB.registerOptimizerLastEPCallback([&](ModulePassManager &MPM, auto, auto) {
90+
MPM.addPass(LoadStoreMemSpace());
91+
return true;
92+
});
93+
};
94+
95+
return {LLVM_PLUGIN_API_VERSION, "print-mem-space", LLVM_VERSION_STRING,
96+
callback};
97+
};
98+
99+
extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo llvmGetPassPluginInfo() {
100+
return getPassPluginInfo();
101+
}

third_party/proton/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ The following examples demonstrate how to use Proton command-line.
128128
proton [options] script.py [script_args] [script_options]
129129
proton [options] pytest [pytest_args] [script_options]
130130
python -m triton.profiler.proton [options] script.py [script_args] [script_options]
131+
proton --instrument=[instrumentation pass] script.py
131132
```
132133

133134
When profiling in the command line mode, the `proton.start` and `proton.finalize` functions are automatically called before and after the script execution. Any `proton.start` and `proton.finalize` functions in the script are ignored. Also, in the command line mode, only a single *session* is supported. Therefore, `proton.deactivate(session_id=1)` is invalid, while `proton.deactivate(session_id=0)` is valid.
@@ -156,6 +157,23 @@ More options can be found by running the following command.
156157
proton-viewer -h
157158
```
158159

160+
### Advanced features
161+
In addition to profiling, Proton also incorporates MLIR/LLVM based compiler instrumentation passes to get Triton level analysis
162+
and optimization information. This feature is under active development and the list of available passes is expected to grow.
163+
164+
#### Available passes
165+
print-mem-spaces: this pass prints the load and store address spaces (e.g. global, flat, shared) chosen by the compiler and attributes back to Triton source information.
166+
167+
Example usage with the Proton matmul tutorial:
168+
```bash
169+
$ proton --instrument=print-mem-spaces matmul.py
170+
0 matmul_kernel matmul.py:180:20 SHARED STORE
171+
1 matmul_kernel matmul.py:181:20 SHARED STORE
172+
2 matmul_kernel matmul.py:180:20 SHARED LOAD
173+
3 matmul_kernel matmul.py:181:20 SHARED LOAD
174+
```
175+
Notes: The instrument functionality is currently only available from the command line. Additionally the instrument and profile command line arguments can not be use simulantously.
176+
159177
### Instruction sampling (experimental)
160178

161179
Proton supports instruction sampling on NVIDIA GPUs.

third_party/proton/proton/proton.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import argparse
22
import sys
33
import os
4+
from glob import glob
5+
import pathlib
46
from .profile import start, finalize, _select_backend
57
from .flags import set_command_line
8+
import triton
69

710

811
def parse_arguments():
@@ -19,6 +22,8 @@ def parse_arguments():
1922
choices=["shadow", "python"])
2023
parser.add_argument("-d", "--data", type=str, help="Profiling data", default="tree", choices=["tree"])
2124
parser.add_argument("-k", "--hook", type=str, help="Profiling hook", default=None, choices=[None, "triton"])
25+
parser.add_argument("-i", "--instrument", type=str, help="Instrumentation analysis type", default=None,
26+
choices=[None, "print-mem-spaces"])
2227
parser.add_argument('target_args', nargs=argparse.REMAINDER, help='Subcommand and its arguments')
2328
args = parser.parse_args()
2429
return args, args.target_args
@@ -28,7 +33,7 @@ def is_pytest(script):
2833
return os.path.basename(script) == 'pytest'
2934

3035

31-
def execute_as_main(script, args):
36+
def execute_as_main(script, args, instrumentation_pass=None):
3237
script_path = os.path.abspath(script)
3338
# Prepare a clean global environment
3439
clean_globals = {
@@ -42,6 +47,14 @@ def execute_as_main(script, args):
4247
sys.argv = [script] + args
4348
# Append the script's directory in case the script uses relative imports
4449
sys.path.append(os.path.dirname(script_path))
50+
top_level_triton_path = os.path.dirname(triton.__file__)
51+
52+
if instrumentation_pass == "print-mem-spaces":
53+
instrumentation_pass_path = str(
54+
next(pathlib.Path(top_level_triton_path).rglob("libPrintLoadStoreMemSpaces.so"), None))
55+
os.environ['TRITON_ALWAYS_COMPILE'] = "1"
56+
os.environ['TRITON_DISABLE_LINE_INFO'] = "0"
57+
os.environ['LLVM_PASS_PLUGIN_PATH'] = instrumentation_pass_path
4558

4659
# Execute in the isolated environment
4760
try:
@@ -54,11 +67,7 @@ def execute_as_main(script, args):
5467
sys.argv = original_argv
5568

5669

57-
def run_profiling(args, target_args):
58-
backend = args.backend if args.backend else _select_backend()
59-
60-
start(args.name, context=args.context, data=args.data, backend=backend, hook=args.hook)
61-
70+
def do_setup_and_execute(target_args, instrumentation_pass=None):
6271
# Set the command line mode to avoid any `start` calls in the script.
6372
set_command_line()
6473

@@ -68,13 +77,29 @@ def run_profiling(args, target_args):
6877
import pytest
6978
pytest.main(script_args)
7079
else:
71-
execute_as_main(script, script_args)
80+
execute_as_main(script, script_args, instrumentation_pass)
81+
82+
83+
def run_profiling(args, target_args):
84+
backend = args.backend if args.backend else _select_backend()
85+
86+
start(args.name, context=args.context, data=args.data, backend=backend, hook=args.hook)
87+
88+
do_setup_and_execute(target_args)
7289

7390
finalize()
7491

7592

93+
def run_instrumentation(args, target_args):
94+
backend = args.backend if args.backend else _select_backend()
95+
do_setup_and_execute(target_args, args.instrument)
96+
97+
7698
def main():
7799
args, target_args = parse_arguments()
100+
if args.instrument:
101+
run_instrumentation(args, target_args)
102+
return
78103
run_profiling(args, target_args)
79104

80105

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, #
9+
stride_bk, stride_bn, #
10+
stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
11+
BLOCK_SIZE_K: tl.constexpr, #
12+
GROUP_SIZE_M: tl.constexpr, #
13+
):
14+
pid = tl.program_id(axis=0)
15+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
16+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
17+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
18+
group_id = pid // num_pid_in_group
19+
first_pid_m = group_id * GROUP_SIZE_M
20+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
21+
pid_m = first_pid_m + (pid % group_size_m)
22+
pid_n = (pid % num_pid_in_group) // group_size_m
23+
24+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
25+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
26+
offs_k = tl.arange(0, BLOCK_SIZE_K)
27+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
28+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
29+
30+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
31+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
32+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
33+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
34+
accumulator += tl.dot(a, b)
35+
a_ptrs += BLOCK_SIZE_K * stride_ak
36+
b_ptrs += BLOCK_SIZE_K * stride_bk
37+
c = accumulator.to(tl.float16)
38+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
39+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
40+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
41+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
42+
tl.store(c_ptrs, c, mask=c_mask)
43+
44+
45+
def matmul(a, b, activation=""):
46+
# Check constraints.
47+
M, K = a.shape
48+
K, N = b.shape
49+
# Allocates output.
50+
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
51+
52+
# 1D launch kernel where each block gets its own program.
53+
def grid():
54+
return (1, )
55+
56+
matmul_kernel[grid](
57+
a, b, c, #
58+
M, N, K, #
59+
a.stride(0), a.stride(1), #
60+
b.stride(0), b.stride(1), #
61+
c.stride(0), c.stride(1), #
62+
128, 256, 64, 8)
63+
return c
64+
65+
66+
a = torch.randn((32, 32), device="cuda", dtype=torch.float16)
67+
b = torch.randn((32, 32), device="cuda", dtype=torch.float16)
68+
matmul(a, b)

0 commit comments

Comments
 (0)