Skip to content

Commit c8e5ceb

Browse files
Enable FLLM on static llama
1 parent c2adfa9 commit c8e5ceb

File tree

8 files changed

+117
-42
lines changed

8 files changed

+117
-42
lines changed

backends/qualcomm/runtime/QnnExecuTorchBackend.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
#include <executorch/backends/qualcomm/runtime/QnnExecuTorchBackend.h>
1212
#include <executorch/backends/qualcomm/runtime/QnnManager.h>
1313
#include <executorch/backends/qualcomm/runtime/backends/QnnCustomProtocol.h>
14-
14+
#include <chrono>
15+
#include <iostream>
1516
namespace executorch {
1617
namespace backends {
1718
namespace qnn {
@@ -33,6 +34,7 @@ Result<DelegateHandle*> QnnExecuTorchBackend::init(
3334
BackendInitContext& context,
3435
FreeableBuffer* processed,
3536
ArrayRef<CompileSpec> compile_specs) const {
37+
auto start = std::chrono::high_resolution_clock::now();
3638
// covert SizedBuffer to qnn ExecuTorch option
3739
QnnExecuTorchContextBinary qnn_context_blob;
3840
const qnn_delegate::QnnExecuTorchOptions* qnn_executorch_options = nullptr;
@@ -108,6 +110,11 @@ Result<DelegateHandle*> QnnExecuTorchBackend::init(
108110
add_cached_delegate(signature, qnn_manager);
109111
// This backend does not need its processed data after Init.
110112
processed->Free();
113+
auto end = std::chrono::high_resolution_clock::now();
114+
auto int_s = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
115+
116+
std::cout << "[Time consuming during init in QnnBackend] Init Time: " << int_s.count() << " milliseconds"
117+
<< std::endl;
111118
return qnn_manager;
112119
}
113120

backends/qualcomm/runtime/backends/QnnBackendCache.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ Error QnnBackendCache::GetQnnGraphInfoFromBinary(
5151
} else if (binaryinfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
5252
num_graphs = binaryinfo->contextBinaryInfoV2.numGraphs;
5353
graphs = binaryinfo->contextBinaryInfoV2.graphs;
54+
#if (QNN_API_VERSION_MAJOR >= 2 && QNN_API_VERSION_MINOR >= 21)
55+
} else if (binaryinfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
56+
num_graphs = binaryinfo->contextBinaryInfoV3.numGraphs;
57+
graphs = binaryinfo->contextBinaryInfoV3.graphs;
58+
#endif
5459
} else {
5560
QNN_EXECUTORCH_LOG_WARN(
5661
"Unknown QNN BinaryInfo version %d.", binaryinfo->version);
@@ -62,6 +67,10 @@ Error QnnBackendCache::GetQnnGraphInfoFromBinary(
6267
RetrieveGraphInfo<QnnSystemContext_GraphInfoV1_t>(graphs[i].graphInfoV1);
6368
} else if (graphs->version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2) {
6469
RetrieveGraphInfo<QnnSystemContext_GraphInfoV2_t>(graphs[i].graphInfoV2);
70+
#if (QNN_API_VERSION_MAJOR >= 2 && QNN_API_VERSION_MINOR >= 21)
71+
} else if (graphs->version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) {
72+
RetrieveGraphInfo<QnnSystemContext_GraphInfoV3_t>(graphs[i].graphInfoV3);
73+
#endif
6574
} else {
6675
QNN_EXECUTORCH_LOG_WARN(
6776
"Unknown QNN GraphInfo version %d.", binaryinfo->version);

backends/qualcomm/runtime/backends/QnnContextCommon.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <executorch/backends/qualcomm/runtime/backends/QnnBackendCommon.h>
1313
#include <executorch/backends/qualcomm/runtime/backends/QnnCustomProtocol.h>
1414
#include <executorch/backends/qualcomm/runtime/backends/QnnDeviceCommon.h>
15-
15+
#include <executorch/backends/qualcomm/runtime/backends/QnnProfiler.h>
1616
#include <memory>
1717
namespace executorch {
1818
namespace backends {

backends/qualcomm/runtime/backends/htpbackend/HtpBackendCache.cpp

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,52 @@ using executorch::runtime::Error;
1717
Error HtpBackendCache::RetrieveBackendBinaryInfo(
1818
const QnnSystemContext_BinaryInfo_t* binaryinfo) {
1919
QnnHtpSystemContext_HwBlobInfo_t* htp_hwblobinfo = nullptr;
20-
20+
#if (QNN_API_VERSION_MAJOR >= 2 && QNN_API_VERSION_MINOR >= 21)
21+
QnnHtpSystemContext_GraphBlobInfo_t* htp_graphblobinfo = nullptr;
22+
#endif
2123
if (binaryinfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
2224
htp_hwblobinfo = static_cast<QnnHtpSystemContext_HwBlobInfo_t*>(
2325
binaryinfo->contextBinaryInfoV1.hwInfoBlob);
2426
} else if (binaryinfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
2527
htp_hwblobinfo = static_cast<QnnHtpSystemContext_HwBlobInfo_t*>(
2628
binaryinfo->contextBinaryInfoV2.hwInfoBlob);
29+
#if (QNN_API_VERSION_MAJOR >= 2 && QNN_API_VERSION_MINOR >= 21)
30+
} else if (binaryinfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
31+
htp_graphblobinfo = static_cast<QnnHtpSystemContext_GraphBlobInfo_t*>(
32+
binaryinfo->contextBinaryInfoV3.graphs->graphInfoV3.graphBlobInfo);
33+
#endif
2734
} else {
2835
QNN_EXECUTORCH_LOG_WARN(
2936
"Unknown QNN BinaryInfo version %d.", binaryinfo->version);
3037
return Error::Internal;
3138
}
3239

33-
if (htp_hwblobinfo == nullptr) {
34-
QNN_EXECUTORCH_LOG_WARN(
35-
"Htp hardware blob information is not found in binary information.");
36-
return Error::Ok;
40+
if (htp_hwblobinfo) {
41+
if (htp_hwblobinfo->version ==
42+
QNN_SYSTEM_CONTEXT_HTP_HW_INFO_BLOB_VERSION_V1) {
43+
spill_fill_buf_ =
44+
(*htp_hwblobinfo).contextBinaryHwInfoBlobV1_t.spillFillBufferSize;
45+
} else {
46+
QNN_EXECUTORCH_LOG_WARN(
47+
"Unknown QNN Htp hw blob info version %d.", htp_hwblobinfo->version);
48+
return Error::Internal;
49+
}
3750
}
3851

39-
if (htp_hwblobinfo->version ==
40-
QNN_SYSTEM_CONTEXT_HTP_HW_INFO_BLOB_VERSION_V1) {
41-
spill_fill_buf_ =
42-
(*htp_hwblobinfo).contextBinaryHwInfoBlobV1_t.spillFillBufferSize;
43-
} else {
44-
QNN_EXECUTORCH_LOG_WARN(
45-
"Unknown QNN Htp hw blob info version %d.", htp_hwblobinfo->version);
46-
return Error::Internal;
52+
#if (QNN_API_VERSION_MAJOR >= 2 && QNN_API_VERSION_MINOR >= 21)
53+
if (htp_graphblobinfo) {
54+
if (htp_graphblobinfo->version ==
55+
QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_VERSION_V1) {
56+
spill_fill_buf_ =
57+
(*htp_graphblobinfo).contextBinaryGraphBlobInfoV1.spillFillBufferSize;
58+
} else {
59+
QNN_EXECUTORCH_LOG_WARN(
60+
"Unknown QNN Htp graph blob info version %d.",
61+
htp_graphblobinfo->version);
62+
return Error::Internal;
63+
}
4764
}
48-
65+
#endif
4966
return Error::Ok;
5067
}
5168

examples/models/llama/llama_transformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ class ModelArgs:
123123
quantization_args: Optional[dict] = None
124124
lora_args: Optional[dict] = None
125125

126+
use_layer_norm_op: bool = False
127+
use_rms_norm_op: bool = False
128+
126129
def __post_init__(self):
127130
if self.n_kv_heads is None:
128131
self.n_kv_heads = self.n_heads

examples/qualcomm/oss_scripts/llama2/model/static_llama.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,15 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
211211
config=config, output_new_cache_only=output_new_cache_only
212212
)
213213
self.feed_forward = FeedForward(config)
214-
self.attention_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
215-
self.ffn_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
214+
if config.use_layer_norm_op:
215+
self.attention_norm = torch.nn.LayerNorm(self.dim, eps=config.norm_eps)
216+
self.ffn_norm = torch.nn.LayerNorm(self.dim, eps=config.norm_eps)
217+
elif config.use_rms_norm_op:
218+
self.attention_norm = torch.nn.RMSNorm(self.dim, eps=config.norm_eps)
219+
self.ffn_norm = torch.nn.RMSNorm(self.dim, eps=config.norm_eps)
220+
else:
221+
self.attention_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
222+
self.ffn_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
216223

217224
def forward(
218225
self,
@@ -257,7 +264,13 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True):
257264
for _ in range(config.n_layers)
258265
]
259266
)
260-
self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
267+
if config.use_layer_norm_op:
268+
self.norm = torch.nn.LayerNorm(config.dim, eps=config.norm_eps)
269+
elif config.use_rms_norm_op:
270+
self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
271+
else:
272+
self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
273+
261274
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
262275
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
263276
freqs_cos, freqs_sin = precompute_freqs_cis(

examples/qualcomm/oss_scripts/llama3_2/llama.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def _prefill_calibrate(
120120
# TODO: change criteria & support batch inputs if necessary
121121
token_list = sp_model.encode(user_prompts, bos=True, eos=False)
122122
token_list = torch.tensor(token_list)[:max_cache_len].reshape(1, -1)
123+
token_list = torch.where(token_list > 30000, torch.tensor(30000), token_list)
123124
last_prompt_pos = token_list.numel()
124125
if last_prompt_pos < max_cache_len:
125126
token_list = torch.cat(
@@ -168,6 +169,10 @@ def calibrate(
168169
else:
169170
raise RuntimeError("Get wrong inputs")
170171

172+
def get_first_node(node):
173+
if isinstance(node, tuple):
174+
return get_first_node(node[0])
175+
return node
171176

172177
class SingleLlama:
173178
def __init__(self, llama_model, pte_filename) -> None:
@@ -199,9 +204,10 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type, sharding_type):
199204
if (
200205
n.op == "placeholder"
201206
and len(users := list(n.users)) == 1
202-
and users[0].meta["val"].size()[-2:] in input_cache_shape
203-
):
204-
n.meta[QCOM_QUANTIZED_IO] = kv_type
207+
# and users[0].meta["val"].size()[-2:] in input_cache_shape
208+
):
209+
if get_first_node(users[0].meta["val"]).size()[-2:] in input_cache_shape:
210+
n.meta[QCOM_QUANTIZED_IO] = kv_type
205211
elif n.op == "output":
206212
for a in n.args[0]:
207213
# single head, kv mode
@@ -330,13 +336,15 @@ def compile(args, pte_filename):
330336
prefill_config = copy.copy(kv_config)
331337
prefill_config.max_seq_len = args.prefill_seq_len
332338
prefill_config.use_kv_cache = False
333-
334-
state_dict = torch.load(
335-
args.checkpoint, weights_only=True, map_location="cpu", mmap=True
336-
)
339+
340+
# TODO: Currently, we do not load the checkpoint for FLLM
341+
if args.model_arch_device == "meta":
342+
state_dict = torch.load(
343+
args.checkpoint, weights_only=True, map_location="cpu", mmap=True
344+
)
337345

338346
llama_instance_list = []
339-
with torch.device("meta"):
347+
with torch.device(args.model_arch_device):
340348
if args.model_mode == "kv":
341349
llama_instance_list.append(
342350
LlamaModel(kv_config, output_new_cache_only=True)
@@ -355,15 +363,17 @@ def compile(args, pte_filename):
355363
else:
356364
raise RuntimeError(f"No such model_mode {args.model_mode}.")
357365

358-
if "model" in state_dict:
359-
state_dict = state_dict["model"]
366+
# TODO: Currently, we do not load the checkpoint for FLLM
367+
if args.model_arch_device == "meta":
368+
if "model" in state_dict:
369+
state_dict = state_dict["model"]
360370

361-
for llama_instance in llama_instance_list:
362-
llama_instance.load_state_dict(
363-
state_dict,
364-
strict=False,
365-
assign=True,
366-
)
371+
for llama_instance in llama_instance_list:
372+
llama_instance.load_state_dict(
373+
state_dict,
374+
strict=False,
375+
assign=True,
376+
)
367377
end_load_ts = time.time()
368378
logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}")
369379

@@ -689,6 +699,14 @@ def main():
689699
type=int,
690700
)
691701

702+
parser.add_argument(
703+
"--model_arch_device",
704+
help="Specify the device for the model architecture. Use 'meta' for phone LLM (default) and 'cpu' for frane LLM.",
705+
default="meta",
706+
choices=["meta", "cpu"],
707+
type=str,
708+
)
709+
692710
args = parser.parse_args()
693711
if args.compile_only and args.pre_gen_pte:
694712
exit("Cannot set both compile_only and pre_gen_pte as true")

examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -263,14 +263,15 @@ Error Runner::generate(
263263

264264
ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null");
265265

266-
if (!system_prompt.empty()) {
267-
prompt_.append("<|start_header_id|>system<|end_header_id|>\n\n");
268-
prompt_.append(system_prompt);
269-
prompt_.append("<|eot_id|>\n");
270-
}
271-
prompt_.append("<|start_header_id|>user<|end_header_id|>\n\n");
266+
// Only use prompt provided by user
267+
// if (!system_prompt.empty()) {
268+
// prompt_.append("<|start_header_id|>system<|end_header_id|>\n\n");
269+
// prompt_.append(system_prompt);
270+
// prompt_.append("<|eot_id|>\n");
271+
// }
272+
// prompt_.append("<|start_header_id|>user<|end_header_id|>\n\n");
272273
prompt_.append(prompt);
273-
prompt_.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>");
274+
// prompt_.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>");
274275

275276
if (token_callback) {
276277
token_callback("<|begin_of_text|>");
@@ -280,6 +281,13 @@ Error Runner::generate(
280281
seq_len = (seq_len > 0 && seq_len <= max_seq_len) ? seq_len : max_seq_len;
281282
Result<std::vector<uint64_t>> encode_res =
282283
tokenizer_->encode(prompt_, n_bos_, 0);
284+
if (encode_res.ok()) {
285+
for (auto& id : encode_res.get()) {
286+
if (id > 30000) {
287+
id = static_cast<uint64_t>(30000);
288+
}
289+
}
290+
}
283291
ET_CHECK_OK_OR_RETURN_ERROR(
284292
encode_res.error(), "failed to encode prompt %s", prompt_.c_str());
285293

0 commit comments

Comments
 (0)