|
26 | 26 | convert_linear_to_conv2d, |
27 | 27 | generate_htp_compiler_spec, |
28 | 28 | generate_qnn_executorch_compiler_spec, |
| 29 | + get_soc_to_chipset_map, |
29 | 30 | ) |
30 | 31 | from executorch.examples.qualcomm.oss_scripts.llama2.model.static_llama import ( |
31 | 32 | LlamaModel, |
|
47 | 48 | from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e |
48 | 49 |
|
49 | 50 |
|
50 | | -soc_to_chipset_map = { |
51 | | - "SSG2115P": QcomChipset.SSG2115P, |
52 | | - "SM8650": QcomChipset.SM8650, |
53 | | - "SM8550": QcomChipset.SM8550, |
54 | | - "SM8475": QcomChipset.SM8475, |
55 | | - "SM8450": QcomChipset.SM8450, |
56 | | -} |
57 | | - |
58 | | - |
59 | 51 | pte_filename = "llama2_qnn" |
60 | 52 |
|
61 | 53 |
|
@@ -402,7 +394,7 @@ def compile(args): |
402 | 394 | end_quantize_ts = time.time() |
403 | 395 | print("single_llama.quantize(quant_dtype)", end_quantize_ts - start_quantize_ts) |
404 | 396 | single_llama.lowering_modules( |
405 | | - args.artifact, kv_type=kv_type, soc_model=soc_to_chipset_map[args.model] |
| 397 | + args.artifact, kv_type=kv_type, soc_model=get_soc_to_chipset_map[args.model] |
406 | 398 | ) |
407 | 399 | end_lowering_ts = time.time() |
408 | 400 | print("Complete Compile", end_lowering_ts - end_quantize_ts) |
|
0 commit comments