Skip to content

Commit c2e9c4a

Browse files
committed
address review comments
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent bd35f62 commit c2e9c4a

File tree

9 files changed

+22
-14
lines changed

9 files changed

+22
-14
lines changed

python/tutorials/01-vector-add.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
import triton
2424
import triton.language as tl
2525

26-
DEVICE = triton.runtime.driver.active.get_current_target().backend
26+
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
27+
triton.runtime.driver.active.get_current_device())
2728

2829

2930
@triton.jit
@@ -62,7 +63,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector.
6263
def add(x: torch.Tensor, y: torch.Tensor):
6364
# We need to preallocate the output.
6465
output = torch.empty_like(x)
65-
assert x.device.type == DEVICE and y.device.type == DEVICE and output.device.type == DEVICE
66+
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
6667
n_elements = output.numel()
6768
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
6869
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].

python/tutorials/02-fused-softmax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
import triton.language as tl
2828
from triton.runtime import driver
2929

30-
DEVICE = triton.runtime.driver.active.get_current_target().backend
30+
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
31+
triton.runtime.driver.active.get_current_device())
3132

3233

3334
def is_hip():
@@ -112,8 +113,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
112113
# %%
113114
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
114115

115-
device = getattr(torch, DEVICE).current_device()
116-
properties = driver.active.utils.get_device_properties(device)
116+
properties = driver.active.utils.get_device_properties(DEVICE.index)
117117
NUM_SM = properties["multiprocessor_count"]
118118
SIZE_SMEM = properties["max_shared_mem"]
119119
WARPS_PER_EU = 8 # TODO: Get from properties
@@ -229,8 +229,8 @@ def allocated_slm_size(size_smem):
229229
))
230230
def benchmark(M, N, provider):
231231
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
232-
stream = getattr(torch, DEVICE).Stream()
233-
getattr(torch, DEVICE).set_stream(stream)
232+
stream = getattr(torch, DEVICE.type).Stream()
233+
getattr(torch, DEVICE.type).set_stream(stream)
234234
if provider == 'torch':
235235
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
236236
if provider == 'triton':

python/tutorials/03-matrix-multiplication.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@
154154
import triton
155155
import triton.language as tl
156156

157-
DEVICE = triton.runtime.driver.active.get_current_target().backend
157+
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
158+
triton.runtime.driver.active.get_current_device())
158159

159160

160161
def is_cuda():

python/tutorials/04-low-memory-dropout.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
import triton
3939
import triton.language as tl
4040

41-
DEVICE = triton.runtime.driver.active.get_current_target().backend
41+
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
42+
triton.runtime.driver.active.get_current_device())
4243

4344

4445
@triton.jit

python/tutorials/05-layer-norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
except ModuleNotFoundError:
4343
HAS_APEX = False
4444

45-
DEVICE = triton.runtime.driver.active.get_current_target().backend
45+
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
46+
triton.runtime.driver.active.get_current_device())
4647

4748

4849
@triton.jit

python/tutorials/06-fused-attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
import triton
2020
import triton.language as tl
2121

22-
DEVICE = triton.runtime.driver.active.get_current_target().backend
22+
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
23+
triton.runtime.driver.active.get_current_device())
2324

2425

2526
def is_hip():

python/tutorials/07-extern-functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525

2626
from pathlib import Path
2727

28-
DEVICE = triton.runtime.driver.active.get_current_target().backend
28+
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
29+
triton.runtime.driver.active.get_current_device())
2930

3031

3132
@triton.jit

python/tutorials/08-grouped-gemm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
import triton
3232
import triton.language as tl
3333

34-
DEVICE = triton.runtime.driver.active.get_current_target().backend
34+
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
35+
triton.runtime.driver.active.get_current_device())
3536

3637

3738
def is_cuda():

python/tutorials/10-experimental-block-pointer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@
9595
import triton
9696
import triton.language as tl
9797

98-
DEVICE = triton.runtime.driver.active.get_current_target().backend
98+
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
99+
triton.runtime.driver.active.get_current_device())
99100

100101

101102
@triton.autotune(

0 commit comments

Comments
 (0)