diff --git a/benchmarks/triton_kernels_benchmark/benchmark_driver.py b/benchmarks/triton_kernels_benchmark/benchmark_driver.py index d9b55cc988..f22b970c42 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_driver.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_driver.py @@ -499,6 +499,9 @@ def get_current_target(self): warp_size = 32 return GPUTarget("xpu", dev_property, warp_size) + def get_active_torch_device(self): + return torch.device("xpu", self.get_current_device()) + @staticmethod def is_active(): return torch.xpu.is_available() diff --git a/python/triton/backends/driver.py b/python/triton/backends/driver.py index 202ae15686..6606b21ca8 100644 --- a/python/triton/backends/driver.py +++ b/python/triton/backends/driver.py @@ -1,4 +1,4 @@ -from abc import ABCMeta, abstractmethod, abstractclassmethod +from abc import ABCMeta, abstractmethod from typing import Callable, List, Protocol, Sequence @@ -10,7 +10,8 @@ def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) - class DriverBase(metaclass=ABCMeta): - @abstractclassmethod + @classmethod + @abstractmethod def is_active(self): pass @@ -18,6 +19,10 @@ def is_active(self): def get_current_target(self): pass + @abstractmethod + def get_active_torch_device(self): + pass + @abstractmethod def get_benchmarker(self) -> Benchmarker: """ diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 0449b802e5..1e77ca7c1d 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -23,6 +23,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. @@ -60,7 +62,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. def add(x: torch.Tensor, y: torch.Tensor): # We need to preallocate the output. output = torch.empty_like(x) - assert x.is_xpu and y.is_xpu and output.is_xpu + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE n_elements = output.numel() # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. @@ -81,8 +83,8 @@ def add(x: torch.Tensor, y: torch.Tensor): torch.manual_seed(0) size = 98432 -x = torch.rand(size, device='xpu') -y = torch.rand(size, device='xpu') +x = torch.rand(size, device=DEVICE) +y = torch.rand(size, device=DEVICE) output_torch = x + y output_triton = add(x, y) print(output_torch.cpu()) @@ -116,8 +118,8 @@ def add(x: torch.Tensor, y: torch.Tensor): args={}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(size, provider): - x = torch.rand(size, device='xpu', dtype=torch.float32) - y = torch.rand(size, device='xpu', dtype=torch.float32) + x = torch.rand(size, device=DEVICE, dtype=torch.float32) + y = torch.rand(size, device=DEVICE, dtype=torch.float32) quantiles = [0.5, 0.2, 0.8] if provider == 'torch': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index f8e72e05a2..5785c7edde 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -27,6 +27,8 @@ import triton.language as tl from triton.runtime import driver +DEVICE = triton.runtime.driver.active.get_active_torch_device() + def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" @@ -110,8 +112,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n # %% # We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. -device = torch.xpu.current_device() -properties = driver.active.utils.get_device_properties(device) +properties = driver.active.utils.get_device_properties(DEVICE.index) NUM_SM = properties["multiprocessor_count"] SIZE_SMEM = properties["max_shared_mem"] WARPS_PER_EU = 8 # TODO: Get from properties @@ -120,7 +121,6 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n WARP_SIZE = properties["sub_group_sizes"][-1] WG_SIZE = properties["max_work_group_size"] max_num_warps = WG_SIZE // WARP_SIZE -target = triton.runtime.driver.active.get_current_target() warps_per_sm = WARPS_PER_EU * EU_PER_SM max_num_resident_warps = NUM_SM * warps_per_sm kernels = {} @@ -194,7 +194,7 @@ def allocated_slm_size(size_smem): # This will allow us to verify that our padding mechanism works. torch.manual_seed(0) -x = torch.randn(1823, 781, device='xpu') +x = torch.randn(1823, 781, device=DEVICE) y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) @@ -226,9 +226,9 @@ def allocated_slm_size(size_smem): args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` )) def benchmark(M, N, provider): - x = torch.randn(M, N, device='xpu', dtype=torch.float32) - stream = torch.xpu.Stream() - torch.xpu.set_stream(stream) + x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) + stream = getattr(torch, DEVICE.type).Stream() + getattr(torch, DEVICE.type).set_stream(stream) if provider == 'torch': ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == 'triton': diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index d7f4fda652..07121e6b9c 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -154,6 +154,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" @@ -390,8 +392,8 @@ def matmul(a, b, activation=""): # We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). torch.manual_seed(0) -a = torch.randn((512, 512), device='xpu', dtype=torch.float16) -b = torch.randn((512, 512), device='xpu', dtype=torch.float16) +a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) +b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) triton_output = matmul(a, b) torch_output = torch.matmul(a, b) print(f"triton_output_with_fp16_inputs={triton_output}") @@ -408,8 +410,8 @@ def matmul(a, b, activation=""): TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2") if TORCH_HAS_FP8 and is_cuda(): torch.manual_seed(0) - a = torch.randn((512, 512), device="xpu", dtype=torch.float16) - b = torch.randn((512, 512), device="xpu", dtype=torch.float16) + a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) + b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) a = a.to(torch.float8_e5m2) # pre-transpose b for efficiency. b = b.T @@ -458,8 +460,8 @@ def matmul(a, b, activation=""): @triton.testing.perf_report(configs) def benchmark(M, N, K, provider, fp8_inputs): - a = torch.randn((M, K), device='xpu', dtype=torch.float16) - b = torch.randn((K, N), device='xpu', dtype=torch.float16) + a = torch.randn((M, K), device=DEVICE, dtype=torch.float16) + b = torch.randn((K, N), device=DEVICE, dtype=torch.float16) if TORCH_HAS_FP8 and fp8_inputs: a = a.to(torch.float8_e5m2) b = b.T diff --git a/python/tutorials/04-low-memory-dropout.py b/python/tutorials/04-low-memory-dropout.py index fc1fceb5a7..3dd84da47e 100644 --- a/python/tutorials/04-low-memory-dropout.py +++ b/python/tutorials/04-low-memory-dropout.py @@ -38,6 +38,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.jit def _dropout( @@ -71,10 +73,10 @@ def dropout(x, x_keep, p): # Input tensor -x = torch.randn(size=(10, )).xpu() +x = torch.randn(size=(10, ), device=DEVICE) # Dropout mask p = 0.5 -x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).xpu() +x_keep = (torch.rand(size=(10, ), device=DEVICE) > p).to(torch.int32) # output = dropout(x, x_keep=x_keep, p=p) print(tabulate.tabulate([ @@ -138,7 +140,7 @@ def seeded_dropout(x, p, seed): return output -x = torch.randn(size=(10, )).xpu() +x = torch.randn(size=(10, ), device=DEVICE) # Compare this to the baseline - dropout mask is never instantiated! output = seeded_dropout(x, p=0.5, seed=123) output2 = seeded_dropout(x, p=0.5, seed=123) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index c9d00d4593..85a8500308 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -42,6 +42,8 @@ except ModuleNotFoundError: HAS_APEX = False +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.jit def _layer_norm_fwd_fused( @@ -290,7 +292,7 @@ def backward(ctx, dy): layer_norm = LayerNorm.apply -def test_layer_norm(M, N, dtype, eps=1e-5, device='xpu'): +def test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) @@ -329,7 +331,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='xpu'): plot_name='layer-norm-backward', args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}, )) -def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='xpu'): +def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device=DEVICE): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index b646683763..b753d331f6 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -19,6 +19,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" @@ -526,13 +528,13 @@ def backward(ctx, do): @pytest.mark.parametrize("causal", [True]) def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): torch.manual_seed(20) - q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xpu").normal_(mean=0.0, std=0.5).requires_grad_()) - k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xpu").normal_(mean=0.0, std=0.5).requires_grad_()) - v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xpu").normal_(mean=0.0, std=0.5).requires_grad_()) + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) sm_scale = 0.5 dout = torch.randn_like(q) # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX), device="xpu")) + M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE)) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale if causal: p[:, :, M == 0] = float("-inf") @@ -600,7 +602,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): @triton.testing.perf_report(configs) -def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="xpu"): +def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device=DEVICE): assert mode in ["fwd", "bwd"] dtype = torch.float16 if "triton" in provider: diff --git a/python/tutorials/07-extern-functions.py b/python/tutorials/07-extern-functions.py index 45e4c697c4..6c1007befe 100644 --- a/python/tutorials/07-extern-functions.py +++ b/python/tutorials/07-extern-functions.py @@ -25,6 +25,8 @@ from pathlib import Path +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.jit def asin_kernel( @@ -49,8 +51,8 @@ def asin_kernel( torch.manual_seed(0) size = 98432 -x = torch.rand(size, device='xpu') -output_triton = torch.zeros(size, device='xpu') +x = torch.rand(size, device=DEVICE) +output_triton = torch.zeros(size, device=DEVICE) output_torch = torch.asin(x) n_elements = output_torch.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) diff --git a/python/tutorials/08-grouped-gemm.py b/python/tutorials/08-grouped-gemm.py index 8814187230..6f55fd9dcc 100644 --- a/python/tutorials/08-grouped-gemm.py +++ b/python/tutorials/08-grouped-gemm.py @@ -31,6 +31,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" @@ -145,7 +147,6 @@ def grouped_matmul_kernel( def group_gemm_fn(group_A, group_B): - device = torch.device('xpu') assert len(group_A) == len(group_B) group_size = len(group_A) @@ -161,7 +162,7 @@ def group_gemm_fn(group_A, group_B): assert A.shape[1] == B.shape[0] M, K = A.shape K, N = B.shape - C = torch.empty((M, N), device=device, dtype=A.dtype) + C = torch.empty((M, N), device=DEVICE, dtype=A.dtype) group_C.append(C) A_addrs.append(A.data_ptr()) B_addrs.append(B.data_ptr()) @@ -170,11 +171,11 @@ def group_gemm_fn(group_A, group_B): g_lds += [A.stride(0), B.stride(0), C.stride(0)] # note these are device tensors - d_a_ptrs = torch.tensor(A_addrs, device=device) - d_b_ptrs = torch.tensor(B_addrs, device=device) - d_c_ptrs = torch.tensor(C_addrs, device=device) - d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device) - d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device) + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) # we use a fixed number of CTA, and it's auto-tunable grid = lambda META: (META['NUM_SM'], ) grouped_matmul_kernel[grid]( @@ -201,8 +202,8 @@ def group_gemm_fn(group_A, group_B): M = group_m[i] N = group_n[i] K = group_k[i] - A = torch.rand((M, K), device="xpu", dtype=torch.float16) - B = torch.rand((K, N), device="xpu", dtype=torch.float16) + A = torch.rand((M, K), device=DEVICE, dtype=torch.float16) + B = torch.rand((K, N), device=DEVICE, dtype=torch.float16) group_A.append(A) group_B.append(B) @@ -264,9 +265,9 @@ def benchmark(N, provider): g_lds = [] group_C = [] for i in range(group_size): - A = torch.rand((N, N), device="xpu", dtype=torch.float16) - B = torch.rand((N, N), device="xpu", dtype=torch.float16) - C = torch.empty((N, N), device="xpu", dtype=torch.float16) + A = torch.rand((N, N), device=DEVICE, dtype=torch.float16) + B = torch.rand((N, N), device=DEVICE, dtype=torch.float16) + C = torch.empty((N, N), device=DEVICE, dtype=torch.float16) group_A.append(A) group_B.append(B) group_C.append(C) @@ -276,11 +277,11 @@ def benchmark(N, provider): g_sizes += [N, N, N] g_lds += [N, N, N] - d_a_ptrs = torch.tensor(A_addrs, device="xpu") - d_b_ptrs = torch.tensor(B_addrs, device="xpu") - d_c_ptrs = torch.tensor(C_addrs, device="xpu") - d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="xpu") - d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="xpu") + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) quantiles = [0.5, 0.2, 0.8] if provider == ref_lib.lower(): diff --git a/python/tutorials/10-experimental-block-pointer.py b/python/tutorials/10-experimental-block-pointer.py index 78896af957..9a3fec38b7 100644 --- a/python/tutorials/10-experimental-block-pointer.py +++ b/python/tutorials/10-experimental-block-pointer.py @@ -95,6 +95,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.autotune( configs=[ @@ -345,23 +347,23 @@ def matmul(a, b, accum_dtype, res_dtype): # [ 1 1 1 ... ], # [ 0 1 1 ... ], ... ] # in order only add 3 values per result matrix element. - a = torch.randn(shape, device='xpu', dtype=dtype) - b = torch.eye(shape[-2], device='xpu', dtype=dtype) + torch.diag( - torch.ones(shape[-2] - 1, device='xpu', dtype=dtype), diagonal=1) + torch.diag( - torch.ones(shape[-2] - 1, device='xpu', dtype=dtype), diagonal=-1) + a = torch.randn(shape, device=DEVICE, dtype=dtype) + b = torch.eye(shape[-2], device=DEVICE, dtype=dtype) + torch.diag( + torch.ones(shape[-2] - 1, device=DEVICE, dtype=dtype), diagonal=1) + torch.diag( + torch.ones(shape[-2] - 1, device=DEVICE, dtype=dtype), diagonal=-1) # duplicate b on batch dimension. if len(shape) == 3: b = b.unsqueeze(0).repeat(shape[0], 1, 1) else: - a = torch.randn(shape, device='xpu', dtype=dtype) - b = torch.randn(shape, device='xpu', dtype=dtype) + a = torch.randn(shape, device=DEVICE, dtype=dtype) + b = torch.randn(shape, device=DEVICE, dtype=dtype) torch_output = torch.matmul(a, b).to(dtype=res_dtype) else: - a = torch.randint(low=-127, high=128, size=shape, device='xpu', dtype=dtype) - b = torch.randint(low=-127, high=128, size=shape, device='xpu', dtype=dtype) + a = torch.randint(low=-127, high=128, size=shape, device=DEVICE, dtype=dtype) + b = torch.randint(low=-127, high=128, size=shape, device=DEVICE, dtype=dtype) # torch.matmul clamps values to input dtype; IPEX doesn't support int32 matmul torch_output = torch.matmul(a.to(device='cpu', dtype=accum_dtype), - b.to(device='cpu', dtype=accum_dtype)).to(device='xpu', dtype=res_dtype) + b.to(device='cpu', dtype=accum_dtype)).to(device=DEVICE, dtype=res_dtype) triton_output = matmul(a, b, accum_dtype, res_dtype) @@ -408,8 +410,8 @@ def matmul(a, b, accum_dtype, res_dtype): @triton.testing.perf_report(configs) def benchmark(M, N, K, provider): - a = torch.randn((M, K), device='xpu', dtype=torch.float16) - b = torch.randn((K, N), device='xpu', dtype=torch.float16) + a = torch.randn((M, K), device=DEVICE, dtype=torch.float16) + b = torch.randn((K, N), device=DEVICE, dtype=torch.float16) quantiles = [0.5, 0.2, 0.8] if provider == ref_lib.lower(): diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 99e5509eca..537604d8d4 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -502,6 +502,11 @@ def get_current_target(self): warp_size = device_properties['warpSize'] return GPUTarget("hip", arch.split(':')[0], warp_size) + def get_active_torch_device(self): + import torch + # when using hip devices, the device string in pytorch is "cuda" + return torch.device("cuda", self.get_current_device()) + def get_benchmarker(self): from triton.testing import do_bench return do_bench diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index 9510e73b9c..f2a237c35b 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -558,6 +558,10 @@ def get_current_target(self): warp_size = 32 return GPUTarget("xpu", dev_property, warp_size) + def get_active_torch_device(self): + import torch + return torch.device("xpu", self.get_current_device()) + def get_device_interface(self): import torch return torch.xpu diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index edeab969ab..e41b4a1386 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -504,6 +504,10 @@ def get_current_target(self): warp_size = 32 return GPUTarget("cuda", capability, warp_size) + def get_active_torch_device(self): + import torch + return torch.device("cuda", self.get_current_device()) + def get_device_interface(self): import torch return torch.cuda