Skip to content

Commit bd0714b

Browse files
committed
reuse LLM_ARCH and LLM_TENSOR
1 parent 431bb08 commit bd0714b

File tree

5 files changed

+165
-224
lines changed

5 files changed

+165
-224
lines changed

src/llama-arch.cpp

Lines changed: 60 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6363
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
6464
{ LLM_ARCH_CHAMELEON, "chameleon" },
6565
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
66+
{ LLM_ARCH_VISION_LLAVA, "llava" },
67+
{ LLM_ARCH_VISION_MOBILEVLM, "mobilevlm" },
68+
{ LLM_ARCH_VISION_MINICPMV, "minicpmv" },
6669
{ LLM_ARCH_UNKNOWN, "(unknown)" },
6770
};
6871

@@ -1314,77 +1317,75 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
13141317
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
13151318
},
13161319
},
1320+
// vision
13171321
{
1318-
LLM_ARCH_UNKNOWN,
1322+
LLM_ARCH_VISION_LLAVA,
13191323
{
1320-
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1321-
},
1324+
{ LLM_TENSOR_V_MMPROJ, "v.mmproj_%d" },
1325+
{ LLM_TENSOR_V_ENC_EMBD_CLS, "v.enc.embd.cls" },
1326+
{ LLM_TENSOR_V_ENC_EMBD_PATCH, "v.enc.embd.patch" },
1327+
{ LLM_TENSOR_V_ENC_EMBD_POS, "v.enc.embd.pos" },
1328+
{ LLM_TENSOR_V_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" },
1329+
{ LLM_TENSOR_V_ENC_ATTN_K, "v.enc.blk.%d.attn_k" },
1330+
{ LLM_TENSOR_V_ENC_ATTN_V, "v.enc.blk.%d.attn_v" },
1331+
{ LLM_TENSOR_V_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" },
1332+
{ LLM_TENSOR_V_ENC_OUTPUT, "v.enc.blk.%d.output" },
1333+
{ LLM_TENSOR_V_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" },
1334+
{ LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
1335+
{ LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
1336+
{ LLM_TENSOR_V_PRE_NORM, "v.pre_norm" },
1337+
{ LLM_TENSOR_V_POST_NORM, "v.post_norm" },
1338+
}
13221339
},
1323-
};
1324-
1325-
static const std::map<vision_arch, std::map<vision_tensor, const char *>> VISION_TENSOR_NAMES = {
13261340
{
1327-
VISION_ARCH_LLAVA,
1341+
LLM_ARCH_VISION_MOBILEVLM,
13281342
{
1329-
{ VISION_TENSOR_MMPROJ, "v.mmproj_%d" },
1330-
{ VISION_TENSOR_ENC_EMBD_CLS, "v.enc.embd.cls" },
1331-
{ VISION_TENSOR_ENC_EMBD_PATCH, "v.enc.embd.patch" },
1332-
{ VISION_TENSOR_ENC_EMBD_POS, "v.enc.embd.pos" },
1333-
{ VISION_TENSOR_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" },
1334-
{ VISION_TENSOR_ENC_ATTN_K, "v.enc.blk.%d.attn_k" },
1335-
{ VISION_TENSOR_ENC_ATTN_V, "v.enc.blk.%d.attn_v" },
1336-
{ VISION_TENSOR_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" },
1337-
{ VISION_TENSOR_ENC_OUTPUT, "v.enc.blk.%d.output" },
1338-
{ VISION_TENSOR_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" },
1339-
{ VISION_TENSOR_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
1340-
{ VISION_TENSOR_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
1341-
{ VISION_TENSOR_PRE_NORM, "v.pre_norm" },
1342-
{ VISION_TENSOR_POST_NORM, "v.post_norm" },
1343+
{ LLM_TENSOR_V_MMPROJ_MLP, "v.mmproj.mlp.%d" },
1344+
{ LLM_TENSOR_V_MMPROJ_PEG, "v.mmproj.peg.%d" },
1345+
{ LLM_TENSOR_V_ENC_EMBD_CLS, "v.enc.embd.cls" },
1346+
{ LLM_TENSOR_V_ENC_EMBD_PATCH, "v.enc.embd.patch" },
1347+
{ LLM_TENSOR_V_ENC_EMBD_POS, "v.enc.embd.pos" },
1348+
{ LLM_TENSOR_V_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" },
1349+
{ LLM_TENSOR_V_ENC_ATTN_K, "v.enc.blk.%d.attn_k" },
1350+
{ LLM_TENSOR_V_ENC_ATTN_V, "v.enc.blk.%d.attn_v" },
1351+
{ LLM_TENSOR_V_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" },
1352+
{ LLM_TENSOR_V_ENC_OUTPUT, "v.enc.blk.%d.output" },
1353+
{ LLM_TENSOR_V_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" },
1354+
{ LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
1355+
{ LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
1356+
{ LLM_TENSOR_V_PRE_NORM, "v.pre_norm" },
1357+
{ LLM_TENSOR_V_POST_NORM, "v.post_norm" },
13431358
}
13441359
},
13451360
{
1346-
VISION_ARCH_MOBILEVLM,
1361+
LLM_ARCH_VISION_MINICPMV,
13471362
{
1348-
{ VISION_TENSOR_MMPROJ_MLP, "v.mmproj.mlp.%d" },
1349-
{ VISION_TENSOR_MMPROJ_PEG, "v.mmproj.peg.%d" },
1350-
{ VISION_TENSOR_ENC_EMBD_CLS, "v.enc.embd.cls" },
1351-
{ VISION_TENSOR_ENC_EMBD_PATCH, "v.enc.embd.patch" },
1352-
{ VISION_TENSOR_ENC_EMBD_POS, "v.enc.embd.pos" },
1353-
{ VISION_TENSOR_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" },
1354-
{ VISION_TENSOR_ENC_ATTN_K, "v.enc.blk.%d.attn_k" },
1355-
{ VISION_TENSOR_ENC_ATTN_V, "v.enc.blk.%d.attn_v" },
1356-
{ VISION_TENSOR_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" },
1357-
{ VISION_TENSOR_ENC_OUTPUT, "v.enc.blk.%d.output" },
1358-
{ VISION_TENSOR_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" },
1359-
{ VISION_TENSOR_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
1360-
{ VISION_TENSOR_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
1361-
{ VISION_TENSOR_PRE_NORM, "v.pre_norm" },
1362-
{ VISION_TENSOR_POST_NORM, "v.post_norm" },
1363+
{ LLM_TENSOR_V_ENC_EMBD_PATCH, "v.enc.embd.patch" },
1364+
{ LLM_TENSOR_V_ENC_EMBD_POS, "v.enc.embd.pos" },
1365+
{ LLM_TENSOR_V_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" },
1366+
{ LLM_TENSOR_V_ENC_ATTN_K, "v.enc.blk.%d.attn_k" },
1367+
{ LLM_TENSOR_V_ENC_ATTN_V, "v.enc.blk.%d.attn_v" },
1368+
{ LLM_TENSOR_V_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" },
1369+
{ LLM_TENSOR_V_ENC_OUTPUT, "v.enc.blk.%d.output" },
1370+
{ LLM_TENSOR_V_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" },
1371+
{ LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
1372+
{ LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
1373+
{ LLM_TENSOR_V_RESMPL_POS_EMBD_K, "v.resmpl.pos_embd_k" },
1374+
{ LLM_TENSOR_V_RESMPL_ATTN_IN, "v.resmpl.attn_in" },
1375+
{ LLM_TENSOR_V_RESMPL_ATTN_OUT, "v.resmpl.attn_out" },
1376+
{ LLM_TENSOR_V_RESMPL_KV_PROJ, "v.resmpl.kv_proj" },
1377+
{ LLM_TENSOR_V_RESMPL_NORM_POST, "v.resmpl.norm_post" },
1378+
{ LLM_TENSOR_V_RESMPL_NORM_KV, "v.resmpl.norm_kv" },
1379+
{ LLM_TENSOR_V_RESMPL_NORM_Q, "v.resmpl.norm_q" },
1380+
{ LLM_TENSOR_V_RESMPL_PROJ, "v.resmpl.proj" },
1381+
{ LLM_TENSOR_V_RESMPL_QUERY, "v.resmpl.query" },
13631382
}
13641383
},
13651384
{
1366-
VISION_ARCH_MINICPMV,
1385+
LLM_ARCH_UNKNOWN,
13671386
{
1368-
{ VISION_TENSOR_ENC_EMBD_PATCH, "v.enc.embd.patch" },
1369-
{ VISION_TENSOR_ENC_EMBD_POS, "v.enc.embd.pos" },
1370-
{ VISION_TENSOR_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" },
1371-
{ VISION_TENSOR_ENC_ATTN_K, "v.enc.blk.%d.attn_k" },
1372-
{ VISION_TENSOR_ENC_ATTN_V, "v.enc.blk.%d.attn_v" },
1373-
{ VISION_TENSOR_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" },
1374-
{ VISION_TENSOR_ENC_OUTPUT, "v.enc.blk.%d.output" },
1375-
{ VISION_TENSOR_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" },
1376-
{ VISION_TENSOR_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
1377-
{ VISION_TENSOR_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
1378-
{ VISION_TENSOR_RESMPL_POS_EMBD_K, "v.resmpl.pos_embd_k" },
1379-
{ VISION_TENSOR_RESMPL_ATTN_IN, "v.resmpl.attn_in" },
1380-
{ VISION_TENSOR_RESMPL_ATTN_OUT, "v.resmpl.attn_out" },
1381-
{ VISION_TENSOR_RESMPL_KV_PROJ, "v.resmpl.kv_proj" },
1382-
{ VISION_TENSOR_RESMPL_NORM_POST, "v.resmpl.norm_post" },
1383-
{ VISION_TENSOR_RESMPL_NORM_KV, "v.resmpl.norm_kv" },
1384-
{ VISION_TENSOR_RESMPL_NORM_Q, "v.resmpl.norm_q" },
1385-
{ VISION_TENSOR_RESMPL_PROJ, "v.resmpl.proj" },
1386-
{ VISION_TENSOR_RESMPL_QUERY, "v.resmpl.query" },
1387-
}
1387+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1388+
},
13881389
},
13891390
};
13901391

@@ -1537,12 +1538,7 @@ std::string LLM_KV::operator()(llm_kv kv) const {
15371538
return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
15381539
}
15391540

1540-
template<>
1541-
std::string BASE_TN_IMPL<llm_arch, llm_tensor>::str() const {
1542-
if (LLM_TENSOR_NAMES.find(arch) == LLM_TENSOR_NAMES.end()) {
1543-
throw std::runtime_error(format("Cannot find tensor name mapping for arch %d", arch));
1544-
}
1545-
1541+
std::string LLM_TN_IMPL::str() const {
15461542
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
15471543
return "__missing__";
15481544
}
@@ -1557,26 +1553,6 @@ std::string BASE_TN_IMPL<llm_arch, llm_tensor>::str() const {
15571553
return name;
15581554
}
15591555

1560-
template<>
1561-
std::string BASE_TN_IMPL<vision_arch, vision_tensor>::str() const {
1562-
if (VISION_TENSOR_NAMES.find(arch) == VISION_TENSOR_NAMES.end()) {
1563-
throw std::runtime_error(format("Cannot find tensor name mapping for arch %d", arch));
1564-
}
1565-
1566-
if (VISION_TENSOR_NAMES.at(arch).find(tensor) == VISION_TENSOR_NAMES.at(arch).end()) {
1567-
return "__missing__";
1568-
}
1569-
1570-
std::string name = ::format(VISION_TENSOR_NAMES.at(arch).at(tensor), bid, xid);
1571-
1572-
if (suffix != nullptr) {
1573-
name += ".";
1574-
name += suffix;
1575-
}
1576-
1577-
return name;
1578-
}
1579-
15801556
const char * llm_arch_name(llm_arch arch) {
15811557
auto it = LLM_ARCH_NAMES.find(arch);
15821558
if (it == LLM_ARCH_NAMES.end()) {

src/llama-arch.h

Lines changed: 36 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,13 @@ enum llm_arch {
6666
LLM_ARCH_GRANITE_MOE,
6767
LLM_ARCH_CHAMELEON,
6868
LLM_ARCH_WAVTOKENIZER_DEC,
69+
// vision
70+
LLM_ARCH_VISION_LLAVA,
71+
LLM_ARCH_VISION_MOBILEVLM,
72+
LLM_ARCH_VISION_MINICPMV,
6973
LLM_ARCH_UNKNOWN,
7074
};
7175

72-
enum vision_arch {
73-
VISION_ARCH_UNKNOWN,
74-
VISION_ARCH_LLAVA,
75-
VISION_ARCH_MOBILEVLM,
76-
VISION_ARCH_MINICPMV,
77-
};
78-
7976
enum llm_kv {
8077
LLM_KV_GENERAL_TYPE,
8178
LLM_KV_GENERAL_ARCHITECTURE,
@@ -354,35 +351,33 @@ enum llm_tensor {
354351
LLM_TENSOR_POS_NET_ATTN_K,
355352
LLM_TENSOR_POS_NET_ATTN_V,
356353
LLM_TENSOR_POS_NET_ATTN_OUT,
357-
};
358-
359-
enum vision_tensor {
360-
VISION_TENSOR_MMPROJ,
361-
VISION_TENSOR_MMPROJ_MLP,
362-
VISION_TENSOR_MMPROJ_PEG,
363-
VISION_TENSOR_ENC_EMBD_CLS,
364-
VISION_TENSOR_ENC_EMBD_PATCH,
365-
VISION_TENSOR_ENC_EMBD_POS,
366-
VISION_TENSOR_ENC_ATTN_Q,
367-
VISION_TENSOR_ENC_ATTN_K,
368-
VISION_TENSOR_ENC_ATTN_V,
369-
VISION_TENSOR_ENC_INPUT_NORM,
370-
VISION_TENSOR_ENC_OUTPUT,
371-
VISION_TENSOR_ENC_OUTPUT_NORM,
372-
VISION_TENSOR_ENC_FFN_UP,
373-
VISION_TENSOR_ENC_FFN_DOWN,
374-
VISION_TENSOR_PRE_NORM,
375-
VISION_TENSOR_POST_NORM,
376-
// minicpmv
377-
VISION_TENSOR_RESMPL_POS_EMBD_K,
378-
VISION_TENSOR_RESMPL_ATTN_IN,
379-
VISION_TENSOR_RESMPL_ATTN_OUT,
380-
VISION_TENSOR_RESMPL_KV_PROJ,
381-
VISION_TENSOR_RESMPL_NORM_POST,
382-
VISION_TENSOR_RESMPL_NORM_KV,
383-
VISION_TENSOR_RESMPL_NORM_Q,
384-
VISION_TENSOR_RESMPL_PROJ,
385-
VISION_TENSOR_RESMPL_QUERY,
354+
// vision
355+
LLM_TENSOR_V_MMPROJ,
356+
LLM_TENSOR_V_MMPROJ_MLP,
357+
LLM_TENSOR_V_MMPROJ_PEG,
358+
LLM_TENSOR_V_ENC_EMBD_CLS,
359+
LLM_TENSOR_V_ENC_EMBD_PATCH,
360+
LLM_TENSOR_V_ENC_EMBD_POS,
361+
LLM_TENSOR_V_ENC_ATTN_Q,
362+
LLM_TENSOR_V_ENC_ATTN_K,
363+
LLM_TENSOR_V_ENC_ATTN_V,
364+
LLM_TENSOR_V_ENC_INPUT_NORM,
365+
LLM_TENSOR_V_ENC_OUTPUT,
366+
LLM_TENSOR_V_ENC_OUTPUT_NORM,
367+
LLM_TENSOR_V_ENC_FFN_UP,
368+
LLM_TENSOR_V_ENC_FFN_DOWN,
369+
LLM_TENSOR_V_PRE_NORM,
370+
LLM_TENSOR_V_POST_NORM,
371+
// vision - minicpmv
372+
LLM_TENSOR_V_RESMPL_POS_EMBD_K,
373+
LLM_TENSOR_V_RESMPL_ATTN_IN,
374+
LLM_TENSOR_V_RESMPL_ATTN_OUT,
375+
LLM_TENSOR_V_RESMPL_KV_PROJ,
376+
LLM_TENSOR_V_RESMPL_NORM_POST,
377+
LLM_TENSOR_V_RESMPL_NORM_KV,
378+
LLM_TENSOR_V_RESMPL_NORM_Q,
379+
LLM_TENSOR_V_RESMPL_PROJ,
380+
LLM_TENSOR_V_RESMPL_QUERY,
386381
};
387382

388383
enum llm_tensor_layer {
@@ -408,10 +403,9 @@ struct LLM_KV {
408403
// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias"
409404
// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight"
410405
//
411-
template<typename Tname, typename Ttensor>
412-
struct BASE_TN_IMPL {
413-
const Tname arch;
414-
const Ttensor tensor;
406+
struct LLM_TN_IMPL {
407+
const llm_arch arch;
408+
const llm_tensor tensor;
415409
const char * const suffix;
416410
const int bid;
417411
const int xid;
@@ -422,16 +416,15 @@ struct BASE_TN_IMPL {
422416
return str();
423417
}
424418

425-
friend bool operator==(const std::string & str, const BASE_TN_IMPL & tn) {
419+
friend bool operator==(const std::string & str, const LLM_TN_IMPL & tn) {
426420
return str == tn.str();
427421
}
428422

429-
friend bool operator!=(const std::string & str, const BASE_TN_IMPL & tn) {
423+
friend bool operator!=(const std::string & str, const LLM_TN_IMPL & tn) {
430424
return str != tn.str();
431425
}
432426
};
433427

434-
using LLM_TN_IMPL = BASE_TN_IMPL<llm_arch, llm_tensor>;
435428
struct LLM_TN {
436429
LLM_TN(llm_arch arch) : arch(arch) {}
437430

@@ -446,20 +439,6 @@ struct LLM_TN {
446439
}
447440
};
448441

449-
struct VISION_TN {
450-
VISION_TN(vision_arch arch) : arch(arch) {}
451-
452-
vision_arch arch;
453-
454-
BASE_TN_IMPL<vision_arch, vision_tensor> operator()(vision_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const {
455-
return { arch, tensor, suffix, bid, xid };
456-
}
457-
458-
BASE_TN_IMPL<vision_arch, vision_tensor> operator()(vision_tensor tensor, int bid = -1, int xid = -1) const {
459-
return { arch, tensor, nullptr, bid, xid };
460-
}
461-
};
462-
463442

464443
struct llm_tensor_info {
465444
llm_tensor_layer layer;
@@ -470,6 +449,4 @@ const char * llm_arch_name(llm_arch arch);
470449

471450
llm_arch llm_arch_from_string(const std::string & name);
472451

473-
vision_arch vision_arch_from_string(const std::string & name);
474-
475452
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);

0 commit comments

Comments
 (0)