Skip to content

Commit 689dcfe

Browse files
authored
[INTERPRETER] Fix argument passing for internal parameters in function declarations (#5169)
1 parent 220e51c commit 689dcfe

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

python/test/unit/language/test_core.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5581,7 +5581,7 @@ def matmul_kernel( #
55815581
stride_cm, stride_cn, #
55825582
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
55835583
low_precision_acc: tl.constexpr, #
5584-
num_pipeline_stages: tl.constexpr = 3 #
5584+
num_stages: tl.constexpr = 3 #
55855585
):
55865586
pid = tl.program_id(axis=0)
55875587
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
@@ -5593,7 +5593,7 @@ def matmul_kernel( #
55935593
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
55945594
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
55955595
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
5596-
for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_pipeline_stages):
5596+
for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_stages):
55975597
a = tl.load(a_ptrs)
55985598
b = tl.load(b_ptrs)
55995599
accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc)
@@ -5632,7 +5632,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s
56325632
max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None
56335633
h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0),
56345634
C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps,
5635-
num_pipeline_stages=num_stages)
5635+
num_stages=num_stages)
56365636
torch_a = torch.from_numpy(A).to(device=device)
56375637
th_a = f8_to_f16(torch_a, in_type_str)
56385638
torch_b = torch.from_numpy(B).to(device=device)
@@ -5824,7 +5824,7 @@ def test_tl_range(device):
58245824
pgm = matmul_kernel[
58255825
1,
58265826
](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N,
5827-
BLOCK_K, 0, num_pipeline_stages=5)
5827+
BLOCK_K, 0, num_stages=5)
58285828
ref_out = torch.matmul(a, b).to(torch.float32)
58295829
if is_interpreter():
58305830
# GPU invokes tensor core for float16 matmul, which is not supported in interpreter.
@@ -5850,8 +5850,8 @@ def maxnreg_noinline2(X):
58505850
tl.store(X, 0)
58515851

58525852

5853+
@pytest.mark.interpreter
58535854
def test_maxnreg(device):
5854-
assert not is_interpreter(), "this test won't work with the interpreter"
58555855
if not is_cuda():
58565856
pytest.skip('maxnreg only works on CUDA')
58575857

@@ -5865,14 +5865,15 @@ def kernel(X):
58655865
X = torch.empty(1, dtype=torch.int32, device=device)
58665866
k = kernel[(1, )](X, maxnreg=42)
58675867

5868-
# Ensure that .maxnreg is set on the kernel function (marked with .entry)
5869-
# and not on either of the noinline functions (marked with .func).
5870-
try:
5871-
assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"])
5872-
assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"])
5873-
except AssertionError:
5874-
print("Failing ptx:\n", k.asm["ptx"])
5875-
raise
5868+
if not is_interpreter():
5869+
# Ensure that .maxnreg is set on the kernel function (marked with .entry)
5870+
# and not on either of the noinline functions (marked with .func).
5871+
try:
5872+
assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"])
5873+
assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"])
5874+
except AssertionError:
5875+
print("Failing ptx:\n", k.asm["ptx"])
5876+
raise
58765877

58775878

58785879
@pytest.mark.interpreter

python/triton/runtime/interpreter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,9 +1034,6 @@ def _implicit_cvt(arg):
10341034

10351035
interpreter_builder = InterpreterBuilder()
10361036

1037-
# These keywords are not supported by the interpreter
1038-
RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"]
1039-
10401037

10411038
class GridExecutor:
10421039

@@ -1077,10 +1074,13 @@ def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
10771074
kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data)
10781075

10791076
def __call__(self, *args_dev, **kwargs):
1080-
# removes reserved keywords from kwargs
1081-
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
10821077
if kwargs.pop("warmup", False):
10831078
return
1079+
# Removes not used reserved keywords from kwargs
1080+
# Triton doesn't support keyword-only, variable positional or variable keyword arguments
1081+
# It's safe to inspect only positional or keyword arguments (i.e., argspec.args)
1082+
argspec = inspect.getfullargspec(self.fn)
1083+
kwargs = {k: v for k, v in kwargs.items() if k in argspec.args}
10841084
# copy arguments to the host
10851085
args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
10861086
# remaps core language functions to interpreted ones

0 commit comments

Comments
 (0)