Skip to content

Commit 380e6a7

Browse files
committed
Update on "[ET-VK] Shortening code for slice op when packed dim is not the same as slice dim."
This diff is a code change for the Slice op to shorten its code when packed dimension is not the same as the slice dimension. Differential Revision: [D70737264](https://our.internmc.facebook.com/intern/diff/D70737264/) [ghstack-poisoned]
2 parents 8401f7a + b3181a4 commit 380e6a7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+510
-520
lines changed

.ci/scripts/unittest-buck2.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@ buck2 query "//backends/apple/... + //backends/example/... + \
2020
# TODO: expand the covered scope of Buck targets.
2121
# //runtime/kernel/... is failing because //third-party:torchgen_files's shell script can't find python on PATH.
2222
# //runtime/test/... requires Python torch, which we don't have in our OSS buck setup.
23-
buck2 build //runtime/backend/... //runtime/core/... //runtime/executor: //runtime/kernel/... //runtime/platform/...
24-
buck2 test //runtime/backend/... //runtime/core/... //runtime/executor: //runtime/kernel/... //runtime/platform/...
23+
buck2 test //kernels/portable/... //runtime/backend/... //runtime/core/... \
24+
//runtime/executor: //runtime/kernel/... //runtime/platform/...

.lintrunner.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ exclude_patterns = [
218218
'examples/**',
219219
'extension/**',
220220
'kernels/optimized/**',
221+
# Justified <functional> include.
222+
'runtime/kernel/thread_parallel_interface.h',
221223
'scripts/**',
222224
'third-party/**',
223225
'util/**',

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,6 @@ if(EXECUTORCH_BUILD_PTHREADPOOL
751751
AND EXECUTORCH_BUILD_CPUINFO
752752
)
753753
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/threadpool)
754-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/parallel)
755754
endif()
756755

757756
if(EXECUTORCH_BUILD_PYBIND)

CODEOWNERS

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,31 +52,31 @@
5252
/extension/export_util @kimishpatel
5353
/extension/flat_tensor @lucylq
5454
/extension/gguf_util @larryliu0820
55-
/extension/kernel_util @kimishpatel @manuelcandales
56-
/extension/llm @jackzhxng @iseeyuan @larryliu0820
57-
/extension/memory_allocator @JacobSzwejbka
55+
/extension/kernel_util @kimishpatel @manuelcandales @swolchok
56+
/extension/llm @jackzhxng @iseeyuan @larryliu0820 @swolchok
57+
/extension/memory_allocator @JacobSzwejbka @swolchok
5858
/extension/module @shoumikhin
59-
/extension/parallel @kimishpatel
59+
/extension/parallel @kimishpatel @swolchok
6060
/extension/pybindings @JacobSzwejbka @larryliu0820
61-
/extension/pytree @JacobSzwejbka
62-
# /extension/runner_util @dbort
61+
/extension/pytree @JacobSzwejbka @swolchok
62+
/extension/runner_util @swolchok
6363
/extension/tensor @shoumikhin
64-
# /extension/testing_util @dbort
65-
/extension/threadpool @kimishpatel
64+
/extension/testing_util @swolchok
65+
/extension/threadpool @kimishpatel @swolchok
6666
/extension/training @JacobSzwejbka
6767

68-
/kernels @manuelcandales
68+
/kernels @manuelcandales @swolchok
6969

7070
/profiler @tarun292 @Gasoonjia
7171

72-
/runtime @JacobSzwejbka @lucylq
72+
/runtime @JacobSzwejbka @lucylq @swolchok
7373
/runtime/backend @cccclai
7474

7575
/schema @JacobSzwejbka @lucylq
7676

77-
/scripts @GregoryComer
77+
/scripts @GregoryComer @swolchok
7878

79-
/shim @larryliu0820 @GregoryComer
79+
/shim @larryliu0820 @GregoryComer @swolchok
8080

8181
/third-party @GregoryComer
8282

Test.cmake

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ if(BUILD_TESTING)
1313
add_subdirectory(extension/evalue_util/test)
1414
add_subdirectory(extension/kernel_util/test)
1515
add_subdirectory(extension/memory_allocator/test)
16-
add_subdirectory(extension/parallel/test)
1716
add_subdirectory(extension/pytree/test)
1817
add_subdirectory(kernels/portable/cpu/util/test)
1918
add_subdirectory(kernels/prim_ops/test)

build/cmake_deps.toml

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ excludes = [
8888
deps = [
8989
"executorch",
9090
"executorch_core",
91-
"extension_parallel",
9291
"extension_threadpool",
9392
"portable_kernels",
9493
]
@@ -131,7 +130,7 @@ excludes = [
131130
deps = [
132131
"executorch_core",
133132
"executorch",
134-
"extension_parallel",
133+
"extension_threadpool",
135134
]
136135

137136
[targets.optimized_native_cpu_ops]
@@ -146,7 +145,6 @@ excludes = [
146145
deps = [
147146
"executorch_core",
148147
"executorch",
149-
"extension_parallel",
150148
"extension_threadpool",
151149
"portable_kernels",
152150
]
@@ -227,19 +225,6 @@ deps = [
227225
"extension_runner_util",
228226
]
229227

230-
[targets.extension_parallel]
231-
buck_targets = [
232-
"//extension/parallel:thread_parallel",
233-
]
234-
filters = [
235-
".cpp$",
236-
]
237-
deps = [
238-
"executorch",
239-
"executorch_core",
240-
"extension_threadpool",
241-
]
242-
243228
[targets.extension_tensor]
244229
buck_targets = [
245230
"//extension/tensor:tensor",
@@ -379,6 +364,7 @@ excludes = [
379364
deps = [
380365
"executorch",
381366
"executorch_core",
367+
"extension_threadpool",
382368
"xnnpack_backend",
383369
"portable_kernels",
384370
]
@@ -393,6 +379,7 @@ filters = [
393379
deps = [
394380
"executorch",
395381
"executorch_core",
382+
"extension_threadpool",
396383
]
397384

398385
[targets.xnnpack_schema]
@@ -427,7 +414,6 @@ deps = [
427414
"executorch",
428415
"executorch_core",
429416
"optimized_kernels",
430-
"extension_parallel",
431417
"extension_threadpool",
432418
"reduce_util",
433419
"xnnpack_backend",
@@ -465,7 +451,7 @@ deps = [
465451
"executorch_core",
466452
"extension_data_loader",
467453
"extension_module",
468-
"extension_parallel",
454+
"extension_threadpool",
469455
"portable_kernels",
470456
"quantized_kernels",
471457
"xnnpack_backend",

build/executorch-config.cmake

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ set(lib_list
7575
custom_ops
7676
extension_module
7777
extension_module_static
78-
extension_parallel
7978
extension_runner_util
8079
extension_tensor
8180
extension_threadpool
@@ -131,14 +130,9 @@ endforeach()
131130

132131
# TODO: investigate use of install(EXPORT) to cleanly handle
133132
# target_compile_options/target_compile_definitions for everything.
134-
if(TARGET extension_parallel)
135-
set_target_properties(
136-
extension_parallel PROPERTIES INTERFACE_LINK_LIBRARIES extension_threadpool
137-
)
138-
endif()
139133
if(TARGET cpublas)
140134
set_target_properties(
141-
cpublas PROPERTIES INTERFACE_LINK_LIBRARIES extension_parallel
135+
cpublas PROPERTIES INTERFACE_LINK_LIBRARIES extension_threadpool
142136
)
143137
endif()
144138
if(TARGET extension_threadpool)

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ class StaticKVCache {
3838
reset();
3939
}
4040

41+
StaticKVCache(const StaticKVCache& other) = delete;
42+
StaticKVCache& operator=(const StaticKVCache& other) = delete;
43+
StaticKVCache(StaticKVCache&& other) = delete;
44+
StaticKVCache& operator=(StaticKVCache&& other) = delete;
45+
4146
~StaticKVCache() {
4247
allocator_.deallocate(data_, data_size_);
4348
}
@@ -200,6 +205,15 @@ class StaticAttentionMask {
200205
reset();
201206
}
202207

208+
StaticAttentionMask(const StaticAttentionMask& other) = delete;
209+
StaticAttentionMask& operator=(const StaticAttentionMask& other) = delete;
210+
StaticAttentionMask(StaticAttentionMask&& other) = delete;
211+
StaticAttentionMask& operator=(StaticAttentionMask&& other) = delete;
212+
213+
~StaticAttentionMask() {
214+
allocator_.deallocate(data_, data_size_);
215+
}
216+
203217
/**
204218
* Reset the mask to the state where the cache contains no valid data.
205219
*/
@@ -315,7 +329,7 @@ class StaticAttentionIOManager {
315329
input_pos_ += update_len;
316330
kCaches_.update(method, k_cache_output_indices, update_len);
317331
vCaches_.update(method, v_cache_output_indices, update_len);
318-
for (auto it : attentionMasks_) {
332+
for (auto& it : attentionMasks_) {
319333
it.second.updateCacheMask(update_len);
320334
}
321335
}
@@ -324,7 +338,7 @@ class StaticAttentionIOManager {
324338
input_pos_ = 0;
325339
kCaches_.reset();
326340
vCaches_.reset();
327-
for (auto it : attentionMasks_) {
341+
for (auto& it : attentionMasks_) {
328342
it.second.reset();
329343
}
330344
}

examples/models/llama/static_attention.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
210210
self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5)
211211
self.attention_qkv_bias = config.attention_qkv_bias
212212
self.use_qk_norm = config.use_qk_norm
213+
self.use_conv2d = False
213214

214215
assert not self.use_qk_norm, "QK norm not supported in static attention yet"
215216
self.wqs = nn.ModuleList(
@@ -255,9 +256,25 @@ def forward(
255256
in_cache_state = kwargs.get("in_cache_state")
256257
out_cache_state = kwargs.get("out_cache_state")
257258

259+
bsz, seq_len, dim = x.shape
260+
if self.use_conv2d:
261+
x = x.reshape(bsz, seq_len, 1, dim).transpose(1, 3)
262+
258263
new_qs = [self.wqs[i](x) for i in range(self.n_heads)]
259264
new_ks = [self.wks[i](x) for i in range(self.n_kv_heads)]
260265
new_vs = [self.wvs[i](x) for i in range(self.n_kv_heads)]
266+
267+
if self.use_conv2d:
268+
269+
def from_conv2ds(ts):
270+
return [
271+
t.reshape(bsz, self.head_dim, seq_len).transpose(1, 2) for t in ts
272+
]
273+
274+
new_qs = from_conv2ds(new_qs)
275+
new_ks = from_conv2ds(new_ks)
276+
new_vs = from_conv2ds(new_vs)
277+
261278
new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs]
262279
new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks]
263280
all_ks = []
@@ -282,7 +299,14 @@ def forward(
282299
heads.append(attn @ all_vs[kv_idx])
283300

284301
y = torch.cat(heads, dim=-1)
285-
y = self.wo(y)
302+
if self.use_conv2d:
303+
y = (
304+
self.wo(y.reshape(bsz, seq_len, 1, -1).transpose(1, 3))
305+
.transpose(1, 3)
306+
.reshape(bsz, seq_len, -1)
307+
)
308+
else:
309+
y = self.wo(y)
286310
return y, {"out_cache_state": out_cache_state}
287311

288312
def load_weights_from_attention_mha(self, other: AttentionMHA):
@@ -300,3 +324,44 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
300324
)
301325

302326
self.wo.weight.data.copy_(other.wo.weight)
327+
328+
def linear_to_conv2d(self):
329+
def transfer_weight(linear, conv2d):
330+
conv2d.weight.data.copy_(linear.weight[:, :, None, None])
331+
return conv2d
332+
333+
self.wqs = nn.ModuleList(
334+
[
335+
transfer_weight(
336+
linear,
337+
nn.Conv2d(self.dim, self.head_dim, 1, bias=self.attention_qkv_bias),
338+
)
339+
for linear in self.wqs
340+
]
341+
)
342+
self.wks = nn.ModuleList(
343+
[
344+
transfer_weight(
345+
linear,
346+
nn.Conv2d(self.dim, self.head_dim, 1, bias=self.attention_qkv_bias),
347+
)
348+
for linear in self.wks
349+
]
350+
)
351+
self.wvs = nn.ModuleList(
352+
[
353+
transfer_weight(
354+
linear,
355+
nn.Conv2d(self.dim, self.head_dim, 1, bias=self.attention_qkv_bias),
356+
)
357+
for linear in self.wvs
358+
]
359+
)
360+
self.wo = transfer_weight(
361+
self.wo,
362+
nn.Conv2d(
363+
self.n_heads * self.head_dim, self.dim, 1, bias=self.attention_qkv_bias
364+
),
365+
)
366+
367+
self.use_conv2d = True

examples/models/llama/tests/test_static_attention.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,38 @@ def setUp(self):
1717
torch.manual_seed(42)
1818

1919
def test_without_cache(self):
20-
config = ModelArgs(
21-
dim=64,
22-
n_heads=4,
23-
n_kv_heads=2,
24-
max_seq_len=8,
25-
)
26-
layer_id = 0
27-
rope = Rope(config)
28-
attn_mha = AttentionMHA(config, layer_id, rope).eval()
29-
static_attn = StaticAttention(config, layer_id, rope).eval()
30-
static_attn.load_weights_from_attention_mha(attn_mha)
20+
def test(use_conv2d):
21+
config = ModelArgs(
22+
dim=64,
23+
n_heads=4,
24+
n_kv_heads=2,
25+
max_seq_len=8,
26+
)
27+
layer_id = 0
28+
rope = Rope(config)
29+
attn_mha = AttentionMHA(config, layer_id, rope).eval()
30+
static_attn = StaticAttention(config, layer_id, rope).eval()
31+
static_attn.load_weights_from_attention_mha(attn_mha)
32+
if use_conv2d:
33+
static_attn.linear_to_conv2d()
34+
35+
x = torch.rand(1, config.max_seq_len, config.dim)
36+
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
37+
expected, _ = attn_mha(x, freqs_cos, freqs_sin)
38+
mask = torch.triu(
39+
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
40+
diagonal=1,
41+
)
42+
y, _ = static_attn(
43+
x,
44+
freqs_cos,
45+
freqs_sin,
46+
mask=mask,
47+
)
48+
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
3149

32-
x = torch.rand(1, config.max_seq_len, config.dim)
33-
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
34-
expected, _ = attn_mha(x, freqs_cos, freqs_sin)
35-
mask = torch.triu(
36-
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
37-
diagonal=1,
38-
)
39-
y, _ = static_attn(
40-
x,
41-
freqs_cos,
42-
freqs_sin,
43-
mask=mask,
44-
)
45-
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
50+
test(True)
51+
test(False)
4652

4753
def test_hf_rope_without_cache(self):
4854
config = ModelArgs(

0 commit comments

Comments
 (0)