-
Notifications
You must be signed in to change notification settings - Fork 0
Fix flash attention state accumulation #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: tzj/qlutattn
Are you sure you want to change the base?
Fix flash attention state accumulation #8
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Extends the flash-attention state tensor to include accumulated VKQ values, updates the relevant operators and API validation for the new layout, and enhances the test to work with the expanded state tensor.
- Extend state tensor shape from
[2, n_heads * seq_len]to[2 + head_dim, n_heads * seq_len] - Update
ggml_flash_attn_ext_with_stateand CPU op implementations to read/write VKQ slots - Update API docs and bolster tests for the new layout
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| tests/test-flash-attn-state.cpp | Adjusted reset/print routines for expanded state; updated creation of state tensor |
| ggml/src/ggml.c | Updated state-dimension assertions in ggml_flash_attn_ext_with_state |
| ggml/src/ggml-cpu/ops.cpp | Expanded validation and read/write paths to handle VKQ values |
| ggml/include/ggml.h | Updated function comment to document new [M, S, VKQ...] layout |
Comments suppressed due to low confidence (2)
tests/test-flash-attn-state.cpp:376
- [nitpick] After printing the final state, the test does not verify the VKQ portion of the state tensor. Add assertions to confirm that the VKQ entries match expected values across segments.
int row = state->ne[0];
tests/test-flash-attn-state.cpp:154
- The variable
head_dimis not defined inmain()before it’s used here. Declare and initializehead_dim(e.g. fromv->ne[0]orstate->ne[0] - 2) prior to creating the state tensor.
ggml_tensor * state = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2 + head_dim, n_heads * seq_len);
…gml-org#16038) Initalizing RESERVED_NAME in is_reserved_name() is not thread safe and leads to corrupted memory when used from multiple threads as can be seen in the asan trace below. This fixes the initialization to make it thread-safe. #0 0x000100abd018 in std::__1::pair<std::__1::__hash_iterator<std::__1::__hash_node<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, void*>*>, bool> std::__1::__hash_table<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, std::__1::hash<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>, std::__1::equal_to<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>, std::__1::allocator<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>>::__emplace_unique_key_args<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&>(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) __hash_table:1565 #1 0x000100ab0320 in SchemaConverter::visit(nlohmann::json_abi_v3_12_0::basic_json<nlohmann::json_abi_v3_12_0::ordered_map, std::__1::vector, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, bool, long long, unsigned long long, double, std::__1::allocator, nlohmann::json_abi_v3_12_0::adl_serializer, std::__1::vector<unsigned char, std::__1::allocator<unsigned char>>, void> const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) json-schema-to-grammar.cpp:802 #2 0x000100aafc48 in std::__1::__function::__func<build_grammar(std::__1::function<void (common_grammar_builder const&)> const&, common_grammar_options const&)::$_2, std::__1::allocator<build_grammar(std::__1::function<void (common_grammar_builder const&)> const&, common_grammar_options const&)::$_2>, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> (std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, nlohmann::json_abi_v3_12_0::basic_json<nlohmann::json_abi_v3_12_0::ordered_map, std::__1::vector, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, bool, long long, unsigned long long, double, std::__1::allocator, nlohmann::json_abi_v3_12_0::adl_serializer, std::__1::vector<unsigned char, std::__1::allocator<unsigned char>>, void> const&)>::operator()(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, nlohmann::json_abi_v3_12_0::basic_json<nlohmann::json_abi_v3_12_0::ordered_map, std::__1::vector, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, bool, long long, unsigned long long, double, std::__1::allocator, nlohmann::json_abi_v3_12_0::adl_serializer, std::__1::vector<unsigned char, std::__1::allocator<unsigned char>>, void> const&) function.h:319 #3 0x000100a2c938 in std::__1::__function::__func<common_chat_params_init_llama_3_x(minja::chat_template const&, templates_params const&, bool)::$_0::operator()(common_grammar_builder const&) const::'lambda'(nlohmann::json_abi_v3_12_0::basic_json<nlohmann::json_abi_v3_12_0::ordered_map, std::__1::vector, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, bool, long long, unsigned long long, double, std::__1::allocator, nlohmann::json_abi_v3_12_0::adl_serializer, std::__1::vector<unsigned char, std::__1::allocator<unsigned char>>, void> const&), std::__1::allocator<common_chat_params_init_llama_3_x(minja::chat_template const&, templates_params const&, bool)::$_0::operator()(common_grammar_builder const&) const::'lambda'(nlohmann::json_abi_v3_12_0::basic_json<nlohmann::json_abi_v3_12_0::ordered_map, std::__1::vector, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, bool, long long, unsigned long long, double, std::__1::allocator, nlohmann::json_abi_v3_12_0::adl_serializer, std::__1::vector<unsigned char, std::__1::allocator<unsigned char>>, void> const&)>, void (nlohmann::json_abi_v3_12_0::basic_json<nlohmann::json_abi_v3_12_0::ordered_map, std::__1::vector, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, bool, long long, unsigned long long, double, std::__1::allocator, nlohmann::json_abi_v3_12_0::adl_serializer, std::__1::vector<unsigned char, std::__1::allocator<unsigned char>>, void> const&)>::operator()(nlohmann::json_abi_v3_12_0::basic_json<nlohmann::json_abi_v3_12_0::ordered_map, std::__1::vector, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, bool, long long, unsigned long long, double, std::__1::allocator, nlohmann::json_abi_v3_12_0::adl_serializer, std::__1::vector<unsigned char, std::__1::allocator<unsigned char>>, void> const&) function.h:319 #4 0x000100a139f8 in foreach_function(nlohmann::json_abi_v3_12_0::basic_json<nlohmann::json_abi_v3_12_0::ordered_map, std::__1::vector, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, bool, long long, unsigned long long, double, std::__1::allocator, nlohmann::json_abi_v3_12_0::adl_serializer, std::__1::vector<unsigned char, std::__1::allocator<unsigned char>>, void> const&, std::__1::function<void (nlohmann::json_abi_v3_12_0::basic_json<nlohmann::json_abi_v3_12_0::ordered_map, std::__1::vector, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, bool, long long, unsigned long long, double, std::__1::allocator, nlohmann::json_abi_v3_12_0::adl_serializer, std::__1::vector<unsigned char, std::__1::allocator<unsigned char>>, void> const&)> const&) chat.cpp:762 #5 0x000100a2a7f4 in std::__1::__function::__func<common_chat_params_init_llama_3_x(minja::chat_template const&, templates_params const&, bool)::$_0, std::__1::allocator<common_chat_params_init_llama_3_x(minja::chat_template const&, templates_params const&, bool)::$_0>, void (common_grammar_builder const&)>::operator()(common_grammar_builder const&) function.h:319 #6 0x000100aa98f4 in build_grammar(std::__1::function<void (common_grammar_builder const&)> const&, common_grammar_options const&) json-schema-to-grammar.cpp:982 #7 0x0001009c9314 in common_chat_params_init_llama_3_x(minja::chat_template const&, templates_params const&, bool) chat.cpp:1110 #8 0x0001009b8afc in common_chat_templates_apply_jinja(common_chat_templates const*, common_chat_templates_inputs const&) chat.cpp:1992 #9 0x0001009b533c in common_chat_templates_apply(common_chat_templates const*, common_chat_templates_inputs const&) chat.cpp:2074 #10 0x000100810120 in llamacpp_apply_chat_template+0x724 (predict_oai-98384e17fb94e863:arm64+0x100090120) ... ==45482==Register values: x[0] = 0x00006020004147f8 x[1] = 0x00006080000013c8 x[2] = 0x0000000000000000 x[3] = 0x0000604006289738 x[4] = 0x0000000000000002 x[5] = 0x0000000000000001 x[6] = 0x04034000004b4000 x[7] = 0x0000000000000001 x[8] = 0xbebebebebebebebe x[9] = 0x17d7d7d7d7d7d7d7 x[10] = 0x00000c04000828ff x[11] = 0x0000000000000001 x[12] = 0x000000002018d383 x[13] = 0x0000000000000000 x[14] = 0xfa0000000000fafa x[15] = 0x000010700001ffff x[16] = 0x000000019dc012c0 x[17] = 0x00000001021284f8 x[18] = 0x0000000000000000 x[19] = 0x00000001700acdc0 x[20] = 0x0000000000000002 x[21] = 0x000000002018d384 x[22] = 0x16dd16fd2e731151 x[23] = 0x0000007000020000 x[24] = 0x0000000100c69c08 x[25] = 0x0000000100c69c20 x[26] = 0x00006080000013c7 x[27] = 0x0000000100c69c00 x[28] = 0x00000001700acd60 fp = 0x00000001700aceb0 lr = 0x0000000100abce30 sp = 0x00000001700acd60 AddressSanitizer can not provide additional info. SUMMARY: AddressSanitizer: SEGV __hash_table:1565 in std::__1::pair<std::__1::__hash_iterator<std::__1::__hash_node<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, void*>*>, bool> std::__1::__hash_table<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, std::__1::hash<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>, std::__1::equal_to<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>, std::__1::allocator<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>>::__emplace_unique_key_args<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&>(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) Thread T5 created by T0 here: #0 0x0001020b99d4 in pthread_create+0x5c (libclang_rt.asan_osx_dynamic.dylib:arm64e+0x359d4) #1 0x000100873910 in std::sys::pal::unix::thread::Thread::new::h77254fdd87a28e05+0x118 (predict_oai-98384e17fb94e863:arm64+0x1000f3910) #2 0x0001007c7a1c in test::run_test::haeb3c2bcd5ed6cf6+0x76c (predict_oai-98384e17fb94e863:arm64+0x100047a1c) #3 0x0001007aedb0 in test::console::run_tests_console::he9d142d704f3a986+0x149c (predict_oai-98384e17fb94e863:arm64+0x10002edb0) #4 0x0001007c5758 in test::test_main::hf86a5e20735245b9+0x118 (predict_oai-98384e17fb94e863:arm64+0x100045758) #5 0x0001007c5da0 in test::test_main_static::h61ee9c8fd30abca0+0x54 (predict_oai-98384e17fb94e863:arm64+0x100045da0) ... ==45482==ABORTING
ggml-org#17031) * Faster tensors (#8) Add fast matrix and matrix/vector multiplication. * Use map for shader replacements instead of pair of strings
Summary
Testing
cmake --build build-x86_64 --config Release -j12./build-x86_64/bin/test-flash-attn-statehttps://chatgpt.com/codex/tasks/task_e_6854917486ac83329eb2492474af1ecc