Skip to content

Commit 244b146

Browse files
committed
Update
[ghstack-poisoned]
2 parents 7e97fd0 + 11a5a02 commit 244b146

File tree

25 files changed

+985
-49
lines changed

25 files changed

+985
-49
lines changed

.ci/scripts/unittest-linux.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ if [[ "$BUILD_TOOL" == "cmake" ]]; then
2222

2323
# We need the runner to test the built library.
2424
PYTHON_EXECUTABLE=python \
25-
CMAKE_ARGS="-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON -DEXECUTORCH_BUILD_TESTS=ON" \
25+
CMAKE_ARGS="-DEXECUTORCH_BUILD_EXTENSION_EVALUE_UTIL=ON -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON -DEXECUTORCH_BUILD_TESTS=ON" \
2626
.ci/scripts/setup-linux.sh "$@"
2727

2828
.ci/scripts/unittest-linux-cmake.sh

.ci/scripts/unittest-macos.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ trap 'rm -rfv ${TMP_DIR}' EXIT
2222
# Setup MacOS dependencies as there is no Docker support on MacOS atm
2323
# We need the runner to test the built library.
2424
PYTHON_EXECUTABLE=python \
25-
CMAKE_ARGS="-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON -DEXECUTORCH_BUILD_TESTS=ON" \
25+
CMAKE_ARGS="-DEXECUTORCH_BUILD_EXTENSION_EVALUE_UTIL=ON -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON -DEXECUTORCH_BUILD_TESTS=ON" \
2626
${CONDA_RUN} --no-capture-output \
2727
.ci/scripts/setup-macos.sh "$@"
2828

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ jobs:
387387
eval "$(conda shell.bash hook)"
388388
389389
# Install requirements
390-
${CONDA_RUN} python install_executorch.py
390+
${CONDA_RUN} EXECUTORCH_BUILD_TORCHAO=1 python install_executorch.py
391391
${CONDA_RUN} sh examples/models/llama/install_requirements.sh
392392
393393
# Run test

CMakeLists.txt

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,16 @@ if(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER)
548548
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/data_loader)
549549
endif()
550550

551+
if(EXECUTORCH_BUILD_EXTENSION_EVALUE_UTIL)
552+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/evalue_util)
553+
install(
554+
DIRECTORY extension/evalue_util/
555+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/executorch/extension/evalue_util
556+
FILES_MATCHING
557+
PATTERN "*.h"
558+
)
559+
endif()
560+
551561
if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
552562
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/flat_tensor)
553563
endif()
@@ -576,6 +586,12 @@ endif()
576586

577587
if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)
578588
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/runner_util)
589+
install(
590+
DIRECTORY extension/runner_util/
591+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/executorch/extension/runner_util
592+
FILES_MATCHING
593+
PATTERN "*.h"
594+
)
579595
endif()
580596

581597
if(EXECUTORCH_BUILD_EXTENSION_TENSOR)
@@ -651,8 +667,7 @@ if(EXECUTORCH_BUILD_PYBIND)
651667

652668
# util lib
653669
add_library(
654-
util ${CMAKE_CURRENT_SOURCE_DIR}/extension/evalue_util/print_evalue.cpp
655-
${CMAKE_CURRENT_SOURCE_DIR}/extension/aten_util/aten_bridge.cpp
670+
util ${CMAKE_CURRENT_SOURCE_DIR}/extension/aten_util/aten_bridge.cpp
656671
)
657672
target_include_directories(
658673
util PUBLIC ${_common_include_directories} ${TORCH_INCLUDE_DIRS}
@@ -695,7 +710,9 @@ endif()
695710

696711
if(EXECUTORCH_BUILD_EXECUTOR_RUNNER)
697712
# Baseline libraries that executor_runner will link against.
698-
set(_executor_runner_libs executorch gflags)
713+
set(_executor_runner_libs executorch extension_evalue_util
714+
extension_runner_util gflags
715+
)
699716

700717
if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED)
701718
list(APPEND _executor_runner_libs optimized_native_cpu_ops_lib)
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define T ${buffer_scalar_type(DTYPE)}
13+
${define_required_extensions(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
// Flash Attention inputs: Query, Key, Value tensors
20+
${layout_declare_tensor(B, "rw", "t_O", DTYPE, "buffer")}
21+
${layout_declare_tensor(B, "rw", "t_l", "float", "buffer")}
22+
${layout_declare_tensor(B, "rw", "t_m", "float", "buffer")}
23+
${layout_declare_tensor(B, "r", "t_Q", DTYPE, "buffer")}
24+
${layout_declare_tensor(B, "r", "t_K", DTYPE, "buffer")}
25+
${layout_declare_tensor(B, "r", "t_V", DTYPE, "buffer")}
26+
27+
${layout_declare_ubo(B, "ivec4", "Q_sizes")} // [B, H, N, D]
28+
${layout_declare_ubo(B, "ivec4", "K_sizes")}
29+
${layout_declare_ubo(B, "ivec4", "V_sizes")}
30+
${layout_declare_ubo(B, "ivec4", "O_sizes")}
31+
32+
${layout_declare_ubo(B, "ivec3", "l_sizes")} // [B, H, N]
33+
${layout_declare_ubo(B, "ivec3", "m_sizes")} // [B, H, N]
34+
35+
${layout_declare_ubo(B, "float", "scale")}
36+
${layout_declare_ubo(B, "int", "block_size_r")} // Br (num rows in Q block)
37+
${layout_declare_ubo(B, "int", "block_size_c")} // Bc (num cols in K/V block)
38+
${layout_declare_ubo(B, "int", "input_pos")} // Starting position for causal masking
39+
${layout_declare_ubo(B, "int", "num_heads")} // Number of query heads
40+
${layout_declare_ubo(B, "int", "num_kv_heads")} // Number of key/value heads
41+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
42+
43+
// Maximum block sizes to prevent array overflow
44+
#define MAX_BR 64
45+
#define MAX_BC 128
46+
47+
void main() {
48+
// Each thread processes one row block
49+
const int thread_id = int(gl_GlobalInvocationID.x);
50+
51+
// Tensor dimensions: Q_sizes = [D, H, N, B] from graph.sizes_ubo()
52+
// The UBO layout is different from the PyTorch tensor layout
53+
const int head_dim = Q_sizes.x; // D (head dim)
54+
const int num_heads = Q_sizes.y; // H (num heads)
55+
const int seq_len = Q_sizes.z; // N (sequence length)
56+
const int batch_size = Q_sizes.w; // B (batch)
57+
58+
// Block sizes
59+
const int Br = block_size_r;
60+
const int Bc = block_size_c;
61+
62+
const int Tr = (seq_len + Br - 1) / Br; // Number of row blocks
63+
const int total_row_blocks = batch_size * num_heads * Tr;
64+
65+
if (thread_id >= total_row_blocks) {
66+
return;
67+
}
68+
69+
// Decode thread_id to (batch, head, row_block)
70+
const int batch = thread_id / (num_heads * Tr);
71+
const int remaining = thread_id % (num_heads * Tr);
72+
const int head = remaining / Tr;
73+
const int row_block = remaining % Tr;
74+
75+
// Calculate row range for this block
76+
const int row_start = row_block * Br;
77+
const int row_end = min(row_start + Br, seq_len);
78+
const int actual_Br = row_end - row_start;
79+
80+
// Base indices for this batch
81+
const int q_base = batch * (seq_len * num_heads * head_dim);
82+
const int k_base = batch * (seq_len * num_heads * head_dim);
83+
const int v_base = batch * (seq_len * num_heads * head_dim);
84+
const int o_base = batch * (seq_len * num_heads * head_dim);
85+
const int lm_base = batch * (seq_len * num_heads);
86+
87+
// STEP 2: Initialize O = 0, l = 0, m = -inf for this row block
88+
for (int r = 0; r < actual_Br; r++) {
89+
const int seq_pos = row_start + r;
90+
const int lm_idx = lm_base + head * seq_len + seq_pos;
91+
92+
t_l[lm_idx] = 0.0;
93+
t_m[lm_idx] = -1.0 / 0.0; // -infinity
94+
95+
for (int dim = 0; dim < head_dim; dim++) {
96+
const int o_idx = o_base + seq_pos * (num_heads * head_dim) + head * head_dim + dim;
97+
t_O[o_idx] = T(0.0);
98+
}
99+
}
100+
101+
// STEP 5: Outer loop over column blocks (For K, V tensors)
102+
const int Tc = (seq_len + Bc - 1) / Bc; // Number of column blocks
103+
for (int j = 0; j < Tc; j++) {
104+
const int col_start = j * Bc;
105+
const int col_end = min(col_start + Bc, seq_len);
106+
const int actual_Bc = col_end - col_start;
107+
108+
// STEP 6-8 done implicitly below
109+
110+
// Load current statistics for all rows in this block
111+
float m_i[MAX_BR];
112+
float l_i[MAX_BR];
113+
for (int r = 0; r < actual_Br; r++) {
114+
const int seq_pos = row_start + r;
115+
const int lm_idx = lm_base + head * seq_len + seq_pos;
116+
m_i[r] = t_m[lm_idx];
117+
l_i[r] = t_l[lm_idx];
118+
}
119+
120+
// STEP 9: Compute Sij = Qi * Kj^T
121+
T S_block[MAX_BR][MAX_BC]; // Use MAX_BR and MAX_BC constants
122+
float m_tilde_ij[MAX_BR]; // Row maxes (float to match l/m)
123+
float l_tilde_ij[MAX_BR]; // Row sums (float to match l/m)
124+
125+
// Initialize row statistics
126+
for (int r = 0; r < actual_Br; r++) {
127+
m_tilde_ij[r] = -1.0 / 0.0; // -infinity
128+
l_tilde_ij[r] = 0.0;
129+
}
130+
131+
// Compute attention scores Sij = Qi @ Kj^T
132+
for (int r = 0; r < actual_Br; r++) {
133+
const int global_row = row_start + r;
134+
for (int c = 0; c < actual_Bc; c++) {
135+
const int global_col = col_start + c;
136+
137+
// For multi-query attention: map query head to KV head
138+
const int kv_head = (head * num_kv_heads) / num_heads;
139+
140+
// Dot product: Q[seq_pos, :] · K[col_pos, :]
141+
T score = T(0.0);
142+
for (int dim = 0; dim < head_dim; dim++) {
143+
const int q_idx = q_base + global_row * (num_heads * head_dim) + head * head_dim + dim;
144+
const int k_idx = k_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim;
145+
score += t_Q[q_idx] * t_K[k_idx];
146+
}
147+
score *= scale;
148+
149+
// Apply causal masking: mask if global_col > global_row + input_pos
150+
if (global_col > global_row + input_pos) {
151+
score = T(-1.0 / 0.0); // Set to negative infinity
152+
}
153+
154+
S_block[r][c] = score;
155+
156+
// Track row maximum (after masking)
157+
m_tilde_ij[r] = max(m_tilde_ij[r], float(score));
158+
}
159+
}
160+
161+
// STEP 10: Compute P'ij = exp(Sij − m'ij) and l'ij = rowsum(P'ij)
162+
for (int r = 0; r < actual_Br; r++) {
163+
// Handle the case where all scores are -inf (fully masked row)
164+
if (isinf(m_tilde_ij[r]) && m_tilde_ij[r] < 0.0) {
165+
// All scores are -inf, so all probabilities are 0
166+
for (int c = 0; c < actual_Bc; c++) {
167+
S_block[r][c] = T(0.0);
168+
}
169+
l_tilde_ij[r] = 0.0;
170+
} else {
171+
// Normal case: compute softmax
172+
for (int c = 0; c < actual_Bc; c++) {
173+
S_block[r][c] = exp(S_block[r][c] - T(m_tilde_ij[r]));
174+
l_tilde_ij[r] += float(S_block[r][c]);
175+
}
176+
}
177+
}
178+
179+
// STEP 11: Softmax update
180+
float m_new_i[MAX_BR];
181+
float l_new_i[MAX_BR];
182+
for (int r = 0; r < actual_Br; r++) {
183+
m_new_i[r] = max(m_i[r], m_tilde_ij[r]);
184+
185+
l_new_i[r] = exp(m_i[r] - m_new_i[r]) * l_i[r] + exp(m_tilde_ij[r] - m_new_i[r]) * l_tilde_ij[r];
186+
}
187+
188+
// STEP 12: Update Oi
189+
for (int r = 0; r < actual_Br; r++) {
190+
const int global_row = row_start + r;
191+
float alpha = exp(m_i[r] - m_new_i[r]);
192+
float beta = exp(m_tilde_ij[r] - m_new_i[r]);
193+
194+
// For multi-query attention: map query head to KV head
195+
const int kv_head = (head * num_kv_heads) / num_heads;
196+
197+
for (int dim = 0; dim < head_dim; dim++) {
198+
const int o_idx = o_base + global_row * (num_heads * head_dim) + head * head_dim + dim;
199+
200+
// Compute P'ij @ Vj for this dimension
201+
T pv_sum = T(0.0);
202+
for (int c = 0; c < actual_Bc; c++) {
203+
const int global_col = col_start + c;
204+
const int v_idx = v_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim;
205+
pv_sum += S_block[r][c] * t_V[v_idx];
206+
}
207+
208+
// Check for division by zero before updating output
209+
if (l_new_i[r] <= 0.0) {
210+
t_O[o_idx] = T(0.0); // Set to zero to avoid NaN
211+
} else {
212+
// Oi = (alpha * l_i * Oi + beta * P'ij @ Vj) / l_new_i
213+
t_O[o_idx] = (T(alpha) * T(l_i[r]) * t_O[o_idx] + T(beta) * pv_sum) / T(l_new_i[r]);
214+
}
215+
}
216+
}
217+
218+
// STEP 13: Update li, mi
219+
for (int r = 0; r < actual_Br; r++) {
220+
const int seq_pos = row_start + r;
221+
const int lm_idx = lm_base + head * seq_len + seq_pos;
222+
t_l[lm_idx] = l_new_i[r];
223+
t_m[lm_idx] = m_new_i[r];
224+
}
225+
}
226+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
flash_attention:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
STORAGE: buffer
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: float
8+
shader_variants:
9+
- NAME: flash_attention_buffer
10+
STORAGE: buffer

0 commit comments

Comments
 (0)