Skip to content

Commit c3ba504

Browse files
committed
Choose device in tutorials through 'driver.active.get_current_target().backend'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 3f6ffa5 commit c3ba504

File tree

4 files changed

+28
-19
lines changed

4 files changed

+28
-19
lines changed

python/tutorials/01-vector-add.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import triton
2424
import triton.language as tl
2525

26+
DEVICE = triton.runtime.driver.active.get_current_target().backend
27+
2628

2729
@triton.jit
2830
def add_kernel(x_ptr, # *Pointer* to first input vector.
@@ -60,7 +62,8 @@ def add_kernel(x_ptr, # *Pointer* to first input vector.
6062
def add(x: torch.Tensor, y: torch.Tensor):
6163
# We need to preallocate the output.
6264
output = torch.empty_like(x)
63-
assert x.is_xpu and y.is_xpu and output.is_xpu
65+
is_dvc = f"is_{DEVICE}"
66+
assert getattr(x, is_dvc) and getattr(y, is_dvc) and getattr(output, is_dvc)
6467
n_elements = output.numel()
6568
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
6669
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
@@ -81,8 +84,8 @@ def add(x: torch.Tensor, y: torch.Tensor):
8184

8285
torch.manual_seed(0)
8386
size = 98432
84-
x = torch.rand(size, device='xpu')
85-
y = torch.rand(size, device='xpu')
87+
x = torch.rand(size, device=DEVICE)
88+
y = torch.rand(size, device=DEVICE)
8689
output_torch = x + y
8790
output_triton = add(x, y)
8891
print(output_torch.cpu())
@@ -116,8 +119,8 @@ def add(x: torch.Tensor, y: torch.Tensor):
116119
args={}, # Values for function arguments not in `x_names` and `y_name`.
117120
))
118121
def benchmark(size, provider):
119-
x = torch.rand(size, device='xpu', dtype=torch.float32)
120-
y = torch.rand(size, device='xpu', dtype=torch.float32)
122+
x = torch.rand(size, device=DEVICE, dtype=torch.float32)
123+
y = torch.rand(size, device=DEVICE, dtype=torch.float32)
121124
quantiles = [0.5, 0.2, 0.8]
122125
if provider == 'torch':
123126
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)

python/tutorials/02-fused-softmax.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import triton.language as tl
2828
from triton.runtime import driver
2929

30+
DEVICE = triton.runtime.driver.active.get_current_target().backend
31+
3032

3133
def is_hip():
3234
return triton.runtime.driver.active.get_current_target().backend == "hip"
@@ -110,7 +112,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
110112
# %%
111113
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
112114

113-
device = torch.xpu.current_device()
115+
device = getattr(torch, DEVICE).current_device()
114116
properties = driver.active.utils.get_device_properties(device)
115117
NUM_SM = properties["multiprocessor_count"]
116118
SIZE_SMEM = properties["max_shared_mem"]
@@ -194,7 +196,7 @@ def allocated_slm_size(size_smem):
194196
# This will allow us to verify that our padding mechanism works.
195197

196198
torch.manual_seed(0)
197-
x = torch.randn(1823, 781, device='xpu')
199+
x = torch.randn(1823, 781, device=DEVICE)
198200
y_triton = softmax(x)
199201
y_torch = torch.softmax(x, axis=1)
200202
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
@@ -226,9 +228,9 @@ def allocated_slm_size(size_smem):
226228
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
227229
))
228230
def benchmark(M, N, provider):
229-
x = torch.randn(M, N, device='xpu', dtype=torch.float32)
230-
stream = torch.xpu.Stream()
231-
torch.xpu.set_stream(stream)
231+
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
232+
stream = getattr(torch, DEVICE).Stream()
233+
getattr(torch, DEVICE).set_stream(stream)
232234
if provider == 'torch':
233235
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
234236
if provider == 'triton':

python/tutorials/03-matrix-multiplication.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@
154154
import triton
155155
import triton.language as tl
156156

157+
DEVICE = triton.runtime.driver.active.get_current_target().backend
158+
157159

158160
def is_cuda():
159161
return triton.runtime.driver.active.get_current_target().backend == "cuda"
@@ -390,8 +392,8 @@ def matmul(a, b, activation=""):
390392
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS).
391393

392394
torch.manual_seed(0)
393-
a = torch.randn((512, 512), device='xpu', dtype=torch.float16)
394-
b = torch.randn((512, 512), device='xpu', dtype=torch.float16)
395+
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
396+
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
395397
triton_output = matmul(a, b)
396398
torch_output = torch.matmul(a, b)
397399
print(f"triton_output_with_fp16_inputs={triton_output}")
@@ -408,8 +410,8 @@ def matmul(a, b, activation=""):
408410
TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
409411
if TORCH_HAS_FP8 and is_cuda():
410412
torch.manual_seed(0)
411-
a = torch.randn((512, 512), device="xpu", dtype=torch.float16)
412-
b = torch.randn((512, 512), device="xpu", dtype=torch.float16)
413+
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
414+
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
413415
a = a.to(torch.float8_e5m2)
414416
# pre-transpose b for efficiency.
415417
b = b.T
@@ -458,8 +460,8 @@ def matmul(a, b, activation=""):
458460

459461
@triton.testing.perf_report(configs)
460462
def benchmark(M, N, K, provider, fp8_inputs):
461-
a = torch.randn((M, K), device='xpu', dtype=torch.float16)
462-
b = torch.randn((K, N), device='xpu', dtype=torch.float16)
463+
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16)
464+
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16)
463465
if TORCH_HAS_FP8 and fp8_inputs:
464466
a = a.to(torch.float8_e5m2)
465467
b = b.T

python/tutorials/04-low-memory-dropout.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
import triton
3939
import triton.language as tl
4040

41+
DEVICE = triton.runtime.driver.active.get_current_target().backend
42+
4143

4244
@triton.jit
4345
def _dropout(
@@ -71,10 +73,10 @@ def dropout(x, x_keep, p):
7173

7274

7375
# Input tensor
74-
x = torch.randn(size=(10, )).xpu()
76+
x = getattr(torch.randn(size=(10, )), DEVICE)()
7577
# Dropout mask
7678
p = 0.5
77-
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).xpu()
79+
x_keep = getattr((torch.rand(size=(10, )) > p).to(torch.int32), DEVICE)()
7880
#
7981
output = dropout(x, x_keep=x_keep, p=p)
8082
print(tabulate.tabulate([
@@ -138,7 +140,7 @@ def seeded_dropout(x, p, seed):
138140
return output
139141

140142

141-
x = torch.randn(size=(10, )).xpu()
143+
x = getattr(torch.randn(size=(10, )), DEVICE)()
142144
# Compare this to the baseline - dropout mask is never instantiated!
143145
output = seeded_dropout(x, p=0.5, seed=123)
144146
output2 = seeded_dropout(x, p=0.5, seed=123)

0 commit comments

Comments
 (0)