Skip to content

Commit 105cb56

Browse files
authored
Use get_current_target function to select the device to run tutorials on (triton-lang#5286)
This pull request contains changes for all tutorials except `09-persistent-matmul.py`, as there is a lot of cuda-specific function. --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 9743ec0 commit 105cb56

File tree

11 files changed

+76
-48
lines changed

11 files changed

+76
-48
lines changed

python/triton/backends/driver.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from abc import ABCMeta, abstractmethod, abstractclassmethod
1+
from abc import ABCMeta, abstractmethod
22
from typing import Callable, List, Protocol, Sequence
33

44

@@ -10,14 +10,19 @@ def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -
1010

1111
class DriverBase(metaclass=ABCMeta):
1212

13-
@abstractclassmethod
13+
@classmethod
14+
@abstractmethod
1415
def is_active(self):
1516
pass
1617

1718
@abstractmethod
1819
def get_current_target(self):
1920
pass
2021

22+
@abstractmethod
23+
def get_active_torch_device(self):
24+
pass
25+
2126
@abstractmethod
2227
def get_benchmarker(self) -> Benchmarker:
2328
"""

python/tutorials/01-vector-add.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import triton
2424
import triton.language as tl
2525

26+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
27+
2628

2729
@triton.jit
2830
def add_kernel(x_ptr, # *Pointer* to first input vector.
@@ -60,7 +62,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector.
6062
def add(x: torch.Tensor, y: torch.Tensor):
6163
# We need to preallocate the output.
6264
output = torch.empty_like(x)
63-
assert x.is_cuda and y.is_cuda and output.is_cuda
65+
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
6466
n_elements = output.numel()
6567
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
6668
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
@@ -81,8 +83,8 @@ def add(x: torch.Tensor, y: torch.Tensor):
8183

8284
torch.manual_seed(0)
8385
size = 98432
84-
x = torch.rand(size, device='cuda')
85-
y = torch.rand(size, device='cuda')
86+
x = torch.rand(size, device=DEVICE)
87+
y = torch.rand(size, device=DEVICE)
8688
output_torch = x + y
8789
output_triton = add(x, y)
8890
print(output_torch)
@@ -116,8 +118,8 @@ def add(x: torch.Tensor, y: torch.Tensor):
116118
args={}, # Values for function arguments not in `x_names` and `y_name`.
117119
))
118120
def benchmark(size, provider):
119-
x = torch.rand(size, device='cuda', dtype=torch.float32)
120-
y = torch.rand(size, device='cuda', dtype=torch.float32)
121+
x = torch.rand(size, device=DEVICE, dtype=torch.float32)
122+
y = torch.rand(size, device=DEVICE, dtype=torch.float32)
121123
quantiles = [0.5, 0.2, 0.8]
122124
if provider == 'torch':
123125
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)

python/tutorials/02-fused-softmax.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import triton.language as tl
2828
from triton.runtime import driver
2929

30+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
31+
3032

3133
def is_hip():
3234
return triton.runtime.driver.active.get_current_target().backend == "hip"
@@ -110,8 +112,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
110112
# %%
111113
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
112114

113-
device = torch.cuda.current_device()
114-
properties = driver.active.utils.get_device_properties(device)
115+
properties = driver.active.utils.get_device_properties(DEVICE.index)
115116
NUM_SM = properties["multiprocessor_count"]
116117
NUM_REGS = properties["max_num_regs"]
117118
SIZE_SMEM = properties["max_shared_mem"]
@@ -189,7 +190,7 @@ def softmax(x):
189190
# This will allow us to verify that our padding mechanism works.
190191

191192
torch.manual_seed(0)
192-
x = torch.randn(1823, 781, device='cuda')
193+
x = torch.randn(1823, 781, device=DEVICE)
193194
y_triton = softmax(x)
194195
y_torch = torch.softmax(x, axis=1)
195196
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
@@ -221,9 +222,9 @@ def softmax(x):
221222
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
222223
))
223224
def benchmark(M, N, provider):
224-
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
225-
stream = torch.cuda.Stream()
226-
torch.cuda.set_stream(stream)
225+
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
226+
stream = getattr(torch, DEVICE.type).Stream()
227+
getattr(torch, DEVICE.type).set_stream(stream)
227228
if provider == 'torch':
228229
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
229230
if provider == 'triton':

python/tutorials/03-matrix-multiplication.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@
154154
import triton
155155
import triton.language as tl
156156

157+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
158+
157159

158160
def is_cuda():
159161
return triton.runtime.driver.active.get_current_target().backend == "cuda"
@@ -355,8 +357,8 @@ def matmul(a, b, activation=""):
355357
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS).
356358

357359
torch.manual_seed(0)
358-
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
359-
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
360+
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
361+
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
360362
triton_output = matmul(a, b)
361363
torch_output = torch.matmul(a, b)
362364
print(f"triton_output_with_fp16_inputs={triton_output}")
@@ -373,8 +375,8 @@ def matmul(a, b, activation=""):
373375
TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
374376
if TORCH_HAS_FP8 and is_cuda():
375377
torch.manual_seed(0)
376-
a = torch.randn((512, 512), device="cuda", dtype=torch.float16)
377-
b = torch.randn((512, 512), device="cuda", dtype=torch.float16)
378+
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
379+
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
378380
a = a.to(torch.float8_e5m2)
379381
# pre-transpose b for efficiency.
380382
b = b.T
@@ -423,8 +425,8 @@ def matmul(a, b, activation=""):
423425

424426
@triton.testing.perf_report(configs)
425427
def benchmark(M, N, K, provider, fp8_inputs):
426-
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
427-
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
428+
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16)
429+
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16)
428430
if TORCH_HAS_FP8 and fp8_inputs:
429431
a = a.to(torch.float8_e5m2)
430432
b = b.T

python/tutorials/04-low-memory-dropout.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
import triton
3939
import triton.language as tl
4040

41+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
42+
4143

4244
@triton.jit
4345
def _dropout(
@@ -71,10 +73,10 @@ def dropout(x, x_keep, p):
7173

7274

7375
# Input tensor
74-
x = torch.randn(size=(10, )).cuda()
76+
x = torch.randn(size=(10, ), device=DEVICE)
7577
# Dropout mask
7678
p = 0.5
77-
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()
79+
x_keep = (torch.rand(size=(10, ), device=DEVICE) > p).to(torch.int32)
7880
#
7981
output = dropout(x, x_keep=x_keep, p=p)
8082
print(tabulate.tabulate([
@@ -138,7 +140,7 @@ def seeded_dropout(x, p, seed):
138140
return output
139141

140142

141-
x = torch.randn(size=(10, )).cuda()
143+
x = torch.randn(size=(10, ), device=DEVICE)
142144
# Compare this to the baseline - dropout mask is never instantiated!
143145
output = seeded_dropout(x, p=0.5, seed=123)
144146
output2 = seeded_dropout(x, p=0.5, seed=123)

python/tutorials/05-layer-norm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
except ModuleNotFoundError:
4343
HAS_APEX = False
4444

45+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
46+
4547

4648
@triton.jit
4749
def _layer_norm_fwd_fused(
@@ -290,7 +292,7 @@ def backward(ctx, dy):
290292
layer_norm = LayerNorm.apply
291293

292294

293-
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
295+
def test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE):
294296
# create data
295297
x_shape = (M, N)
296298
w_shape = (x_shape[-1], )
@@ -328,7 +330,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
328330
plot_name='layer-norm-backward',
329331
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'},
330332
))
331-
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
333+
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device=DEVICE):
332334
# create data
333335
x_shape = (M, N)
334336
w_shape = (x_shape[-1], )

python/tutorials/06-fused-attention.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import triton
2020
import triton.language as tl
2121

22+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
23+
2224

2325
def is_hip():
2426
return triton.runtime.driver.active.get_current_target().backend == "hip"
@@ -526,13 +528,13 @@ def backward(ctx, do):
526528
@pytest.mark.parametrize("causal", [True])
527529
def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
528530
torch.manual_seed(20)
529-
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
530-
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
531-
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
531+
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
532+
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
533+
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
532534
sm_scale = 0.5
533535
dout = torch.randn_like(q)
534536
# reference implementation
535-
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
537+
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
536538
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
537539
if causal:
538540
p[:, :, M == 0] = float("-inf")
@@ -599,7 +601,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
599601

600602

601603
@triton.testing.perf_report(configs)
602-
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"):
604+
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device=DEVICE):
603605
assert mode in ["fwd", "bwd"]
604606
dtype = torch.float16
605607
if "triton" in provider:

python/tutorials/07-extern-functions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
from pathlib import Path
2727

28+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
29+
2830

2931
@triton.jit
3032
def asin_kernel(
@@ -49,8 +51,8 @@ def asin_kernel(
4951

5052
torch.manual_seed(0)
5153
size = 98432
52-
x = torch.rand(size, device='cuda')
53-
output_triton = torch.zeros(size, device='cuda')
54+
x = torch.rand(size, device=DEVICE)
55+
output_triton = torch.zeros(size, device=DEVICE)
5456
output_torch = torch.asin(x)
5557
assert x.is_cuda and output_triton.is_cuda
5658
n_elements = output_torch.numel()

python/tutorials/08-grouped-gemm.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
import triton
3232
import triton.language as tl
3333

34+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
35+
3436

3537
@triton.autotune(
3638
configs=[
@@ -141,7 +143,6 @@ def grouped_matmul_kernel(
141143

142144

143145
def group_gemm_fn(group_A, group_B):
144-
device = torch.device('cuda')
145146
assert len(group_A) == len(group_B)
146147
group_size = len(group_A)
147148

@@ -157,7 +158,7 @@ def group_gemm_fn(group_A, group_B):
157158
assert A.shape[1] == B.shape[0]
158159
M, K = A.shape
159160
K, N = B.shape
160-
C = torch.empty((M, N), device=device, dtype=A.dtype)
161+
C = torch.empty((M, N), device=DEVICE, dtype=A.dtype)
161162
group_C.append(C)
162163
A_addrs.append(A.data_ptr())
163164
B_addrs.append(B.data_ptr())
@@ -166,11 +167,11 @@ def group_gemm_fn(group_A, group_B):
166167
g_lds += [A.stride(0), B.stride(0), C.stride(0)]
167168

168169
# note these are device tensors
169-
d_a_ptrs = torch.tensor(A_addrs, device=device)
170-
d_b_ptrs = torch.tensor(B_addrs, device=device)
171-
d_c_ptrs = torch.tensor(C_addrs, device=device)
172-
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device)
173-
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device)
170+
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
171+
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
172+
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
173+
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
174+
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
174175
# we use a fixed number of CTA, and it's auto-tunable
175176
grid = lambda META: (META['NUM_SM'], )
176177
grouped_matmul_kernel[grid](
@@ -197,8 +198,8 @@ def group_gemm_fn(group_A, group_B):
197198
M = group_m[i]
198199
N = group_n[i]
199200
K = group_k[i]
200-
A = torch.rand((M, K), device="cuda", dtype=torch.float16)
201-
B = torch.rand((K, N), device="cuda", dtype=torch.float16)
201+
A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)
202+
B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)
202203
group_A.append(A)
203204
group_B.append(B)
204205

@@ -255,9 +256,9 @@ def benchmark(N, provider):
255256
g_lds = []
256257
group_C = []
257258
for i in range(group_size):
258-
A = torch.rand((N, N), device="cuda", dtype=torch.float16)
259-
B = torch.rand((N, N), device="cuda", dtype=torch.float16)
260-
C = torch.empty((N, N), device="cuda", dtype=torch.float16)
259+
A = torch.rand((N, N), device=DEVICE, dtype=torch.float16)
260+
B = torch.rand((N, N), device=DEVICE, dtype=torch.float16)
261+
C = torch.empty((N, N), device=DEVICE, dtype=torch.float16)
261262
group_A.append(A)
262263
group_B.append(B)
263264
group_C.append(C)
@@ -267,11 +268,11 @@ def benchmark(N, provider):
267268
g_sizes += [N, N, N]
268269
g_lds += [N, N, N]
269270

270-
d_a_ptrs = torch.tensor(A_addrs, device="cuda")
271-
d_b_ptrs = torch.tensor(B_addrs, device="cuda")
272-
d_c_ptrs = torch.tensor(C_addrs, device="cuda")
273-
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="cuda")
274-
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="cuda")
271+
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
272+
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
273+
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
274+
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
275+
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
275276

276277
quantiles = [0.5, 0.2, 0.8]
277278
if provider == 'cublas':

third_party/amd/backend/driver.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,11 @@ def get_current_target(self):
505505
warp_size = device_properties['warpSize']
506506
return GPUTarget("hip", arch.split(':')[0], warp_size)
507507

508+
def get_active_torch_device(self):
509+
import torch
510+
# when using hip devices, the device string in pytorch is "cuda"
511+
return torch.device("cuda", self.get_current_device())
512+
508513
def get_benchmarker(self):
509514
from triton.testing import do_bench
510515
return do_bench

0 commit comments

Comments
 (0)