Skip to content

Commit a21d58d

Browse files
committed
Merge branch 'main' into etiotto.remove_rewrite_tensor_ptr
2 parents 00f8432 + dd36f6d commit a21d58d

File tree

7 files changed

+361
-31
lines changed

7 files changed

+361
-31
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
7878
start_m = tl.program_id(2)
7979
off_z = tl.program_id(0)
8080
off_h = tl.program_id(1)
81+
if N_CTX <= 512:
82+
start_m = tl.program_id(0)
83+
off_z = tl.program_id(2)
8184
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
8285

8386
# block pointers
@@ -176,6 +179,9 @@ def forward(q, k, v, causal, sm_scale):
176179
num_warps = 8 if Lq == 64 else 16
177180
stage = 3 if causal else 1
178181
grid = lambda args: (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], args['BLOCK_M']))
182+
n_ctx = q.shape[2]
183+
if n_ctx <= 512:
184+
grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), q.shape[1], q.shape[0])
179185
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
180186

181187
if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0':

python/test/unit/language/test_block_pointer.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,51 +7,53 @@
77

88

99
@triton.jit
10-
def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr):
10+
def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr):
1111
pid = tl.program_id(0)
1212
# We only copy half of the data to see if the padding works
1313
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
1414
block_shape=(BLOCK_SIZE, ), order=(0, ))
1515
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
1616
block_shape=(BLOCK_SIZE, ), order=(0, ))
17-
if padding_option is None:
18-
a = tl.load(a_block_ptr, boundary_check=(0, ))
19-
else:
20-
a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option)
17+
# if padding_option is None:
18+
a = tl.load(a_block_ptr, boundary_check=(0, ))
19+
# else:
20+
# a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option)
2121
tl.store(b_block_ptr, a, boundary_check=(0, ))
2222

2323

2424
@pytest.mark.interpreter
25-
@pytest.mark.parametrize("dtypes_str, n, padding_option", [ #
26-
(dtypes_str, n, padding)
27-
for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"),
28-
("float32", "float32"), ("bfloat16", "bfloat16"))
29-
for n in (64, 128, 256, 512, 1024)
30-
for padding in (None, "zero", "nan") #
25+
@pytest.mark.parametrize("dtypes_str, n", [ #
26+
(dtypes_str, n)
27+
# for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"),
28+
# ("float32", "float32"), ("bfloat16", "bfloat16"))
29+
for dtypes_str in [("float16", "float16")]
30+
for n in [64]
3131
])
32-
def test_block_copy(dtypes_str, n, padding_option, device):
32+
def test_block_copy(dtypes_str, n, device):
3333
src_dtype_str = dtypes_str[0]
3434
dst_dtype_str = dtypes_str[1]
3535
src_dtype = getattr(torch, src_dtype_str)
3636
dst_dtype = getattr(torch, dst_dtype_str)
3737
check_type_supported(src_dtype, device)
3838
check_type_supported(dst_dtype, device)
3939
if src_dtype_str in ("bool", "int16", "int32"):
40-
if padding_option == "nan":
41-
pytest.xfail("Padding with NaN is not supported for integer types")
40+
# if padding_option == "nan":
41+
# pytest.xfail("Padding with NaN is not supported for integer types")
4242
a = torch.randint(0, 2, (n, ), device=device, dtype=src_dtype)
4343
else:
4444
a = torch.randn((n, ), device=device, dtype=src_dtype)
4545
b = torch.zeros((n, ), device=device, dtype=dst_dtype)
4646

4747
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), )
48-
block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option)
48+
block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64)
4949
a.to(dst_dtype)
5050
assert torch.all(a[0:n // 2] == b[0:n // 2])
51-
if padding_option == "zero":
52-
assert torch.all(b[n // 2:n] == 0)
53-
elif padding_option == "nan":
54-
assert torch.all(torch.isnan(b[n // 2:n]))
51+
52+
53+
# if padding_option == "zero":
54+
# assert torch.all(b[n // 2:n] == 0)
55+
# elif padding_option == "nan":
56+
# assert torch.all(torch.isnan(b[n // 2:n]))
5557

5658

5759
@triton.jit

scripts/skiplist/a770/language.txt

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,166 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
22
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
3+
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
4+
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
5+
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]
6+
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-int8-int8]
7+
test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float16]
8+
test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float32]
9+
test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float32-float32]
10+
test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-int8-int8]
11+
test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float16-float16]
12+
test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float16-float32]
13+
test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float32-float32]
14+
test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-int8-int8]
15+
test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float16-float16]
16+
test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float16-float32]
17+
test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float32-float32]
18+
test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-int8-int8]
19+
test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float16-float16]
20+
test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float16-float32]
21+
test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float32-float32]
22+
test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-int8-int8]
23+
test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float16-float16]
24+
test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float16-float32]
25+
test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float32-float32]
26+
test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-int8-int8]
27+
test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float16-float16]
28+
test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float16-float32]
29+
test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float32-float32]
30+
test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-int8-int8]
31+
test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float16-float16]
32+
test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float16-float32]
33+
test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float32-float32]
34+
test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-int8-int8]
35+
test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float16-float16]
36+
test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float16-float32]
37+
test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float32-float32]
38+
test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-int8-int8]
39+
test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float16-float16]
40+
test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float16-float32]
41+
test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float32-float32]
42+
test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-int8-int8]
43+
test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float16-float16]
44+
test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float16-float32]
45+
test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float32-float32]
46+
test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-int8-int8]
47+
test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float16-float16]
48+
test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float16-float32]
49+
test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float32-float32]
50+
test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-int8-int8]
51+
test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float16-float16]
52+
test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float16-float32]
53+
test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float32-float32]
54+
test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-int8-int8]
55+
test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float16-float16]
56+
test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float16-float32]
57+
test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float32-float32]
58+
test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-int8-int8]
59+
test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float16-float16]
60+
test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float16-float32]
61+
test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float32-float32]
62+
test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-int8-int8]
63+
test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float16-float16]
64+
test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float16-float32]
65+
test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float32-float32]
66+
test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-int8-int8]
67+
test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float16-float16]
68+
test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float16-float32]
69+
test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float32-float32]
70+
test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-int8-int8]
71+
test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float16-float16]
72+
test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float16-float32]
73+
test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float32-float32]
74+
test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-int8-int8]
75+
test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float16-float16]
76+
test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float16-float32]
77+
test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float32-float32]
78+
test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-int8-int8]
79+
test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float16-float16]
80+
test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float16-float32]
81+
test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float32-float32]
82+
test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-int8-int8]
83+
test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float16-float16]
84+
test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float16-float32]
85+
test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float32-float32]
86+
test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-int8-int8]
87+
test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float16-float16]
88+
test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float16-float32]
89+
test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float32-float32]
90+
test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-int8-int8]
91+
test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float16-float16]
92+
test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float16-float32]
93+
test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float32-float32]
94+
test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-int8-int8]
95+
test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float16-float16]
96+
test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float16-float32]
97+
test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float32-float32]
98+
test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-int8-int8]
99+
test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float16-float16]
100+
test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float16-float32]
101+
test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float32-float32]
102+
test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-int8-int8]
103+
test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float16-float16]
104+
test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float16-float32]
105+
test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float32-float32]
106+
test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-int8-int8]
107+
test/unit/language/test_core.py::test_dot3d[4-4-128-128-64-64-64-float16-float16]
108+
test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float16-float16]
109+
test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float16-float32]
110+
test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float32-float32]
111+
test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-int8-int8]
112+
test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float16-float16]
113+
test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float16-float32]
114+
test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float32-float32]
115+
test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-int8-int8]
116+
test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float16-float16]
117+
test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float16-float32]
118+
test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float32-float32]
119+
test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-int8-int8]
120+
test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float16-float16]
121+
test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float16-float32]
122+
test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float32-float32]
123+
test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-int8-int8]
124+
test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float16-float16]
125+
test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float16-float32]
126+
test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float32-float32]
127+
test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-int8-int8]
128+
test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float16-float16]
129+
test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float16-float32]
130+
test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float32-float32]
131+
test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-int8-int8]
132+
test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float16-float16]
133+
test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float16-float32]
134+
test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float32-float32]
135+
test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-int8-int8]
136+
test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float16-float16]
137+
test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float16-float32]
138+
test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float32-float32]
139+
test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-int8-int8]
140+
test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float16-float16]
141+
test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float16-float32]
142+
test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float32-float32]
143+
test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-int8-int8]
144+
test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float16-float16]
145+
test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float16-float32]
146+
test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float32-float32]
147+
test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-int8-int8]
148+
test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float16-float16]
149+
test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float16-float32]
150+
test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float32-float32]
151+
test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-int8-int8]
152+
test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float16-float16]
153+
test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float16-float32]
154+
test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float32-float32]
155+
test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-int8-int8]
156+
test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float16-float16]
157+
test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float16-float32]
158+
test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float32-float32]
159+
test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-int8-int8]
160+
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float16]
161+
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float32]
162+
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float32-float32]
163+
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-int8-int8]
3164
# https://github.com/intel/intel-xpu-backend-for-triton/issues/983
4165
test/unit/language/test_core.py::test_noinline[shared]
5166
test/unit/language/test_core.py::test_dot[1-128-128-64-2-True-True-none-tf32-int8-int8-1_0]

0 commit comments

Comments
 (0)