Skip to content

Commit 72e8a2a

Browse files
authored
Merge branch 'main' into fix-doc-build-version-theme
2 parents c034b2a + 149e23d commit 72e8a2a

16 files changed

+331
-179
lines changed

backends/arm/tosa/backend.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,14 @@ def _preprocess( # noqa: C901
133133
if not artifact_path:
134134
artifact_path = ""
135135

136-
tosa_graph = ts.TosaSerializer(artifact_path)
136+
version = tosa_spec.version
137+
tosa_graph = ts.TosaSerializer(
138+
artifact_path,
139+
targetMajor=version.major,
140+
targetMinor=version.minor,
141+
targetPatch=version.micro,
142+
targetDraft=False,
143+
)
137144

138145
if not (
139146
tosa_spec.version.major == ts.TOSA_VERSION_MAJOR

backends/vulkan/op_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ def register_dequantize_for_conv2d_op():
687687
@update_features("llama::sdpa_with_kv_cache")
688688
def register_sdpa_with_kv_cache_op():
689689
return OpFeatures(
690-
inputs_storage=utils.WIDTH_PACKED_TEXTURE,
690+
inputs_storage=utils.CONTIGUOUS_ANY,
691691
supports_resize=True,
692692
supports_prepacking=True,
693693
)

backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ void main() {
7676
const int Q_H = q_projected_sizes.y;
7777
// sequence length
7878
const int S = q_projected_sizes.z;
79+
const int S_aligned = align_up_4(S);
7980
// manually determine size of the context_len dim of the attention weight.
8081
// The "actual" tensor sizes may have been aligned to a multiple of 4 to allow
8182
// memory loads to be aligned to texel boundaries.
@@ -96,7 +97,7 @@ void main() {
9697
// number of threads in the work group.
9798
for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) {
9899
VEC4_T in_texel = load_attn_weights_c4(
99-
c4, s, q_h, context_texel_len, S, Q_H);
100+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
100101

101102
for (int comp = 0; comp < 4; comp++) {
102103
local_exp_sum += exp(in_texel[comp]);
@@ -108,7 +109,7 @@ void main() {
108109
for (int c4 = C4_limit; c4 < context_texel_len; ++c4) {
109110
const int c_base = mul_4(c4);
110111
VEC4_T in_texel = load_attn_weights_c4(
111-
c4, s, q_h, context_texel_len, S, Q_H);
112+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
112113

113114
[[unroll]] for (int comp = 0; comp < 4; comp++) {
114115
if (c_base + comp < context_len) {
@@ -138,19 +139,19 @@ void main() {
138139
// Now go back through each element in the row and normalize
139140
for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) {
140141
VEC4_T in_texel = load_attn_weights_c4(
141-
c4, s, q_h, context_texel_len, S, Q_H);
142+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
142143

143144
VEC4_T out_texel = exp(in_texel) / local_exp_sum;
144145
store_attn_weights_softmax_c4(
145-
out_texel, c4, s, q_h, context_texel_len, S, Q_H);
146+
out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H);
146147
}
147148
// First thread in the work group responsible for handling last texel if it
148149
// contains any padded elements
149150
if (worker_id == 0) {
150151
for (int c4 = C4_limit; c4 < context_texel_len; ++c4) {
151152
const int c_base = mul_4(c4);
152153
VEC4_T in_texel = load_attn_weights_c4(
153-
c4, s, q_h, context_texel_len, S, Q_H);
154+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
154155

155156
// Ensure that padding elements are set to 0.
156157
VEC4_T out_texel = VEC4_T(0);
@@ -160,7 +161,7 @@ void main() {
160161
}
161162
}
162163
store_attn_weights_softmax_c4(
163-
out_texel, c4, s, q_h, context_texel_len, S, Q_H);
164+
out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H);
164165
}
165166
}
166167
}

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl

Lines changed: 23 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ void main() {
8181
const int Q_H = q_projected_sizes.y;
8282
// sequence length
8383
const int S = q_projected_sizes.z;
84+
const int S_aligned = align_up_4(S);
8485

8586
// number of K/V heads
8687
const int KV_H = k_cache_sizes.y;
@@ -118,55 +119,27 @@ void main() {
118119
}
119120
// Otherwise, need to actually compute output tile
120121
else {
121-
const bool dont_check_bounds = (S - s) >= TILE_M &&
122-
(context_len - c) >= TILE_N;
123-
124-
if (dont_check_bounds) {
125-
for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) {
126-
load_q_projected_tile_no_checks(
127-
q_tile,
128-
d4,
129-
s,
130-
q_h,
131-
D4,
132-
Q_H,
133-
S);
134-
135-
load_k_cache_tile_no_checks(
136-
w_tile,
137-
d4,
138-
c,
139-
kv_h,
140-
D4,
141-
context_len,
142-
C,
143-
KV_H);
144-
145-
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
146-
}
147-
} else {
148-
for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) {
149-
load_q_projected_tile_with_checks(
150-
q_tile,
151-
d4,
152-
s,
153-
q_h,
154-
D4,
155-
Q_H,
156-
S);
157-
158-
load_k_cache_tile_with_checks(
159-
w_tile,
160-
d4,
161-
c,
162-
kv_h,
163-
D4,
164-
context_len,
165-
C,
166-
KV_H);
167-
168-
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
169-
}
122+
for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) {
123+
load_q_projected_tile_with_checks(
124+
q_tile,
125+
d4,
126+
s,
127+
q_h,
128+
D4,
129+
Q_H,
130+
S);
131+
132+
load_k_cache_tile_with_checks(
133+
w_tile,
134+
d4,
135+
c,
136+
kv_h,
137+
D4,
138+
context_len,
139+
C,
140+
KV_H);
141+
142+
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
170143
}
171144
}
172145

@@ -205,7 +178,7 @@ void main() {
205178
s,
206179
q_h,
207180
context_texel_len,
208-
S,
181+
S_aligned,
209182
Q_H);
210183
}
211184
}

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@ sdpa_compute_attn_weights_coop:
1212
TILE_K4: 1
1313
TILE_N4: 1
1414
generate_variant_forall:
15+
combination:
16+
parameter_names: [IO_STORAGE, K_CACHE_STORAGE]
17+
combos:
18+
- parameter_values: [texture3d, texture3d]
19+
- parameter_values: [buffer, texture3d]
20+
- parameter_values: [buffer, buffer]
1521
DTYPE:
1622
- VALUE: float
1723
- VALUE: half
1824
shader_variants:
19-
- NAME: sdpa_compute_attn_weights_coop_texture3d_texture3d
20-
- NAME: sdpa_compute_attn_weights_coop_buffer_texture3d
21-
IO_STORAGE: buffer
25+
- NAME: sdpa_compute_attn_weights_coop

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl

Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ void main() {
9393
const int Q_H = q_projected_sizes.y;
9494
// sequence length
9595
const int S = q_projected_sizes.z;
96+
const int S_aligned = align_up_4(S);
9697

9798
// number of K/V heads
9899
const int KV_H = k_cache_sizes.y;
@@ -129,55 +130,28 @@ void main() {
129130
}
130131
// Otherwise, need to actually compute output tile
131132
else {
132-
const bool dont_check_bounds = (S - s) >= TILE_M &&
133-
(context_len - c) >= TILE_N;
134-
135-
if (dont_check_bounds) {
136-
for (int d4 = 0; d4 < D4; d4++) {
137-
load_q_projected_tile_no_checks(
138-
q_tile,
139-
d4,
140-
s,
141-
q_h,
142-
D4,
143-
Q_H,
144-
S);
145-
146-
load_k_cache_tile_no_checks(
147-
w_tile,
148-
d4,
149-
c,
150-
kv_h,
151-
D4,
152-
context_len,
153-
C,
154-
KV_H);
155-
156-
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
157-
}
158-
} else {
159-
for (int d4 = 0; d4 < D4; d4++) {
160-
load_q_projected_tile_with_checks(
161-
q_tile,
162-
d4,
163-
s,
164-
q_h,
165-
D4,
166-
Q_H,
167-
S);
168-
169-
load_k_cache_tile_with_checks(
170-
w_tile,
171-
d4,
172-
c,
173-
kv_h,
174-
D4,
175-
context_len,
176-
C,
177-
KV_H);
178-
179-
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
180-
}
133+
for (int d4 = 0; d4 < D4; d4++) {
134+
load_q_projected_tile_with_checks(
135+
q_tile,
136+
d4,
137+
s,
138+
q_h,
139+
D4,
140+
Q_H,
141+
S);
142+
143+
load_k_cache_tile_with_checks(
144+
w_tile,
145+
d4,
146+
c,
147+
kv_h,
148+
D4,
149+
context_len,
150+
C,
151+
KV_H);
152+
153+
154+
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
181155
}
182156

183157
// Apply scale and mask
@@ -196,6 +170,6 @@ void main() {
196170
s,
197171
q_h,
198172
context_texel_len,
199-
S,
173+
S_aligned,
200174
Q_H);
201175
}

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@ sdpa_compute_attn_weights_tiled:
1313
TILE_K4: 1
1414
TILE_N4: 1
1515
generate_variant_forall:
16+
combination:
17+
parameter_names: [IO_STORAGE, K_CACHE_STORAGE]
18+
combos:
19+
- parameter_values: [texture3d, texture3d]
20+
- parameter_values: [buffer, texture3d]
21+
- parameter_values: [buffer, buffer]
1622
DTYPE:
1723
- VALUE: float
1824
- VALUE: half
1925
shader_variants:
20-
- NAME: sdpa_compute_attn_weights_tiled_texture3d_texture3d
21-
- NAME: sdpa_compute_attn_weights_tiled_buffer_texture3d
22-
IO_STORAGE: buffer
26+
- NAME: sdpa_compute_attn_weights_tiled

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ void main() {
8181
const int Q_H = q_projected_sizes.y;
8282
// sequence length
8383
const int S = q_projected_sizes.z;
84+
const int S_aligned = align_up_4(S);
8485

8586
// number of K/V heads
8687
const int KV_H = v_cache_sizes.y;
@@ -120,7 +121,7 @@ void main() {
120121
s,
121122
q_h,
122123
context_texel_len,
123-
S,
124+
S_aligned,
124125
Q_H);
125126

126127
load_v_cache_tile_no_checks(
@@ -146,7 +147,7 @@ void main() {
146147
s,
147148
q_h,
148149
context_texel_len,
149-
S,
150+
S_aligned,
150151
Q_H);
151152

152153
load_v_cache_tile_with_checks(

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@ sdpa_compute_out_coop:
1212
TILE_K4: 1
1313
TILE_N4: 1
1414
generate_variant_forall:
15+
combination:
16+
parameter_names: [IO_STORAGE, V_CACHE_STORAGE]
17+
combos:
18+
- parameter_values: [texture3d, texture3d]
19+
- parameter_values: [buffer, texture3d]
20+
- parameter_values: [buffer, buffer]
1521
DTYPE:
1622
- VALUE: float
1723
- VALUE: half
1824
shader_variants:
19-
- NAME: sdpa_compute_out_coop_texture3d_texture3d
20-
- NAME: sdpa_compute_out_coop_buffer_texture3d
21-
IO_STORAGE: buffer
25+
- NAME: sdpa_compute_out_coop

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ void main() {
7575
const int Q_H = q_projected_sizes.y;
7676
// sequence length
7777
const int S = q_projected_sizes.z;
78+
const int S_aligned = align_up_4(S);
7879

7980
// number of K/V heads
8081
const int KV_H = v_cache_sizes.y;
@@ -113,7 +114,7 @@ void main() {
113114
s,
114115
q_h,
115116
context_texel_len,
116-
S,
117+
S_aligned,
117118
Q_H);
118119

119120
load_v_cache_tile_no_checks(
@@ -136,7 +137,7 @@ void main() {
136137
s,
137138
q_h,
138139
context_texel_len,
139-
S,
140+
S_aligned,
140141
Q_H);
141142

142143
load_v_cache_tile_with_checks(

0 commit comments

Comments
 (0)