Skip to content

Commit fad8e24

Browse files
committed
Merge branch 'main' into gregory/windows-support
2 parents cb3bbf9 + b8fc4b9 commit fad8e24

File tree

7 files changed

+410
-27
lines changed

7 files changed

+410
-27
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
7878
start_m = tl.program_id(2)
7979
off_z = tl.program_id(0)
8080
off_h = tl.program_id(1)
81+
if N_CTX <= 512:
82+
start_m = tl.program_id(0)
83+
off_z = tl.program_id(2)
8184
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
8285

8386
# block pointers
@@ -176,6 +179,9 @@ def forward(q, k, v, causal, sm_scale):
176179
num_warps = 8 if Lq == 64 else 16
177180
stage = 3 if causal else 1
178181
grid = lambda args: (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], args['BLOCK_M']))
182+
n_ctx = q.shape[2]
183+
if n_ctx <= 512:
184+
grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), q.shape[1], q.shape[0])
179185
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
180186

181187
if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0':

python/test/unit/language/test_core.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,18 +1490,30 @@ def kernel(X):
14901490
for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)]
14911491
for axis in [0, 1]
14921492
for num_ctas in num_ctas_list
1493-
for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']])
1493+
for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']])
14941494
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device):
14951495
check_type_supported(dtype_x_str, device)
1496+
if is_interpreter() and dtype_x_str == 'float16':
1497+
pytest.skip('float16 atomic_add does not work in the interpreter mode')
14961498
shape0, shape1 = shape
14971499
# triton kernel
14981500

14991501
@triton.jit
1500-
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
1502+
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr):
15011503
off0 = tl.arange(0, SHAPE0)
15021504
off1 = tl.arange(0, SHAPE1)
15031505
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
1506+
1507+
if DTYPE == tl.float16:
1508+
# sum can have bad numerics when accumulating in float16.
1509+
# if we're dealing with float16, do the sum in float32.
1510+
x = x.to(tl.float32)
1511+
15041512
z = tl.sum(x, axis=AXIS)
1513+
1514+
if DTYPE == tl.float16:
1515+
z = z.to(DTYPE)
1516+
15051517
if AXIS == 1:
15061518
old = tl.atomic_add(Z + off0, z)
15071519
tl.store(OLD + off0, old)
@@ -1515,13 +1527,23 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
15151527
z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs)
15161528
old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str))
15171529
# reference results
1518-
z_ref = z + np.sum(x, axis=axis, keepdims=False)
1530+
if x.dtype == np.float16:
1531+
# do the sum in float32 to reduce numerical variation
1532+
z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype)
1533+
else:
1534+
z_ref = z + np.sum(x, axis=axis, keepdims=False)
15191535
old_ref = np.copy(z)
15201536
# triton result
15211537
x_tri = to_triton(x, device=device)
15221538
z_tri = to_triton(z, device=device)
15231539
old_tri = to_triton(old, device=device)
1524-
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas)
1540+
1541+
def torch_to_triton_dtype(t):
1542+
if t == torch.float16:
1543+
return tl.float16
1544+
return None
1545+
1546+
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), num_ctas=num_ctas)
15251547
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
15261548
np.testing.assert_equal(old_ref, to_numpy(old_tri))
15271549

scripts/skiplist/a770/language.txt

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,166 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
22
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
3+
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
4+
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
5+
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]
6+
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-int8-int8]
7+
test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float16]
8+
test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float32]
9+
test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float32-float32]
10+
test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-int8-int8]
11+
test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float16-float16]
12+
test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float16-float32]
13+
test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float32-float32]
14+
test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-int8-int8]
15+
test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float16-float16]
16+
test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float16-float32]
17+
test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float32-float32]
18+
test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-int8-int8]
19+
test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float16-float16]
20+
test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float16-float32]
21+
test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float32-float32]
22+
test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-int8-int8]
23+
test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float16-float16]
24+
test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float16-float32]
25+
test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float32-float32]
26+
test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-int8-int8]
27+
test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float16-float16]
28+
test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float16-float32]
29+
test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float32-float32]
30+
test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-int8-int8]
31+
test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float16-float16]
32+
test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float16-float32]
33+
test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float32-float32]
34+
test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-int8-int8]
35+
test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float16-float16]
36+
test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float16-float32]
37+
test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float32-float32]
38+
test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-int8-int8]
39+
test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float16-float16]
40+
test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float16-float32]
41+
test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float32-float32]
42+
test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-int8-int8]
43+
test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float16-float16]
44+
test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float16-float32]
45+
test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float32-float32]
46+
test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-int8-int8]
47+
test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float16-float16]
48+
test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float16-float32]
49+
test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float32-float32]
50+
test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-int8-int8]
51+
test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float16-float16]
52+
test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float16-float32]
53+
test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float32-float32]
54+
test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-int8-int8]
55+
test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float16-float16]
56+
test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float16-float32]
57+
test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float32-float32]
58+
test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-int8-int8]
59+
test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float16-float16]
60+
test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float16-float32]
61+
test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float32-float32]
62+
test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-int8-int8]
63+
test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float16-float16]
64+
test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float16-float32]
65+
test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float32-float32]
66+
test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-int8-int8]
67+
test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float16-float16]
68+
test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float16-float32]
69+
test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float32-float32]
70+
test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-int8-int8]
71+
test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float16-float16]
72+
test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float16-float32]
73+
test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float32-float32]
74+
test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-int8-int8]
75+
test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float16-float16]
76+
test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float16-float32]
77+
test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float32-float32]
78+
test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-int8-int8]
79+
test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float16-float16]
80+
test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float16-float32]
81+
test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float32-float32]
82+
test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-int8-int8]
83+
test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float16-float16]
84+
test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float16-float32]
85+
test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float32-float32]
86+
test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-int8-int8]
87+
test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float16-float16]
88+
test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float16-float32]
89+
test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float32-float32]
90+
test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-int8-int8]
91+
test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float16-float16]
92+
test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float16-float32]
93+
test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float32-float32]
94+
test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-int8-int8]
95+
test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float16-float16]
96+
test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float16-float32]
97+
test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float32-float32]
98+
test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-int8-int8]
99+
test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float16-float16]
100+
test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float16-float32]
101+
test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float32-float32]
102+
test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-int8-int8]
103+
test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float16-float16]
104+
test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float16-float32]
105+
test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float32-float32]
106+
test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-int8-int8]
107+
test/unit/language/test_core.py::test_dot3d[4-4-128-128-64-64-64-float16-float16]
108+
test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float16-float16]
109+
test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float16-float32]
110+
test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float32-float32]
111+
test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-int8-int8]
112+
test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float16-float16]
113+
test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float16-float32]
114+
test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float32-float32]
115+
test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-int8-int8]
116+
test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float16-float16]
117+
test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float16-float32]
118+
test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float32-float32]
119+
test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-int8-int8]
120+
test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float16-float16]
121+
test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float16-float32]
122+
test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float32-float32]
123+
test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-int8-int8]
124+
test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float16-float16]
125+
test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float16-float32]
126+
test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float32-float32]
127+
test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-int8-int8]
128+
test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float16-float16]
129+
test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float16-float32]
130+
test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float32-float32]
131+
test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-int8-int8]
132+
test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float16-float16]
133+
test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float16-float32]
134+
test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float32-float32]
135+
test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-int8-int8]
136+
test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float16-float16]
137+
test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float16-float32]
138+
test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float32-float32]
139+
test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-int8-int8]
140+
test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float16-float16]
141+
test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float16-float32]
142+
test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float32-float32]
143+
test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-int8-int8]
144+
test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float16-float16]
145+
test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float16-float32]
146+
test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float32-float32]
147+
test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-int8-int8]
148+
test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float16-float16]
149+
test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float16-float32]
150+
test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float32-float32]
151+
test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-int8-int8]
152+
test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float16-float16]
153+
test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float16-float32]
154+
test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float32-float32]
155+
test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-int8-int8]
156+
test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float16-float16]
157+
test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float16-float32]
158+
test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float32-float32]
159+
test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-int8-int8]
160+
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float16]
161+
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float32]
162+
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float32-float32]
163+
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-int8-int8]
3164
# https://github.com/intel/intel-xpu-backend-for-triton/issues/983
4165
test/unit/language/test_core.py::test_noinline[shared]
5166
test/unit/language/test_core.py::test_dot[1-128-128-64-2-True-True-none-tf32-int8-int8-1_0]

0 commit comments

Comments
 (0)