Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions benchmarks/triton_kernels_benchmark/benchmark_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,9 @@ def get_current_target(self):
warp_size = 32
return GPUTarget("xpu", dev_property, warp_size)

def get_active_torch_device(self):
return torch.device("xpu", self.get_current_device())

@staticmethod
def is_active():
return torch.xpu.is_available()
9 changes: 7 additions & 2 deletions python/triton/backends/driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABCMeta, abstractmethod, abstractclassmethod
from abc import ABCMeta, abstractmethod
from typing import Callable, List, Protocol, Sequence


Expand All @@ -10,14 +10,19 @@ def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -

class DriverBase(metaclass=ABCMeta):

@abstractclassmethod
@classmethod
@abstractmethod
def is_active(self):
pass

@abstractmethod
def get_current_target(self):
pass

@abstractmethod
def get_active_torch_device(self):
pass

@abstractmethod
def get_benchmarker(self) -> Benchmarker:
"""
Expand Down
12 changes: 7 additions & 5 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()


@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
Expand Down Expand Up @@ -60,7 +62,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector.
def add(x: torch.Tensor, y: torch.Tensor):
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.is_xpu and y.is_xpu and output.is_xpu
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
Expand All @@ -81,8 +83,8 @@ def add(x: torch.Tensor, y: torch.Tensor):

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='xpu')
y = torch.rand(size, device='xpu')
x = torch.rand(size, device=DEVICE)
y = torch.rand(size, device=DEVICE)
output_torch = x + y
output_triton = add(x, y)
print(output_torch.cpu())
Expand Down Expand Up @@ -116,8 +118,8 @@ def add(x: torch.Tensor, y: torch.Tensor):
args={}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(size, provider):
x = torch.rand(size, device='xpu', dtype=torch.float32)
y = torch.rand(size, device='xpu', dtype=torch.float32)
x = torch.rand(size, device=DEVICE, dtype=torch.float32)
y = torch.rand(size, device=DEVICE, dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
Expand Down
14 changes: 7 additions & 7 deletions python/tutorials/02-fused-softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
Expand Down Expand Up @@ -110,8 +112,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
# %%
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

device = torch.xpu.current_device()
properties = driver.active.utils.get_device_properties(device)
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
SIZE_SMEM = properties["max_shared_mem"]
WARPS_PER_EU = 8 # TODO: Get from properties
Expand All @@ -120,7 +121,6 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
WARP_SIZE = properties["sub_group_sizes"][-1]
WG_SIZE = properties["max_work_group_size"]
max_num_warps = WG_SIZE // WARP_SIZE
target = triton.runtime.driver.active.get_current_target()
warps_per_sm = WARPS_PER_EU * EU_PER_SM
max_num_resident_warps = NUM_SM * warps_per_sm
kernels = {}
Expand Down Expand Up @@ -194,7 +194,7 @@ def allocated_slm_size(size_smem):
# This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device='xpu')
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
Expand Down Expand Up @@ -226,9 +226,9 @@ def allocated_slm_size(size_smem):
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device='xpu', dtype=torch.float32)
stream = torch.xpu.Stream()
torch.xpu.set_stream(stream)
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
Expand Down
14 changes: 8 additions & 6 deletions python/tutorials/03-matrix-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
Expand Down Expand Up @@ -390,8 +392,8 @@ def matmul(a, b, activation=""):
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS).

torch.manual_seed(0)
a = torch.randn((512, 512), device='xpu', dtype=torch.float16)
b = torch.randn((512, 512), device='xpu', dtype=torch.float16)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp16_inputs={triton_output}")
Expand All @@ -408,8 +410,8 @@ def matmul(a, b, activation=""):
TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
if TORCH_HAS_FP8 and is_cuda():
torch.manual_seed(0)
a = torch.randn((512, 512), device="xpu", dtype=torch.float16)
b = torch.randn((512, 512), device="xpu", dtype=torch.float16)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
a = a.to(torch.float8_e5m2)
# pre-transpose b for efficiency.
b = b.T
Expand Down Expand Up @@ -458,8 +460,8 @@ def matmul(a, b, activation=""):

@triton.testing.perf_report(configs)
def benchmark(M, N, K, provider, fp8_inputs):
a = torch.randn((M, K), device='xpu', dtype=torch.float16)
b = torch.randn((K, N), device='xpu', dtype=torch.float16)
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16)
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16)
if TORCH_HAS_FP8 and fp8_inputs:
a = a.to(torch.float8_e5m2)
b = b.T
Expand Down
8 changes: 5 additions & 3 deletions python/tutorials/04-low-memory-dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()


@triton.jit
def _dropout(
Expand Down Expand Up @@ -71,10 +73,10 @@ def dropout(x, x_keep, p):


# Input tensor
x = torch.randn(size=(10, )).xpu()
x = torch.randn(size=(10, ), device=DEVICE)
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).xpu()
x_keep = (torch.rand(size=(10, ), device=DEVICE) > p).to(torch.int32)
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
Expand Down Expand Up @@ -138,7 +140,7 @@ def seeded_dropout(x, p, seed):
return output


x = torch.randn(size=(10, )).xpu()
x = torch.randn(size=(10, ), device=DEVICE)
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
Expand Down
6 changes: 4 additions & 2 deletions python/tutorials/05-layer-norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
except ModuleNotFoundError:
HAS_APEX = False

DEVICE = triton.runtime.driver.active.get_active_torch_device()


@triton.jit
def _layer_norm_fwd_fused(
Expand Down Expand Up @@ -290,7 +292,7 @@ def backward(ctx, dy):
layer_norm = LayerNorm.apply


def test_layer_norm(M, N, dtype, eps=1e-5, device='xpu'):
def test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
Expand Down Expand Up @@ -329,7 +331,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='xpu'):
plot_name='layer-norm-backward',
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'},
))
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='xpu'):
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device=DEVICE):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
Expand Down
12 changes: 7 additions & 5 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
Expand Down Expand Up @@ -526,13 +528,13 @@ def backward(ctx, do):
@pytest.mark.parametrize("causal", [True])
def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
torch.manual_seed(20)
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xpu").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xpu").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xpu").normal_(mean=0.0, std=0.5).requires_grad_())
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
sm_scale = 0.5
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="xpu"))
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")
Expand Down Expand Up @@ -600,7 +602,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):


@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="xpu"):
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device=DEVICE):
assert mode in ["fwd", "bwd"]
dtype = torch.float16
if "triton" in provider:
Expand Down
6 changes: 4 additions & 2 deletions python/tutorials/07-extern-functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

from pathlib import Path

DEVICE = triton.runtime.driver.active.get_active_torch_device()


@triton.jit
def asin_kernel(
Expand All @@ -49,8 +51,8 @@ def asin_kernel(

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='xpu')
output_triton = torch.zeros(size, device='xpu')
x = torch.rand(size, device=DEVICE)
output_triton = torch.zeros(size, device=DEVICE)
output_torch = torch.asin(x)
n_elements = output_torch.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
Expand Down
35 changes: 18 additions & 17 deletions python/tutorials/08-grouped-gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
Expand Down Expand Up @@ -145,7 +147,6 @@ def grouped_matmul_kernel(


def group_gemm_fn(group_A, group_B):
device = torch.device('xpu')
assert len(group_A) == len(group_B)
group_size = len(group_A)

Expand All @@ -161,7 +162,7 @@ def group_gemm_fn(group_A, group_B):
assert A.shape[1] == B.shape[0]
M, K = A.shape
K, N = B.shape
C = torch.empty((M, N), device=device, dtype=A.dtype)
C = torch.empty((M, N), device=DEVICE, dtype=A.dtype)
group_C.append(C)
A_addrs.append(A.data_ptr())
B_addrs.append(B.data_ptr())
Expand All @@ -170,11 +171,11 @@ def group_gemm_fn(group_A, group_B):
g_lds += [A.stride(0), B.stride(0), C.stride(0)]

# note these are device tensors
d_a_ptrs = torch.tensor(A_addrs, device=device)
d_b_ptrs = torch.tensor(B_addrs, device=device)
d_c_ptrs = torch.tensor(C_addrs, device=device)
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device)
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device)
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
# we use a fixed number of CTA, and it's auto-tunable
grid = lambda META: (META['NUM_SM'], )
grouped_matmul_kernel[grid](
Expand All @@ -201,8 +202,8 @@ def group_gemm_fn(group_A, group_B):
M = group_m[i]
N = group_n[i]
K = group_k[i]
A = torch.rand((M, K), device="xpu", dtype=torch.float16)
B = torch.rand((K, N), device="xpu", dtype=torch.float16)
A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)
B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)
group_A.append(A)
group_B.append(B)

Expand Down Expand Up @@ -264,9 +265,9 @@ def benchmark(N, provider):
g_lds = []
group_C = []
for i in range(group_size):
A = torch.rand((N, N), device="xpu", dtype=torch.float16)
B = torch.rand((N, N), device="xpu", dtype=torch.float16)
C = torch.empty((N, N), device="xpu", dtype=torch.float16)
A = torch.rand((N, N), device=DEVICE, dtype=torch.float16)
B = torch.rand((N, N), device=DEVICE, dtype=torch.float16)
C = torch.empty((N, N), device=DEVICE, dtype=torch.float16)
group_A.append(A)
group_B.append(B)
group_C.append(C)
Expand All @@ -276,11 +277,11 @@ def benchmark(N, provider):
g_sizes += [N, N, N]
g_lds += [N, N, N]

d_a_ptrs = torch.tensor(A_addrs, device="xpu")
d_b_ptrs = torch.tensor(B_addrs, device="xpu")
d_c_ptrs = torch.tensor(C_addrs, device="xpu")
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="xpu")
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="xpu")
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)

quantiles = [0.5, 0.2, 0.8]
if provider == ref_lib.lower():
Expand Down
Loading