|
27 | 27 | import triton.language as tl |
28 | 28 | from triton.runtime import driver |
29 | 29 |
|
30 | | -DEVICE = triton.runtime.driver.active.get_current_target().backend |
| 30 | +DEVICE = torch.device(triton.runtime.driver.active.get_current_target().backend, |
| 31 | + triton.runtime.driver.active.get_current_device()) |
31 | 32 |
|
32 | 33 |
|
33 | 34 | def is_hip(): |
@@ -112,8 +113,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n |
112 | 113 | # %% |
113 | 114 | # We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. |
114 | 115 |
|
115 | | -device = getattr(torch, DEVICE).current_device() |
116 | | -properties = driver.active.utils.get_device_properties(device) |
| 116 | +properties = driver.active.utils.get_device_properties(DEVICE.index) |
117 | 117 | NUM_SM = properties["multiprocessor_count"] |
118 | 118 | SIZE_SMEM = properties["max_shared_mem"] |
119 | 119 | WARPS_PER_EU = 8 # TODO: Get from properties |
@@ -229,8 +229,8 @@ def allocated_slm_size(size_smem): |
229 | 229 | )) |
230 | 230 | def benchmark(M, N, provider): |
231 | 231 | x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) |
232 | | - stream = getattr(torch, DEVICE).Stream() |
233 | | - getattr(torch, DEVICE).set_stream(stream) |
| 232 | + stream = getattr(torch, DEVICE.type).Stream() |
| 233 | + getattr(torch, DEVICE.type).set_stream(stream) |
234 | 234 | if provider == 'torch': |
235 | 235 | ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) |
236 | 236 | if provider == 'triton': |
|
0 commit comments