|
95 | 95 | import triton |
96 | 96 | import triton.language as tl |
97 | 97 |
|
| 98 | +DEVICE = triton.runtime.driver.active.get_current_target().backend |
| 99 | + |
98 | 100 |
|
99 | 101 | @triton.autotune( |
100 | 102 | configs=[ |
@@ -345,23 +347,23 @@ def matmul(a, b, accum_dtype, res_dtype): |
345 | 347 | # [ 1 1 1 ... ], |
346 | 348 | # [ 0 1 1 ... ], ... ] |
347 | 349 | # in order only add 3 values per result matrix element. |
348 | | - a = torch.randn(shape, device='xpu', dtype=dtype) |
349 | | - b = torch.eye(shape[-2], device='xpu', dtype=dtype) + torch.diag( |
350 | | - torch.ones(shape[-2] - 1, device='xpu', dtype=dtype), diagonal=1) + torch.diag( |
351 | | - torch.ones(shape[-2] - 1, device='xpu', dtype=dtype), diagonal=-1) |
| 350 | + a = torch.randn(shape, device=DEVICE, dtype=dtype) |
| 351 | + b = torch.eye(shape[-2], device=DEVICE, dtype=dtype) + torch.diag( |
| 352 | + torch.ones(shape[-2] - 1, device=DEVICE, dtype=dtype), diagonal=1) + torch.diag( |
| 353 | + torch.ones(shape[-2] - 1, device=DEVICE, dtype=dtype), diagonal=-1) |
352 | 354 | # duplicate b on batch dimension. |
353 | 355 | if len(shape) == 3: |
354 | 356 | b = b.unsqueeze(0).repeat(shape[0], 1, 1) |
355 | 357 | else: |
356 | | - a = torch.randn(shape, device='xpu', dtype=dtype) |
357 | | - b = torch.randn(shape, device='xpu', dtype=dtype) |
| 358 | + a = torch.randn(shape, device=DEVICE, dtype=dtype) |
| 359 | + b = torch.randn(shape, device=DEVICE, dtype=dtype) |
358 | 360 | torch_output = torch.matmul(a, b).to(dtype=res_dtype) |
359 | 361 | else: |
360 | | - a = torch.randint(low=-127, high=128, size=shape, device='xpu', dtype=dtype) |
361 | | - b = torch.randint(low=-127, high=128, size=shape, device='xpu', dtype=dtype) |
| 362 | + a = torch.randint(low=-127, high=128, size=shape, device=DEVICE, dtype=dtype) |
| 363 | + b = torch.randint(low=-127, high=128, size=shape, device=DEVICE, dtype=dtype) |
362 | 364 | # torch.matmul clamps values to input dtype; IPEX doesn't support int32 matmul |
363 | 365 | torch_output = torch.matmul(a.to(device='cpu', dtype=accum_dtype), |
364 | | - b.to(device='cpu', dtype=accum_dtype)).to(device='xpu', dtype=res_dtype) |
| 366 | + b.to(device='cpu', dtype=accum_dtype)).to(device=DEVICE, dtype=res_dtype) |
365 | 367 |
|
366 | 368 | triton_output = matmul(a, b, accum_dtype, res_dtype) |
367 | 369 |
|
@@ -408,8 +410,8 @@ def matmul(a, b, accum_dtype, res_dtype): |
408 | 410 |
|
409 | 411 | @triton.testing.perf_report(configs) |
410 | 412 | def benchmark(M, N, K, provider): |
411 | | - a = torch.randn((M, K), device='xpu', dtype=torch.float16) |
412 | | - b = torch.randn((K, N), device='xpu', dtype=torch.float16) |
| 413 | + a = torch.randn((M, K), device=DEVICE, dtype=torch.float16) |
| 414 | + b = torch.randn((K, N), device=DEVICE, dtype=torch.float16) |
413 | 415 |
|
414 | 416 | quantiles = [0.5, 0.2, 0.8] |
415 | 417 | if provider == ref_lib.lower(): |
|
0 commit comments