Skip to content

Commit 70a4ddf

Browse files
Merge commit '6af74b2f4535682abfc0b08958bc2c6831036d29'
2 parents 9bda03d + 6af74b2 commit 70a4ddf

File tree

7 files changed

+21
-8
lines changed

7 files changed

+21
-8
lines changed

python/test/unit/hopper/test_experimental_tma.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def kernel(Z, desc, SIZE: tl.constexpr, BYVAL_TMA: tl.constexpr):
5757
@triton.jit
5858
def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
5959
M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
60-
BYVAL_TMA: tl.constexpr):
60+
BYVAL_TMA: tl.constexpr, dtype: tl.constexpr):
6161
if not BYVAL_TMA:
6262
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
6363
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
@@ -72,11 +72,11 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
7272
offs_k = 0
7373
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
7474
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
75-
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], tl.float16)
76-
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tl.float16)
75+
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype)
76+
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], dtype)
7777
accumulator = tl.dot(a, b, acc=accumulator)
7878
offs_k += BLOCK_SIZE_K
79-
accumulator = accumulator.to(tl.float16)
79+
accumulator = accumulator.to(dtype)
8080
tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])
8181

8282

@@ -101,7 +101,7 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tm
101101
desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size())
102102
kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1,
103103
1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma,
104-
num_warps=8, num_stages=num_stages)
104+
num_warps=8, num_stages=num_stages, dtype=tl.float16)
105105
ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16)
106106
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
107107
if BLOCK_M >= 64 and BLOCK_N >= 64:

python/triton/language/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1613,7 +1613,7 @@ def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=
16131613
16141614
This loads a tensor of data based on the descriptor and offsets.
16151615
"""
1616-
type = block_type(dtype, shape)
1616+
type = block_type(_constexpr_to_value(dtype), shape)
16171617
return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder)
16181618

16191619

python/triton/testing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from contextlib import contextmanager
66
from typing import Any, Dict, List
77
from . import language as tl
8+
from . import runtime
89
import time
910
import logging
1011

@@ -161,7 +162,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
161162
assert return_mode in ["min", "max", "mean", "median", "all"]
162163
import torch
163164

164-
di = torch._dynamo.device_interface.get_interface_for_device(device_type)
165+
di = runtime.driver.active.get_device_interface()
165166

166167
fn()
167168
di.synchronize()

python/tutorials/09-persistent-matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def bench(K, dtype, tiles_per_update, reps=10):
554554
if cublas is not None:
555555
for _ in range(reps):
556556
cublas_matmul(a, b)
557-
time.sleep(0.01)
557+
time.sleep(0.01)
558558
if dtype == torch.float16:
559559
for _ in range(reps):
560560
torch_matmul(a, b)

third_party/amd/backend/driver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,10 @@ def __init__(self):
484484
self.utils = HIPUtils()
485485
self.launcher_cls = HIPLauncher
486486

487+
def get_device_interface(self):
488+
import torch
489+
return torch.cuda
490+
487491
@staticmethod
488492
def is_active():
489493
import torch

third_party/intel/backend/driver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,10 @@ def get_current_target(self):
479479
warp_size = 32
480480
return GPUTarget("xpu", dev_property, warp_size)
481481

482+
def get_device_interface(self):
483+
import torch
484+
return torch.xpu
485+
482486
@staticmethod
483487
def is_active():
484488
import torch

third_party/nvidia/backend/driver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,10 @@ def get_current_target(self):
440440
warp_size = 32
441441
return GPUTarget("cuda", capability, warp_size)
442442

443+
def get_device_interface(self):
444+
import torch
445+
return torch.cuda
446+
443447
@staticmethod
444448
def is_active():
445449
import torch

0 commit comments

Comments
 (0)