Skip to content

Commit 5df82a0

Browse files
committed
add 'get_active_torch_device'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 4564326 commit 5df82a0

File tree

14 files changed

+32
-21
lines changed

14 files changed

+32
-21
lines changed

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,9 @@ def get_current_target(self):
499499
warp_size = 32
500500
return GPUTarget("xpu", dev_property, warp_size)
501501

502+
def get_active_torch_device(self):
503+
return torch.device("xpu", self.get_current_device())
504+
502505
@staticmethod
503506
def is_active():
504507
return torch.xpu.is_available()

python/triton/backends/driver.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from abc import ABCMeta, abstractmethod, abstractclassmethod
1+
from abc import ABCMeta, abstractmethod
22
from typing import Callable, List, Protocol, Sequence
33

44

@@ -10,14 +10,19 @@ def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -
1010

1111
class DriverBase(metaclass=ABCMeta):
1212

13-
@abstractclassmethod
13+
@classmethod
14+
@abstractmethod
1415
def is_active(self):
1516
pass
1617

1718
@abstractmethod
1819
def get_current_target(self):
1920
pass
2021

22+
@abstractmethod
23+
def get_active_torch_device(self):
24+
pass
25+
2126
@abstractmethod
2227
def get_benchmarker(self) -> Benchmarker:
2328
"""

python/tutorials/01-vector-add.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
import triton
2424
import triton.language as tl
2525

26-
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
27-
triton.runtime.driver.active.get_current_device())
26+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
2827

2928

3029
@triton.jit

python/tutorials/02-fused-softmax.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
import triton.language as tl
2828
from triton.runtime import driver
2929

30-
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
31-
triton.runtime.driver.active.get_current_device())
30+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
3231

3332

3433
def is_hip():
@@ -122,7 +121,6 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
122121
WARP_SIZE = properties["sub_group_sizes"][-1]
123122
WG_SIZE = properties["max_work_group_size"]
124123
max_num_warps = WG_SIZE // WARP_SIZE
125-
target = triton.runtime.driver.active.get_current_target()
126124
warps_per_sm = WARPS_PER_EU * EU_PER_SM
127125
max_num_resident_warps = NUM_SM * warps_per_sm
128126
kernels = {}

python/tutorials/03-matrix-multiplication.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,7 @@
154154
import triton
155155
import triton.language as tl
156156

157-
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
158-
triton.runtime.driver.active.get_current_device())
157+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
159158

160159

161160
def is_cuda():

python/tutorials/04-low-memory-dropout.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@
3838
import triton
3939
import triton.language as tl
4040

41-
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
42-
triton.runtime.driver.active.get_current_device())
41+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
4342

4443

4544
@triton.jit

python/tutorials/05-layer-norm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@
4242
except ModuleNotFoundError:
4343
HAS_APEX = False
4444

45-
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
46-
triton.runtime.driver.active.get_current_device())
45+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
4746

4847

4948
@triton.jit

python/tutorials/06-fused-attention.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
import triton
2020
import triton.language as tl
2121

22-
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
23-
triton.runtime.driver.active.get_current_device())
22+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
2423

2524

2625
def is_hip():

python/tutorials/07-extern-functions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525

2626
from pathlib import Path
2727

28-
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
29-
triton.runtime.driver.active.get_current_device())
28+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
3029

3130

3231
@triton.jit

python/tutorials/08-grouped-gemm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131
import triton
3232
import triton.language as tl
3333

34-
DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend,
35-
triton.runtime.driver.active.get_current_device())
34+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
3635

3736

3837
def is_cuda():

0 commit comments

Comments
 (0)