Skip to content

Commit 50ab08d

Browse files
agron911meta-codesync[bot]
authored andcommitted
[triton][beta] [Cherry-pick] '[BENCH] Integrate hipblas in roofline measurement (#8216)' (#995)
Summary: Pull Request resolved: #995 This is a cherry-pick of an upstream PR: triton-lang/triton#8216 Upstream commit message: ``` > [BENCH] Integrate hipblas in roofline measurement (#8216) ``` ***Do not remove the following line from this commit*** Reactor Cherry-pick Revision: 11b19e4 --- This diff was generated by running: ``` buck run fbcode//triton/tools/reactor:reactor -- cherrypick --num-commits 1 ``` Reviewed By: dshi7 Differential Revision: D94471625 fbshipit-source-id: 76103b2d01e3d89c61026093be278aabefd766d5
1 parent 07c0acc commit 50ab08d

File tree

4 files changed

+59
-26
lines changed

4 files changed

+59
-26
lines changed

python/test/unit/runtime/test_blaslt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_blaslt(m, n, k, dtype_str, device):
2323
if dtype_str == "float8_e4m3fn" and not is_hip_cdna4():
2424
pytest.skip("float8_e4m3fn is only supported on HIP CDNA4")
2525
c_dtype = torch.float16 if dtype_str in ("float8_e4m3fnuz", "float8_e4m3fn") else dtype
26-
make_handle = lambda workspace: vendor.hipblas.HipBlasLt(workspace)
26+
make_handle = lambda workspace: vendor.hipblas.HipblasLt(workspace)
2727
else:
2828
pytest.skip("test_blaslt is only supported on CUDA or HIP")
2929

python/triton_kernels/triton_kernels/roofline.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import ctypes
22
import matplotlib.pyplot as plt
33
import triton
4-
from triton._C.libtriton import nvidia
4+
from triton._C.libtriton import nvidia, amd
55
import torch
66
import csv
77
from dataclasses import dataclass
88
import inspect
9+
from .target_info import is_hip, is_cuda
910

1011

1112
@dataclass
@@ -84,23 +85,48 @@ def inject_proxy_and_call(val, args, kwargs):
8485

8586

8687
def get_memset_tbps():
87-
# Measure device memory set bandwidth using CUDA driver API (cuMemsetD8Async)
88-
if torch.version.cuda is None:
89-
raise RuntimeError("get_memset_tbps is only supported on CUDA")
90-
# load cuda
91-
cuda = ctypes.CDLL("libcuda.so")
92-
cuda.cuInit.argtypes = [ctypes.c_uint]
93-
cuda.cuInit.restype = ctypes.c_int
94-
if cuda.cuInit(0) != 0:
95-
raise RuntimeError("cuInit failed")
96-
# initialize cuMemsetD8Async
97-
cuda.cuMemsetD8Async.argtypes = [ctypes.c_uint64, ctypes.c_ubyte, ctypes.c_size_t, ctypes.c_void_p]
98-
cuda.cuMemsetD8Async.restype = ctypes.c_int
99-
# benchmark `cuMemsetD8Async`
10088
n_bytes = 1 << 32
10189
buf = torch.empty(n_bytes, device="cuda", dtype=torch.uint8)
102-
dptr = ctypes.c_uint64(buf.data_ptr())
103-
fn = lambda: cuda.cuMemsetD8Async(dptr, ctypes.c_ubyte(0), ctypes.c_size_t(n_bytes), ctypes.c_void_p(0))
90+
stream0 = ctypes.c_void_p(0)
91+
92+
if is_cuda():
93+
libname = "libcuda.so"
94+
init_name = "cuInit"
95+
memset_name = "cuMemsetD8Async"
96+
memset_argtypes = [ctypes.c_uint64, ctypes.c_ubyte, ctypes.c_size_t, ctypes.c_void_p]
97+
dptr = ctypes.c_uint64(buf.data_ptr())
98+
value = ctypes.c_ubyte(0)
99+
elif is_hip():
100+
libname = "libamdhip64.so"
101+
init_name = "hipInit"
102+
memset_name = "hipMemsetAsync"
103+
memset_argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t, ctypes.c_void_p]
104+
dptr = ctypes.c_void_p(buf.data_ptr())
105+
value = ctypes.c_int(0)
106+
else:
107+
raise RuntimeError("Unsupported platform: neither CUDA nor ROCm detected")
108+
109+
lib = ctypes.CDLL(libname)
110+
111+
# optional init
112+
if hasattr(lib, init_name):
113+
init_fn = getattr(lib, init_name)
114+
init_fn.argtypes = [ctypes.c_uint]
115+
init_fn.restype = ctypes.c_int
116+
init_fn(0)
117+
118+
if not hasattr(lib, memset_name):
119+
raise RuntimeError(f"{memset_name} not found in {libname}")
120+
121+
memset_fn = getattr(lib, memset_name)
122+
memset_fn.argtypes = memset_argtypes
123+
memset_fn.restype = ctypes.c_int
124+
125+
def fn():
126+
err = memset_fn(dptr, value, ctypes.c_size_t(n_bytes), stream0)
127+
if err != 0:
128+
raise RuntimeError(f"{memset_name} failed with error {err}")
129+
104130
time_ms = triton.testing.do_bench(fn, rep=1000)
105131
tbps = (n_bytes / (time_ms * 1e-3)) * 1e-12
106132
return tbps
@@ -109,13 +135,20 @@ def get_memset_tbps():
109135
def get_cublas_tflops(dtype):
110136
dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[dtype]
111137
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
112-
cublas = nvidia.cublas.CublasLt(cublas_workspace)
138+
if is_cuda():
139+
cublas = nvidia.cublas.CublasLt(cublas_workspace)
140+
bench_fn = cublas.matmul
141+
elif is_hip():
142+
hipblas = amd.hipblas.HipblasLt(cublas_workspace)
143+
bench_fn = hipblas.matmul
144+
else:
145+
raise RuntimeError("Unsupported platform: neither CUDA nor ROCm detected")
113146
device = "cuda"
114147
M, N, K = 8192, 8192, 8192
115148
a = torch.randn(M, K, device=device, dtype=torch.float32).to(dtype)
116149
b = torch.randn(K, N, device=device, dtype=torch.float32).to(dtype).T
117150
c = torch.empty((M, N), device=device, dtype=dtype)
118-
time_ms = triton.testing.do_bench(lambda: cublas.matmul(a, b, c), rep=1000)
151+
time_ms = triton.testing.do_bench(lambda: bench_fn(a, b, c), rep=1000)
119152
return 2 * M * N * K / time_ms * 1e-9
120153

121154

third_party/amd/include/hipblas_instance.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
constexpr int HIPBLAS_COMPUTE_32F_FAST_F8 = 104;
1313
constexpr int HIPBLAS_COMPUTE_32F_FAST_FBF_OCP = 105;
1414

15-
class HipBlasLtInstance {
15+
class HipblasLtInstance {
1616
// Typedefs for hipblas functions
1717
typedef hipblasStatus_t (*hipblasLtCreate_t)(hipblasLtHandle_t *);
1818
typedef hipblasStatus_t (*hipblasLtDestroy_t)(hipblasLtHandle_t);
@@ -264,7 +264,7 @@ class HipBlasLtInstance {
264264
}
265265

266266
public:
267-
HipBlasLtInstance(uint64_t workspace, size_t workspaceSize)
267+
HipblasLtInstance(uint64_t workspace, size_t workspaceSize)
268268
: workspace((void *)workspace), workspaceSize(workspaceSize) {
269269
loadHipBlasDylib();
270270
successOrExit(hipblasLtCreate(&ltHandle));
@@ -273,7 +273,7 @@ class HipBlasLtInstance {
273273
preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,
274274
sizeof(workspaceSize)));
275275
}
276-
~HipBlasLtInstance() {
276+
~HipblasLtInstance() {
277277
if (preference)
278278
successOrExit(hipblasLtMatmulPreferenceDestroy(preference));
279279

third_party/amd/python/triton_amd.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,15 +515,15 @@ void init_triton_amd(py::module &&m) {
515515
});
516516

517517
auto hipBlas = m.def_submodule("hipblas");
518-
py::class_<HipBlasLtInstance>(hipBlas, "HipBlasLt")
518+
py::class_<HipblasLtInstance>(hipBlas, "HipblasLt")
519519
.def(py::init<>([&](py::object &workspace) {
520520
auto wrk_ptr = workspace.attr("data_ptr")().cast<uint64_t>();
521521
auto wrk_size = workspace.attr("numel")().cast<size_t>() *
522522
workspace.attr("element_size")().cast<size_t>();
523-
return new HipBlasLtInstance(wrk_ptr, wrk_size);
523+
return new HipblasLtInstance(wrk_ptr, wrk_size);
524524
}))
525525
.def("matmul",
526-
[](HipBlasLtInstance &self, py::object &A, py::object &B,
526+
[](HipblasLtInstance &self, py::object &A, py::object &B,
527527
py::object &C) {
528528
auto A_ptr = A.attr("data_ptr")().cast<uint64_t>();
529529
auto B_ptr = B.attr("data_ptr")().cast<uint64_t>();
@@ -532,7 +532,7 @@ void init_triton_amd(py::module &&m) {
532532
self.matmul(init.m, init.n, init.k, A_ptr, B_ptr, C_ptr,
533533
init.dtype);
534534
})
535-
.def("gemm", [](HipBlasLtInstance &self, py::object &A, py::object &B,
535+
.def("gemm", [](HipblasLtInstance &self, py::object &A, py::object &B,
536536
py::object &C, py::object &D, float alpha, float beta) {
537537
auto A_ptr = A.attr("data_ptr")().cast<uint64_t>();
538538
auto B_ptr = B.attr("data_ptr")().cast<uint64_t>();

0 commit comments

Comments
 (0)