Skip to content

Commit 06889c8

Browse files
committed
more tutorials
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent c3ba504 commit 06889c8

File tree

5 files changed

+42
-31
lines changed

5 files changed

+42
-31
lines changed

python/tutorials/05-layer-norm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
import triton
3535
import triton.language as tl
3636

37+
DEVICE = triton.runtime.driver.active.get_current_target().backend
38+
39+
3740
try:
3841
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
3942
# should not be added to extras_require in setup.py.
@@ -290,7 +293,7 @@ def backward(ctx, dy):
290293
layer_norm = LayerNorm.apply
291294

292295

293-
def test_layer_norm(M, N, dtype, eps=1e-5, device='xpu'):
296+
def test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE):
294297
# create data
295298
x_shape = (M, N)
296299
w_shape = (x_shape[-1], )
@@ -329,7 +332,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='xpu'):
329332
plot_name='layer-norm-backward',
330333
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'},
331334
))
332-
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='xpu'):
335+
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device=DEVICE):
333336
# create data
334337
x_shape = (M, N)
335338
w_shape = (x_shape[-1], )

python/tutorials/06-fused-attention.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import triton
2020
import triton.language as tl
2121

22+
DEVICE = triton.runtime.driver.active.get_current_target().backend
23+
2224

2325
def is_hip():
2426
return triton.runtime.driver.active.get_current_target().backend == "hip"
@@ -526,13 +528,13 @@ def backward(ctx, do):
526528
@pytest.mark.parametrize("causal", [True])
527529
def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
528530
torch.manual_seed(20)
529-
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xpu").normal_(mean=0.0, std=0.5).requires_grad_())
530-
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xpu").normal_(mean=0.0, std=0.5).requires_grad_())
531-
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xpu").normal_(mean=0.0, std=0.5).requires_grad_())
531+
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
532+
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
533+
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
532534
sm_scale = 0.5
533535
dout = torch.randn_like(q)
534536
# reference implementation
535-
M = torch.tril(torch.ones((N_CTX, N_CTX), device="xpu"))
537+
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
536538
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
537539
if causal:
538540
p[:, :, M == 0] = float("-inf")
@@ -600,7 +602,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
600602

601603

602604
@triton.testing.perf_report(configs)
603-
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="xpu"):
605+
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device=DEVICE):
604606
assert mode in ["fwd", "bwd"]
605607
dtype = torch.float16
606608
if "triton" in provider:

python/tutorials/07-extern-functions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
from pathlib import Path
2727

28+
DEVICE = triton.runtime.driver.active.get_current_target().backend
29+
2830

2931
@triton.jit
3032
def asin_kernel(
@@ -49,8 +51,8 @@ def asin_kernel(
4951

5052
torch.manual_seed(0)
5153
size = 98432
52-
x = torch.rand(size, device='xpu')
53-
output_triton = torch.zeros(size, device='xpu')
54+
x = torch.rand(size, device=DEVICE)
55+
output_triton = torch.zeros(size, device=DEVICE)
5456
output_torch = torch.asin(x)
5557
n_elements = output_torch.numel()
5658
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )

python/tutorials/08-grouped-gemm.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
import triton
3232
import triton.language as tl
3333

34+
DEVICE = triton.runtime.driver.active.get_current_target().backend
35+
3436

3537
def is_cuda():
3638
return triton.runtime.driver.active.get_current_target().backend == "cuda"
@@ -145,7 +147,7 @@ def grouped_matmul_kernel(
145147

146148

147149
def group_gemm_fn(group_A, group_B):
148-
device = torch.device('xpu')
150+
device = torch.device(DEVICE)
149151
assert len(group_A) == len(group_B)
150152
group_size = len(group_A)
151153

@@ -201,8 +203,8 @@ def group_gemm_fn(group_A, group_B):
201203
M = group_m[i]
202204
N = group_n[i]
203205
K = group_k[i]
204-
A = torch.rand((M, K), device="xpu", dtype=torch.float16)
205-
B = torch.rand((K, N), device="xpu", dtype=torch.float16)
206+
A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)
207+
B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)
206208
group_A.append(A)
207209
group_B.append(B)
208210

@@ -264,9 +266,9 @@ def benchmark(N, provider):
264266
g_lds = []
265267
group_C = []
266268
for i in range(group_size):
267-
A = torch.rand((N, N), device="xpu", dtype=torch.float16)
268-
B = torch.rand((N, N), device="xpu", dtype=torch.float16)
269-
C = torch.empty((N, N), device="xpu", dtype=torch.float16)
269+
A = torch.rand((N, N), device=DEVICE, dtype=torch.float16)
270+
B = torch.rand((N, N), device=DEVICE, dtype=torch.float16)
271+
C = torch.empty((N, N), device=DEVICE, dtype=torch.float16)
270272
group_A.append(A)
271273
group_B.append(B)
272274
group_C.append(C)
@@ -276,11 +278,11 @@ def benchmark(N, provider):
276278
g_sizes += [N, N, N]
277279
g_lds += [N, N, N]
278280

279-
d_a_ptrs = torch.tensor(A_addrs, device="xpu")
280-
d_b_ptrs = torch.tensor(B_addrs, device="xpu")
281-
d_c_ptrs = torch.tensor(C_addrs, device="xpu")
282-
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="xpu")
283-
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="xpu")
281+
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
282+
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
283+
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
284+
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
285+
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
284286

285287
quantiles = [0.5, 0.2, 0.8]
286288
if provider == ref_lib.lower():

python/tutorials/10-experimental-block-pointer.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@
9595
import triton
9696
import triton.language as tl
9797

98+
DEVICE = triton.runtime.driver.active.get_current_target().backend
99+
98100

99101
@triton.autotune(
100102
configs=[
@@ -345,23 +347,23 @@ def matmul(a, b, accum_dtype, res_dtype):
345347
# [ 1 1 1 ... ],
346348
# [ 0 1 1 ... ], ... ]
347349
# in order only add 3 values per result matrix element.
348-
a = torch.randn(shape, device='xpu', dtype=dtype)
349-
b = torch.eye(shape[-2], device='xpu', dtype=dtype) + torch.diag(
350-
torch.ones(shape[-2] - 1, device='xpu', dtype=dtype), diagonal=1) + torch.diag(
351-
torch.ones(shape[-2] - 1, device='xpu', dtype=dtype), diagonal=-1)
350+
a = torch.randn(shape, device=DEVICE, dtype=dtype)
351+
b = torch.eye(shape[-2], device=DEVICE, dtype=dtype) + torch.diag(
352+
torch.ones(shape[-2] - 1, device=DEVICE, dtype=dtype), diagonal=1) + torch.diag(
353+
torch.ones(shape[-2] - 1, device=DEVICE, dtype=dtype), diagonal=-1)
352354
# duplicate b on batch dimension.
353355
if len(shape) == 3:
354356
b = b.unsqueeze(0).repeat(shape[0], 1, 1)
355357
else:
356-
a = torch.randn(shape, device='xpu', dtype=dtype)
357-
b = torch.randn(shape, device='xpu', dtype=dtype)
358+
a = torch.randn(shape, device=DEVICE, dtype=dtype)
359+
b = torch.randn(shape, device=DEVICE, dtype=dtype)
358360
torch_output = torch.matmul(a, b).to(dtype=res_dtype)
359361
else:
360-
a = torch.randint(low=-127, high=128, size=shape, device='xpu', dtype=dtype)
361-
b = torch.randint(low=-127, high=128, size=shape, device='xpu', dtype=dtype)
362+
a = torch.randint(low=-127, high=128, size=shape, device=DEVICE, dtype=dtype)
363+
b = torch.randint(low=-127, high=128, size=shape, device=DEVICE, dtype=dtype)
362364
# torch.matmul clamps values to input dtype; IPEX doesn't support int32 matmul
363365
torch_output = torch.matmul(a.to(device='cpu', dtype=accum_dtype),
364-
b.to(device='cpu', dtype=accum_dtype)).to(device='xpu', dtype=res_dtype)
366+
b.to(device='cpu', dtype=accum_dtype)).to(device=DEVICE, dtype=res_dtype)
365367

366368
triton_output = matmul(a, b, accum_dtype, res_dtype)
367369

@@ -408,8 +410,8 @@ def matmul(a, b, accum_dtype, res_dtype):
408410

409411
@triton.testing.perf_report(configs)
410412
def benchmark(M, N, K, provider):
411-
a = torch.randn((M, K), device='xpu', dtype=torch.float16)
412-
b = torch.randn((K, N), device='xpu', dtype=torch.float16)
413+
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16)
414+
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16)
413415

414416
quantiles = [0.5, 0.2, 0.8]
415417
if provider == ref_lib.lower():

0 commit comments

Comments
 (0)