Skip to content

Commit 425f6ee

Browse files
agron911meta-codesync[bot]
authored andcommitted
[Cherry-pick] [Frontend][Test] Make gluon/test_frontend.py runnable locally (#8006) (#624)
Summary: Cherry-picked from upstream OAI repository. Original Commit: d441626 Original Author: Jeff Niu Original Date: 2025-08-29 05:59:56 -0700 Original commit message: ``` [Frontend][Test] Make `gluon/test_frontend.py` runnable locally (#8006) ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. Pull Request resolved: #624 Reviewed By: dshi7 Differential Revision: D86112807 Pulled By: agron911 fbshipit-source-id: 1ce60cdc592ec7d57bd660ef36ed50ce15fa8635
1 parent d14b0a0 commit 425f6ee

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

python/test/gluon/test_frontend.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import expecttest
2-
import torch
32
import pytest
43
import re
54

@@ -128,7 +127,7 @@ def kernel(src_layout: ttgl.constexpr, dst_layout: ttgl.constexpr):
128127
with pytest.raises(CompilationError) as e:
129128
src_layout: ttgl.constexpr = ttgl.AutoLayout()
130129
dst_layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
131-
kernel.warmup(src_layout, dst_layout, grid=(1, ))
130+
run_parser(kernel, *make_args(src_layout, dst_layout), target=target)
132131

133132
assert "layout conversion from AutoLayout()" in str(e.value.__cause__)
134133
assert "to BlockedLayout(size_per_thread=[2]" in str(e.value.__cause__)
@@ -699,7 +698,7 @@ def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr):
699698

700699
@pytest.mark.parametrize("target", [HOPPER_TARGET, BLACKWELL_TARGET])
701700
def test_async_tma(target):
702-
input = torch.randn((1024, 1024), device="cuda", dtype=torch.float16)
701+
input = MockTensor(ttgl.float16, (1024, 1024))
703702
XBLOCK = 128
704703
shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2)
705704
input_desc = TensorDescriptor.from_tensor(input, [XBLOCK, XBLOCK], shared_layout)
@@ -758,7 +757,7 @@ def async_tma_blackwell_kernel(input_desc, XBLOCK: ttgl.constexpr):
758757

759758

760759
def test_async_tma_blackwell():
761-
input = torch.randn((1024, 1024), device="cuda", dtype=torch.float16)
760+
input = MockTensor(ttgl.float16, (1024, 1024))
762761
XBLOCK = 128
763762
shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2)
764763
input_desc = TensorDescriptor.from_tensor(input, [1, XBLOCK], shared_layout)

python/triton/runtime/jit.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,8 +968,17 @@ def wrap_dtype(arg):
968968
return MockTensor(arg)
969969
return arg
970970

971-
def __init__(self, dtype):
971+
def __init__(self, dtype, shape=None):
972+
if shape is None:
973+
shape = [1]
972974
self.dtype = dtype
975+
self.shape = shape
976+
977+
def stride(self):
978+
strides = [1]
979+
for size in self.shape[1:]:
980+
strides.append(strides[-1] * size)
981+
return tuple(reversed(strides))
973982

974983
@staticmethod
975984
def data_ptr():

0 commit comments

Comments
 (0)