From e5ef519d46a7d6bba32e553ded6e6c52d9365c6b Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Wed, 2 Oct 2024 12:30:40 -0700 Subject: [PATCH 1/2] prefill model (#5807) Summary: python -m executorch.examples.models.llama2.export_llama --disable_dynamic_shape --qnn --pt2e_quantize qnn_16a4w Segfault error stacktrace ``` [INFO] [Qnn ExecuTorch]: Initialize Qnn backend parameters for Qnn executorch backend type 2 [INFO] [Qnn ExecuTorch]: Caching: Caching is in SAVE MODE. [WARNING] [Qnn ExecuTorch]: Qnn API version 2.19.0 is used. The version is tested against 2.18.0. [INFO] [Qnn ExecuTorch]: Running level=3 optimization. AddressSanitizer:DEADLYSIGNAL ================================================================= ==1523599==ERROR: AddressSanitizer: SEGV on unknown address 0x000000000020 (pc 0x7f1585ee38e2 bp 0x7f16d5ab8800 sp 0x7ffed19ab8b0 T0) ==1523599==The signal is caused by a READ memory access. ==1523599==Hint: address points to the zero page. SCARINESS: 10 (null-deref) #0 0x7f1585ee38e2 (/home/chenlai/fbsource/third-party/qualcomm/qnn/qnn-2.26/lib/x86_64-linux-clang/libQnnHtp.so+0x2ce38e2) (BuildId: bc3ab8ddc89a0e65) #1 0x7f1585dd8926 (/home/chenlai/fbsource/third-party/qualcomm/qnn/qnn-2.26/lib/x86_64-linux-clang/libQnnHtp.so+0x2bd8926) (BuildId: bc3ab8ddc89a0e65) #2 0x7f15844d1161 (/home/chenlai/fbsource/third-party/qualcomm/qnn/qnn-2.26/lib/x86_64-linux-clang/libQnnHtp.so+0x12d1161) (BuildId: bc3ab8ddc89a0e65) #3 0x7f15844dcac6 (/home/chenlai/fbsource/third-party/qualcomm/qnn/qnn-2.26/lib/x86_64-linux-clang/libQnnHtp.so+0x12dcac6) (BuildId: bc3ab8ddc89a0e65) #4 0x7f15844d245b (/home/chenlai/fbsource/third-party/qualcomm/qnn/qnn-2.26/lib/x86_64-linux-clang/libQnnHtp.so+0x12d245b) (BuildId: bc3ab8ddc89a0e65) #5 0x7f15b9bc7b21 in auto torch::executor::qnn::QnnInterface::qnn_backend_validate_op_config(void*, Qnn_OpConfig_t) const fbcode/executorch/backends/qualcomm/runtime/backends/QnnFunctionInterface.h:39 #6 0x7f15b9bc7682 in torch::executor::qnn::QnnBackend::BackendValidateOpConfig(Qnn_OpConfig_t const&) fbcode/executorch/backends/qualcomm/runtime/backends/QnnBackendCommon.h:41 #7 0x7f15b9bc7115 in torch::executor::qnn::QnnManager::IsNodeSupportedByBackend(std::vector, std::allocator>>&) fbcode/executorch/backends/qualcomm/runtime/QnnManager.cpp:450 #8 0x7f15b9dd44ee in torch::executor::qnn::PyQnnManager::IsNodeSupportedByBackend(std::vector, std::allocator>>&) fbcode/executorch/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h:57 #9 0x7f15b9e5b986 in pybind11::cpp_function::cpp_function, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool (torch::executor::qnn::PyQnnManager::*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&)::operator()(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&) const fbsource/pybind11/pybind11.h:84 #10 0x7f15b9e5b8b5 in bool pybind11::detail::argument_loader, std::allocator>>&>::call_impl, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool (torch::executor::qnn::PyQnnManager::*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&)&, 0ul, 1ul, pybind11::detail::void_type>(torch::executor::qnn::PyQnnManager&&, std::integer_sequence, pybind11::detail::void_type&&) && fbsource/pybind11/cast.h:2042 #11 0x7f15b9e53831 in std::enable_if::value, bool>::type pybind11::detail::argument_loader, std::allocator>>&>::call, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool (torch::executor::qnn::PyQnnManager::*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&)&>(pybind11::cpp_function::cpp_function, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool (torch::executor::qnn::PyQnnManager::*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&)&) && fbsource/pybind11/cast.h:2014 #12 0x7f15b9e53454 in void pybind11::cpp_function::initialize, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool (torch::executor::qnn::PyQnnManager::*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&), bool, torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool&&, torch::executor::qnn::PyQnnManager (*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(pybind11::detail::function_call&)::operator()(pybind11::detail::function_call&) const fbsource/pybind11/pybind11.h:193 #13 0x7f15b9e530d3 in void pybind11::cpp_function::initialize, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool (torch::executor::qnn::PyQnnManager::*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&), bool, torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool&&, torch::executor::qnn::PyQnnManager (*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(pybind11::detail::function_call&)::__invoke(pybind11::detail::function_call&) fbsource/pybind11/pybind11.h:170 #14 0x7f15b9d8f707 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) fbsource/pybind11/pybind11.h:767 #15 0x327141 in cfunction_call(_object*, _object*, _object*) (.__uniq.281047882695835599676768160755749362799) (/usr/local/fbcode/platform010/bin/python3.10+0x327141) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #16 0x349630 in _PyObject_MakeTpCall (/usr/local/fbcode/platform010/bin/python3.10+0x349630) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #17 0x5897d4 in method_vectorcall(_object*, _object* const*, unsigned long, _object*) (.__uniq.243338978568352371442406765225626566013.llvm.6236606370933165261) (/usr/local/fbcode/platform010/bin/python3.10+0x5897d4) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #18 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #19 0x331421 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x331421) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #20 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #21 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #22 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #23 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #24 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #25 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #26 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #27 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #28 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #29 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #30 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #31 0x331577 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x331577) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #32 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #33 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #34 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #35 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #36 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #37 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #38 0x39b8ca in _PyEval_Vector (/usr/local/fbcode/platform010/bin/python3.10+0x39b8ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #39 0x39ad7d in _PyObject_FastCallDictTstate (/usr/local/fbcode/platform010/bin/python3.10+0x39ad7d) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #40 0x3c8b72 in slot_tp_call(_object*, _object*, _object*) (.__uniq.235726554139783955843240177532338160225) (/usr/local/fbcode/platform010/bin/python3.10+0x3c8b72) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #41 0x392ca8 in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x392ca8) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #42 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #43 0x39b8ca in _PyEval_Vector (/usr/local/fbcode/platform010/bin/python3.10+0x39b8ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #44 0x331b18 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x331b18) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #45 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #46 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #47 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #48 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #49 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #50 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #51 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #52 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #53 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #54 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #55 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #56 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #57 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #58 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #59 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #60 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #61 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #62 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #63 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #64 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #65 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #66 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #67 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #68 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #69 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #70 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #71 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #72 0x39b8ca in _PyEval_Vector (/usr/local/fbcode/platform010/bin/python3.10+0x39b8ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #73 0x431565 in PyEval_EvalCode (/usr/local/fbcode/platform010/bin/python3.10+0x431565) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #74 0x431447 in run_mod(_mod*, _object*, _object*, _object*, PyCompilerFlags*, _arena*) (.__uniq.251861886623903963524397139660542440724.llvm.17622910512627074885) (/usr/local/fbcode/platform010/bin/python3.10+0x431447) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #75 0x4e3054 in pyrun_file(_IO_FILE*, _object*, int, _object*, _object*, int, PyCompilerFlags*) (.__uniq.251861886623903963524397139660542440724) (/usr/local/fbcode/platform010/bin/python3.10+0x4e3054) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #76 0x4e2b54 in _PyRun_SimpleFileObject (/usr/local/fbcode/platform010/bin/python3.10+0x4e2b54) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #77 0x4e28f1 in _PyRun_AnyFileObject (/usr/local/fbcode/platform010/bin/python3.10+0x4e28f1) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #78 0x4d4a54 in Py_RunMain (/usr/local/fbcode/platform010/bin/python3.10+0x4d4a54) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #79 0x4d286b in pymain_main(_PyArgv*) (.__uniq.297908980262787110426434251325078884054) (/usr/local/fbcode/platform010/bin/python3.10+0x4d286b) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #80 0x4d2759 in Py_BytesMain (/usr/local/fbcode/platform010/bin/python3.10+0x4d2759) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #81 0x7f19e282c656 in __libc_start_call_main (/usr/local/fbcode/platform010/lib/libc.so.6+0x2c656) (BuildId: 93cdceeb8322234c38e1f2c93ad0ff10c7632fa6) #82 0x7f19e282c717 in __libc_start_main@GLIBC_2.2.5 (/usr/local/fbcode/platform010/lib/libc.so.6+0x2c717) (BuildId: 93cdceeb8322234c38e1f2c93ad0ff10c7632fa6) #83 0x553d90 in _start (/usr/local/fbcode/platform010/bin/python3.10+0x553d90) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) AddressSanitizer can not provide additional info. AddressSanitizer: SEGV (/home/chenlai/fbsource/third-party/qualcomm/qnn/qnn-2.26/lib/x86_64-linux-clang/libQnnHtp.so+0x2ce38e2) (BuildId: bc3ab8ddc89a0e65) ==1523599==ABORTING ``` Differential Revision: D63736779 --- examples/models/llama2/export_llama_lib.py | 55 +- examples/models/llama2/llama_transformer.py | 551 ++++-------------- examples/models/llama2/model.py | 69 ++- .../models/llama2/params/demo_config.json | 2 +- examples/models/llama2/runner/targets.bzl | 1 + examples/models/model_factory.py | 8 +- extension/llm/export/builder.py | 3 + extension/llm/export/partitioner_lib.py | 6 +- 8 files changed, 218 insertions(+), 477 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index a39bb048200..e2e5a178edf 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -53,21 +53,23 @@ get_quant_embedding_transform, get_quant_weight_transform, ) -from .source_transformation.quantized_kv_cache import ( - replace_kv_cache_with_quantized_kv_cache, -) + +# from .source_transformation.quantized_kv_cache import ( +# replace_kv_cache_with_quantized_kv_cache, +# ) from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis -from .source_transformation.sdpa import ( - replace_causal_mask, - replace_kv_cache_with_coreml_kv_cache, - replace_kv_cache_with_simple_kv_cache, - replace_sdpa_with_coreml_sdpa, - replace_sdpa_with_custom_op, - replace_sdpa_with_flex_sdpa, - replace_sdpa_with_simple_sdpa, -) + +# from .source_transformation.sdpa import ( +# replace_causal_mask, +# replace_kv_cache_with_coreml_kv_cache, +# replace_kv_cache_with_simple_kv_cache, +# replace_sdpa_with_coreml_sdpa, +# replace_sdpa_with_custom_op, +# replace_sdpa_with_flex_sdpa, +# replace_sdpa_with_simple_sdpa, +# ) IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -893,23 +895,20 @@ def _get_source_transforms( # noqa assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" transforms.append(replace_kv_cache_with_quantized_kv_cache) + if args.qnn: + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` + from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d + + # transforms.append(replace_kv_cache_with_simple_kv_cache) + # transforms.append(replace_sdpa_with_flex_sdpa) + # transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) + transforms.append(convert_linear_to_conv2d) if args.use_kv_cache: - if args.qnn: - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` - from executorch.backends.qualcomm.utils.utils import ( - convert_linear_to_conv2d, - ) - - transforms.append(replace_kv_cache_with_simple_kv_cache) - transforms.append(replace_sdpa_with_flex_sdpa) - transforms.append(replace_causal_mask) - transforms.append(replace_rms_norm_with_native_rms_norm) - if args.optimized_rotation_path: - transforms.append(fuse_layer_norms) - transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) - transforms.append(convert_linear_to_conv2d) - - elif args.mps: + if args.mps: # Currently mps doesn't support sdpa op, use the simpler decomposition # to get free perf gain. transforms.append(replace_sdpa_with_simple_sdpa) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 8e17013ae3d..2a229a87609 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -1,27 +1,14 @@ -# @lint-ignore-every LICENSELINT -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# Llama 2 is licensed under the LLAMA 2 Community License, -# Copyright (c) Meta Platforms, Inc. All Rights Reserved. - -# Please refer to README.md in the same folder for more information. - +import logging +import math from dataclasses import dataclass -from functools import partial -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import torch -import torch.nn.functional as F -from executorch.examples.models.llama2.rope import ( - apply_rotary_emb, - hf_apply_rotary_emb, - hf_precompute_freqs_cis, - precompute_freqs_cis, -) +from torch.nn import functional as F + -from torch import nn +logger: logging.Logger = logging.getLogger() class RMSNorm(torch.nn.Module): @@ -39,9 +26,8 @@ def __init__(self, dim: int, eps: float = 1e-6): """ super().__init__() - self.dim = dim self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) + self.weight = torch.nn.Parameter(torch.ones(dim)) def _norm(self, x): """ @@ -54,7 +40,7 @@ def _norm(self, x): torch.Tensor: The normalized tensor. """ - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): """ @@ -71,12 +57,6 @@ def forward(self, x): return output * self.weight -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) - - @dataclass class ModelArgs: dim: int = 4096 @@ -84,182 +64,56 @@ class ModelArgs: n_heads: int = 32 n_kv_heads: Optional[int] = None vocab_size: int = -1 # defined later by tokenizer - hidden_dim: Optional[int] = None + invocation_vocab_size: int = -1 # defined later by tokenizer multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None norm_eps: float = 1e-5 max_batch_size: int = 32 max_seq_len: int = 2048 - moe: bool = False # True to enable the MoE (Mixture of Experts) - num_experts: int = 8 # Number of experts - num_activated_experts: int = 2 # Number of experts to activate - use_kv_cache: bool = False # Use key/value cache - use_sdpa_with_kv_cache_op: bool = ( - False # Use custom sdpa op that updates kv cache in-place - ) - # Generate logits for all inputs. When it's True, it would take big memory usage - # at runtime. Enable it only necessary (e.g., use perplexity tools that requires - # logits for all input tokens.) - generate_full_logits: bool = False - enable_dynamic_shape: bool = False # export model with dynamic shape support - # A dictionary mapping from pruned token-id to original token-id - output_prune_map: Optional[Dict[int, int]] = None - use_hf_rope: bool = False # Use HuggingFace's RoPE implementation - rope_theta: Optional[float] = ( - None # The official name to override self.rope_freq_base. - ) - rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. - use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1. - # Additional Model Metadata needed at runtime - bos_idx: int = 1 - eos_idx: int = 3 - bos_count: int = -1 # i.e., a single EOS is used as BOS - eos_count: int = 2 - - def __post_init__(self): - if self.n_kv_heads is None: - self.n_kv_heads = self.n_heads - - # rope_theta overrides rope_freq_base since it's the official name. - if self.rope_theta is not None: - self.rope_freq_base = self.rope_theta - - if self.use_sdpa_with_kv_cache_op: - assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache" - - if self.hidden_dim is None: - # If hidden_dim is not explicitly set in the ModelArgs, - # then calculate implicitly based on dim and also multiple of `args.multiple_of` - multiple_of = self.multiple_of - hidden_dim = 4 * self.dim - hidden_dim = int(2 * hidden_dim / 3) - if self.ffn_dim_multiplier is not None: - hidden_dim = int(self.ffn_dim_multiplier * hidden_dim) - self.hidden_dim = find_multiple(hidden_dim, multiple_of) + hidden_dim: Optional[int] = None -class KVCache(nn.Module): - def __init__( - self, - max_batch_size: int, - max_seq_length: int, - n_heads: int, - head_dim: int, - transpose_cache: bool, - enable_dynamic_shape: bool, - dtype=torch.float32, - ): - super().__init__() - self.max_seq_length = max_seq_length - self.is_tranposed = transpose_cache - if transpose_cache: - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - else: - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) - - self.max_batch_size = max_batch_size - self.n_heads = n_heads - self.head_dim = head_dim - self.transpose_cache = transpose_cache - self.enable_dynamic_shape = enable_dynamic_shape - self.register_buffer( - "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") - ) - self.register_buffer( - "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") - ) +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # pyre-ignore + freqs = torch.outer(t, freqs).float() # pyre-ignore + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + return freqs_cos, freqs_sin - def update( - self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_seq_length) - dim_to_slice = 2 if self.transpose_cache else 1 - seq_length = k_val.size(dim_to_slice) - # Replace the entry in the cache for this token - # The following lines are equivalent to: - # cache_k[:bsz, start_pos : start_pos + seqlen] = xk - # cache_v[:bsz, start_pos : start_pos + seqlen] = xv - # when dim_to_slice is 1 - # We use .narrow() here to make the compiler happy - # pyre-ignore: Incompatible parameter type [6] - narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) - # pyre-ignore: Incompatible parameter type [6] - narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) - - narrowed_k.copy_(k_val) - narrowed_v.copy_(v_val) - return self.k_cache, self.v_cache - else: - k_out = self.k_cache - v_out = self.v_cache - if self.transpose_cache: - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - else: - k_out[:, input_pos] = k_val - v_out[:, input_pos] = v_val - return k_out, v_out +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(shape) -class SDPA(nn.Module): - def __init__( - self, - kv_cache: KVCache, - dim: int, - head_dim: int, - n_rep: int, - max_seq_len: int, - enable_dynamic_shape: bool, - ): - super().__init__() - self.kv_cache = kv_cache - self.dim = dim - self.head_dim = head_dim - self.n_rep = n_rep - self.max_seq_len = max_seq_len - self.enable_dynamic_shape = enable_dynamic_shape +def apply_rotary_emb( + xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: - def forward( - self, - input_pos: torch.Tensor, - q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) - k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) - v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) - bsz, - seqlen, - mask: torch.Tensor, - ) -> torch.Tensor: - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) - if self.enable_dynamic_shape: - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_seq_len) - seq_length = q.size(2) - # pyre-ignore: Incompatible parameter type [6] - attn_mask = mask.narrow(0, start_pos, seq_length) - else: - attn_mask = mask[None, None, input_pos] + xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) + + freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) + freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) + xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin + xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos + xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin + xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos - return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) + xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) -class Attention(nn.Module): - def __init__(self, args: ModelArgs, layer_id: int): + +class Attention(torch.nn.Module): + def __init__(self, args: ModelArgs): super().__init__() - self.use_kv_cache = args.use_kv_cache self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads assert args.n_heads % self.n_kv_heads == 0 model_parallel_size = 1 @@ -267,295 +121,146 @@ def __init__(self, args: ModelArgs, layer_id: int): self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads - self.max_batch_size = args.max_batch_size - self.max_seq_len = args.max_seq_len - self.dim = args.dim - # args.dim = 4096, args.n_heads = 32, self.head_dim = 4096 / 32 = 125 - self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) - - self.layer_id = layer_id - - causal_mask = torch.tril( - torch.ones( - self.max_seq_len, - self.max_seq_len, - dtype=torch.bool, - device="cpu", - ) - ) - self.register_buffer("mask", causal_mask, persistent=False) - - if self.use_kv_cache: - self.kv_cache = KVCache( - args.max_batch_size, - args.max_seq_len, - self.n_kv_heads, - self.head_dim, - not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v - args.enable_dynamic_shape, - ) - self.SDPA = SDPA( - kv_cache=self.kv_cache, - dim=self.dim, - head_dim=self.head_dim, - n_rep=self.n_rep, - max_seq_len=self.max_seq_len, - enable_dynamic_shape=args.enable_dynamic_shape, - ) - if args.use_hf_rope: - self.apply_rotary_emb = hf_apply_rotary_emb - else: - self.apply_rotary_emb = apply_rotary_emb + self.wq = torch.nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) + self.wk = torch.nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = torch.nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = torch.nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + # set large value of -inf (or -32768 with int16) when we want to + # ignore correspnding values in the mask + mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-32768")) + mask = torch.triu(mask, diagonal=1) + self.register_buffer("mask", mask) def forward( self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, ): bsz, seqlen, _ = x.shape # QKV - q, k, v = self.wq(x), self.wk(x), self.wv(x) - # We need view_copy elimination - q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) - k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings - q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) - - if self.use_kv_cache: - assert input_pos is not None - output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) - return self.wo(output) - - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) + xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) # grouped multiquery attention: expand out keys and values - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - + xk = [ + torch.cat([xk[:, :, i : i + 1, :]] * self.n_rep, dim=2) + for i in range(xk.size(2)) + ] + xk = torch.cat(xk, dim=2) + + xv = [ + torch.cat([xv[:, :, i : i + 1, :]] * self.n_rep, dim=2) + for i in range(xv.size(2)) + ] + xv = torch.cat(xv, dim=2) + + # make heads into a batch dimension + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = xk.transpose(1, 2) + xv = xv.transpose(1, 2) + + scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) assert hasattr(self, "mask") - - mask = self.mask[:seqlen, :seqlen] - - output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + scores = ( + scores + self.mask[:, :, :seqlen, :seqlen] + ) # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) output = self.wo(output) - return output -class FeedForward(nn.Module): - def __init__(self, args: ModelArgs): +class FeedForward(torch.nn.Module): + def __init__(self, dim: int, hidden_dim: int, multiple_of: int): super().__init__() - assert args.hidden_dim is not None - hidden_dim: int = args.hidden_dim - self.w1 = nn.Linear(args.dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, args.dim, bias=False) - self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) + self.w1 = torch.nn.Linear(dim, hidden_dim, bias=False) + self.w2 = torch.nn.Linear(hidden_dim, dim, bias=False) + self.w3 = torch.nn.Linear(dim, hidden_dim, bias=False) def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + x = F.silu(self.w1(x)) * self.w3(x) + x = self.w2(x) + return x -class ConditionalFeedForward(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.dim = args.dim - hidden_dim = args.hidden_dim - if hidden_dim is None: - # If hidden_dim is not explicitly set in the ModelArgs, - # then calculate implicitly based on dim and also multiple of `args.multiple_of` - multiple_of = args.multiple_of - hidden_dim = 4 * self.dim - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) - self.w2 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) - self.w3 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) - self.num_experts = args.num_experts - - def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor: - w1_weights = self.w1[expert_indices].transpose(-1, -2) # [T, A, D, D] - w3_weights = self.w3[expert_indices].transpose(-1, -2) # [T, A, D, D] - w2_weights = self.w2[expert_indices] # [T, A, D, D] - x1 = F.silu(torch.einsum("ti,taio -> tao", x, w1_weights)) - x3 = torch.einsum("ti, taio -> tao", x, w3_weights) - expert_outs = torch.einsum("tao, taoi -> tai", (x1 * x3), w2_weights) - return expert_outs - - -class MOEFeedForward(nn.Module): - def __init__(self, config) -> None: - super().__init__() - self.gate = nn.Linear(config.dim, config.num_experts, bias=False) - self.cond_ffn = ConditionalFeedForward(config) - self.dim = config.dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.view(-1, self.dim) - # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts - # x: [T, D] - scores = self.gate(x) # [T, E] - expert_weights, expert_indices = torch.topk(scores, 2, dim=-1) # [T, A], [T, A] - expert_weights = expert_weights.softmax(dim=-1) # [T, A] - expert_outs = self.cond_ffn(x, expert_indices) - return torch.einsum("tai,ta -> ti", expert_outs, expert_weights) - - -class TransformerBlock(nn.Module): +class TransformerBlock(torch.nn.Module): def __init__(self, layer_id: int, args: ModelArgs): super().__init__() - self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads - self.attention = Attention(args, layer_id) - if args.moe: - self.block_sparse_moe = MOEFeedForward(args) + self.attention = Attention(args) + if args.hidden_dim is None: + hidden_dim = 4 * args.dim + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = args.multiple_of * ( + (hidden_dim + args.multiple_of - 1) // args.multiple_of + ) else: - self.feed_forward = FeedForward(args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - - def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN - h = self.attention.forward( - self.attention_norm(x), freqs_cos, freqs_sin, input_pos + hidden_dim = args.hidden_dim + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=hidden_dim, + multiple_of=args.multiple_of, ) + self.layer_id = layer_id + self.attention_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps) - h = x + h - if hasattr(self, "block_sparse_moe"): - out = h + self.block_sparse_moe(self.ffn_norm(h)) - else: - out = h + self.feed_forward(self.ffn_norm(h)) + def forward(self, x, freqs_cos, freqs_sin): + h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin) + out = h + self.feed_forward.forward(self.ffn_norm(h)) return out -class Transformer(nn.Module): +class LastTimeStepPool(torch.nn.Module): + def forward(self, logits: torch.Tensor, seq_lens: torch.Tensor) -> torch.Tensor: + bsz, _, dim = logits.shape + idx = seq_lens.unsqueeze(1).expand(bsz, dim).unsqueeze(1) + return logits.gather(1, idx - 1).squeeze(1) + + +class Transformer(torch.nn.Module): def __init__(self, params: ModelArgs): super().__init__() self.params = params self.vocab_size = params.vocab_size self.n_layers = params.n_layers - self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.tok_embeddings = torch.nn.Embedding(params.vocab_size, params.dim) self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params)) - self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = nn.Linear(params.dim, params.vocab_size, bias=False) - self.use_kv_cache = params.use_kv_cache - self.generate_full_logits = params.generate_full_logits - self.max_seq_len = params.max_seq_len - self.output_prune_map = params.output_prune_map - if params.use_hf_rope: - self.precompute_freqs_cis = hf_precompute_freqs_cis - else: - self.precompute_freqs_cis = partial( - precompute_freqs_cis, use_scaled=params.use_scaled_rope - ) - freqs_cos, freqs_sin = self.precompute_freqs_cis( - params.dim // params.n_heads, - ( - params.max_seq_len # Normal llama2. - if params.ffn_dim_multiplier is None - else params.max_seq_len * 2 # Sharded checkpoint. - ), - params.rope_freq_base, + self.norm = torch.nn.RMSNorm(params.dim, eps=params.norm_eps) + self.out = torch.nn.Linear(params.dim, params.vocab_size, bias=False) + + freqs_cos, freqs_sin = precompute_freqs_cis( + self.params.dim // self.params.n_heads, self.params.max_seq_len ) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) - def forward( - self, - tokens: Optional[torch.LongTensor] = None, # tokens - input_pos: Optional[ - torch.LongTensor - ] = None, # Scalar tensor indicating size of window of the caches - h: Optional[torch.FloatTensor] = None, # embeddings - ) -> torch.Tensor: - if (tokens is None) ^ (h is not None): - raise ValueError( - "You cannot specify both tokens and h at the same time, and must specify either one" - ) - if tokens is not None and h is None: - h = self.tok_embeddings(tokens) - seqlen = h.shape[1] - - if self.use_kv_cache: - assert ( - input_pos is not None - ), "input_pos must be provided when use_kv_cache is True" - - if self.params.enable_dynamic_shape: - # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. - input_pos_item = input_pos[-1].item() - torch._check_is_size(input_pos_item) - torch._check(input_pos_item < self.params.max_seq_len) - # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor - freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen) - # pyre-ignore: Incompatible parameter type [6] - freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen) - else: - # When not using dynamic shape, use of the .item results in - # symints, due to querying the data from tensor. - # this path avoids that for mps backend, although probably mps backend - # can support dynamic shape? - freqs_cos = self.freqs_cos[input_pos] - freqs_sin = self.freqs_sin[input_pos] - - else: - assert input_pos is None, "input_pos is unused when use_kv_cache is False" - freqs_cos = self.freqs_cos[:seqlen] - freqs_sin = self.freqs_sin[:seqlen] + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + freqs_cos = self.freqs_cos[:seqlen] + freqs_sin = self.freqs_sin[:seqlen] for layer in self.layers: - h = layer( - h, - freqs_cos, - freqs_sin, - input_pos, - ) - - if not self.generate_full_logits: - # Only the last logit is used for the new generated token - h = h[:, -1, :] + h = layer(h, freqs_cos, freqs_sin) h = self.norm(h) - logits = self.output(h) - - if self.output_prune_map is not None: - # expand to original size so that downstream applications can use the logits as-is. - if self.generate_full_logits: - # (1, seq_len, pruned_size) -> (1, seq_len, original_size) - expanded_logits = torch.full( - [logits.shape[0], logits.shape[1], self.vocab_size], - float("-inf"), - device=logits.device, - dtype=logits.dtype, - ) - expanded_logits[:, :, list(self.output_prune_map.values())] = logits - else: - # (1, pruned_size) -> (1, original_size) - expanded_logits = torch.full( - [logits.shape[0], self.vocab_size], - float("-inf"), - device=logits.device, - dtype=logits.dtype, - ) - expanded_logits[:, list(self.output_prune_map.values())] = logits - logits = expanded_logits - - return logits + invocation_logits = self.out(h) + + return invocation_logits diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index a4081d1bd57..5842ef544f0 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -150,16 +150,31 @@ def __init__(self, **kwargs): output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} max_seq_len = self.max_seq_len max_batch_size = 1 + print("params: ", params) + params.pop("rope_theta", None) model_args: ModelArgs = ModelArgs( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - use_kv_cache=self.use_kv_cache, - use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, - generate_full_logits=self.generate_full_logits, - output_prune_map=output_prune_map, - enable_dynamic_shape=self.enable_dynamic_shape, + # input_vocab_size=params["input_vocab_size"], + # use_kv_cache=self.use_kv_cache, + # use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, + # generate_full_logits=self.generate_full_logits, + # output_prune_map=output_prune_map, + # enable_dynamic_shape=self.enable_dynamic_shape, **params, ) + # model_args: ModelArgs = ( + # ModelArgs( + # dim=512, + # hidden_dim=1536, + # n_heads=8, + # n_kv_heads=2, + # n_layers=19, + # vocab_size=128256, + # invocation_vocab_size=8, + # use_layer_norm_op=True, + # ), + # ) if kwargs.get("fairseq2", False): print("Using fairseq2 checkpoint") checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint) @@ -170,10 +185,24 @@ def __init__(self, **kwargs): print(f"{key} : {weights.numel()} : {weights.size()}") print("============= /weights ================") - # Within the device="meta" context, tensors that are created do not carry data. - # They possess all other metadata a tensor carries such as size, stride, requires_grad. - with torch.device("meta"): - self.model_ = Transformer(model_args) + # Within the device="meta" context, tensors that are created do not carry data. + # They possess all other metadata a tensor carries such as size, stride, requires_grad. + # with torch.device("meta"): + # self.model_ = Transformer(model_args) + # self.model_ = Transformer( + # ModelArgs( + # dim=512, + # hidden_dim=1536, + # n_heads=8, + # n_kv_heads=2, + # n_layers=19, + # vocab_size=128256, + # invocation_vocab_size=8, + # use_layer_norm_op=True, + # ), + # ) + self.model_ = Transformer(model_args) + print("model: ", self.model_) if "int8" in str(checkpoint_path): print("Using int8 weight-only quantization!") @@ -263,11 +292,11 @@ def __init__(self, **kwargs): # assign=True: load params/buffers by assignment instead of performing an in-place copy. # Because we are using device="meta", tensors do not have memory associated with them # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. - missing, unexpected = self.model_.load_state_dict( - checkpoint, - strict=False, - assign=True, - ) # self.model_ = Transformer(gptconf) + # missing, unexpected = self.model_.load_state_dict( + # checkpoint, + # strict=False, + # assign=True, + # ) # self.model_ = Transformer(gptconf) if kwargs.get("verbose", False): print("============= missing keys ================") print(missing) @@ -296,11 +325,13 @@ def get_example_inputs(self): if self.use_kv_cache: return self.get_example_inputs_kvcache_sdpa() else: - return ( - torch.tensor( - [[1, 2, 3]], dtype=torch.long - ), # tokens, with kv cache our input token length is always just 1 token. - ) + # return ( + # torch.tensor( + # [[1, 2, 3]], dtype=torch.long + # ), # tokens, with kv cache our input token length is always just 1 token. + # ) + b = torch.ones(1, 64, dtype=torch.long) + return (b,) # assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working def get_example_inputs_kvcache_sdpa(self): diff --git a/examples/models/llama2/params/demo_config.json b/examples/models/llama2/params/demo_config.json index 13287f117e9..754d09b5ca2 100644 --- a/examples/models/llama2/params/demo_config.json +++ b/examples/models/llama2/params/demo_config.json @@ -1 +1 @@ -{"dim": 64, "multiple_of": 4, "n_heads": 8, "n_layers": 5, "norm_eps": 1e-05, "vocab_size": 512} \ No newline at end of file +{"dim": 64, "multiple_of": 4, "n_heads": 8, "n_layers": 1, "norm_eps": 1e-05, "vocab_size": 512} diff --git a/examples/models/llama2/runner/targets.bzl b/examples/models/llama2/runner/targets.bzl index 96d47ffce21..eb5dfe87299 100644 --- a/examples/models/llama2/runner/targets.bzl +++ b/examples/models/llama2/runner/targets.bzl @@ -29,6 +29,7 @@ def define_common_targets(): ], # qnn_executorch_backend can be added below //executorch/backends/qualcomm:qnn_executorch_backend exported_deps = [ + "//executorch/backends/qualcomm:qnn_executorch_backend", "//executorch/backends/xnnpack:xnnpack_backend", "//executorch/extension/llm/runner:stats", "//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix, diff --git a/examples/models/model_factory.py b/examples/models/model_factory.py index fb317e3bca3..8913bd50484 100644 --- a/examples/models/model_factory.py +++ b/examples/models/model_factory.py @@ -35,9 +35,11 @@ def create_model( ValueError: If the provided model class is not found in the module. """ package_prefix = "executorch." if not os.getcwd().endswith("executorch") else "" - module = importlib.import_module( - f"{package_prefix}examples.models.{module_name}" - ) + print(f"package_prefix: {package_prefix}") + # module = importlib.import_module( + # f"{package_prefix}examples.models.{module_name}" + # ) + module = importlib.import_module(f"executorch.examples.models.{module_name}") if hasattr(module, model_class_name): model_class = getattr(module, model_class_name) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index ae0ca6df757..16f77668839 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -9,6 +9,7 @@ # ExecuTorch. import logging +import os from enum import Enum from typing import Any, Callable, List, Optional @@ -34,6 +35,7 @@ from torch.ao.quantization.quantizer import Quantizer from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer from torch.nn.attention import SDPBackend +from tqdm import tqdm FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -150,6 +152,7 @@ def source_transform( return self def _get_dynamic_shape(self) -> Any: + return None if self.dynamic_shapes: return self.dynamic_shapes diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 37b215a51ff..8f827443f7a 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -154,9 +154,9 @@ def get_qnn_partitioner( num_sharding: int = 0, soc_model: str = "SM8650", # default to SM8650 ): - assert ( - use_kv_cache is True - ), "Qualcomm backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment" + # assert ( + # use_kv_cache is True + # ), "Qualcomm backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment" try: # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.partition.qnn_partitioner` from executorch.backends.qualcomm.partition.qnn_partitioner import ( From b951b2e4709b099e1000a6a75ca8406f0e0e3b30 Mon Sep 17 00:00:00 2001 From: Sheng Feng Wu Date: Wed, 2 Oct 2024 23:30:14 -0700 Subject: [PATCH 2/2] - Quantize and delegeted embedding op - Quantize matmul with 16x8 --- .../qualcomm/quantizer/custom_annotation.py | 26 +++++++++++++++++++ examples/models/llama2/export_llama_lib.py | 2 +- extension/llm/export/partitioner_lib.py | 2 +- extension/llm/export/quantizer_lib.py | 15 ++++++----- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 9cde50b9c70..881d24bbb5e 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -118,3 +118,29 @@ def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig): if "SDPA" in full_qualified_name: annotate_matmul(node, quantization_config_16a8w) annotate_matmul_input1(node.args[1], quantization_config_8a8w) + + +def custom_annotate_matmul_16a8w(gm: torch.fx.GraphModule): + """ + Annotate matmul op with 16a8w quantization config + """ + + def annotate_matmul(node: Node, quantization_config: QuantizationConfig): + input_qspec_map = {} + input_act = node.args[0] + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + input_act1 = node.args[1] + input_spec1 = quantization_config.weight + input_qspec_map[input_act1] = input_spec1 + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + # Annotate 16a8w for matmul op to get better performance + quantization_config_16a8w = get_16a8w_qnn_ptq_config() + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: + annotate_matmul(node, quantization_config_16a8w) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index e2e5a178edf..144e4f69f68 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -526,7 +526,7 @@ def get_quantizer_and_quant_params(args): if args.qnn and args.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack and qnn" qnn_quantizer, quant_dtype = get_qnn_quantizer( - args.pt2e_quantize, args.quantization_mode + args.pt2e_quantize, args.use_kv_cache, args.quantization_mode ) quantizers.append(qnn_quantizer) if args.coreml and args.pt2e_quantize: diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 8f827443f7a..58f4bda1b3c 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -179,7 +179,7 @@ def get_qnn_partitioner( ) use_fp16 = True - skip_node_op_set = {"llama.fallback.default", "aten.embedding.default"} + skip_node_op_set = {"llama.fallback.default"} if pt2e_quantize is not None: use_fp16 = False diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 45d9932724e..2cd5c9dfc60 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -143,11 +143,13 @@ def check_embedding_byte_registered(): def get_qnn_quantizer( pt2e_quantize: str, + use_kv_cache: bool, quantization_mode: Optional[str] = None, ): try: from executorch.backends.qualcomm.quantizer.custom_annotation import ( # pyre-fixme[21] custom_annotate_llama_matmul_16a8w, + custom_annotate_matmul_16a8w, ) # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer` @@ -198,8 +200,12 @@ def get_qnn_quantizer( get_16a4w_qnn_ptq_config(act_observer=MinMaxObserver) ) qnn_quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") - # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - custom_annotations = (custom_annotate_llama_matmul_16a8w,) + if use_kv_cache: + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + custom_annotations = (custom_annotate_llama_matmul_16a8w,) + else: + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + custom_annotations = (custom_annotate_matmul_16a8w,) else: raise AssertionError( f"No support for quant type {quant_config}. Support 8a8w, 16a16w and 16a4w." @@ -209,11 +215,6 @@ def get_qnn_quantizer( quantization_mode is None ), "Currently qnn backend only supports QnnQuantizer via pt2e flow" qnn_quantizer.add_custom_quant_annotations(custom_annotations) - qnn_quantizer.add_discard_ops( - [ - torch.ops.aten.embedding.default, - ] - ) return qnn_quantizer, quant_dtype