Skip to content

Commit 5b6bf2f

Browse files
authored
move string array functionality into model-loader
1 parent 41049e6 commit 5b6bf2f

File tree

2 files changed

+47
-30
lines changed

2 files changed

+47
-30
lines changed

src/llama-model-loader.cpp

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -288,61 +288,84 @@ namespace GGUFMeta {
288288

289289
template<typename T>
290290
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
291-
const int kid = gguf_find_key(meta.get(), key.c_str());
291+
const gguf_context * ctx = meta.get();
292+
const int kid = gguf_find_key(ctx, key.c_str());
292293

293-
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
294+
if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
294295
if (required) {
295296
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
296297
}
297298
return false;
298299
}
299300

300301
struct GGUFMeta::ArrayInfo arr_info =
301-
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
302+
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
302303

303304
switch (arr_info.gt) {
304305
case GGUF_TYPE_UINT32:
305-
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
306-
(std::is_same<T, uint32_t>::value)); break;
307-
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
306+
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
307+
(std::is_same<T, uint32_t>::value)); break;
308+
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
309+
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
308310
default:
309-
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
311+
throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
310312
}
311313

312-
result.resize(arr_info.length);
313-
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
314+
if constexpr (std::is_same<T, std::string>::value) {
315+
const size_t n_items = gguf_get_arr_n(ctx, kid);
316+
result.clear();
317+
318+
for (size_t i = 0; i < n_items; i++) {
319+
const T value = gguf_get_arr_str(ctx, kid, i);
320+
result.emplace_back(value);
321+
}
322+
} else {
323+
result.resize(arr_info.length);
324+
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
325+
}
314326

315327
return true;
316328
}
317329

318330
template<typename T, size_t N_MAX>
319331
bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
320-
const int kid = gguf_find_key(meta.get(), key.c_str());
332+
const gguf_context * ctx = meta.get();
333+
const int kid = gguf_find_key(ctx, key.c_str());
321334

322-
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
335+
if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
323336
if (required) {
324337
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
325338
}
326339
return false;
327340
}
328341

329342
struct GGUFMeta::ArrayInfo arr_info =
330-
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
343+
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
331344

332345
switch (arr_info.gt) {
333346
case GGUF_TYPE_UINT32:
334-
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
335-
(std::is_same<T, uint32_t>::value)); break;
336-
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
347+
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
348+
(std::is_same<T, uint32_t>::value)); break;
349+
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
350+
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
337351
default:
338-
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
352+
throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
339353
}
340354

341355
if (arr_info.length > N_MAX) {
342356
throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
343357
}
344358

345-
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
359+
if constexpr (std::is_same<T, std::string>::value) {
360+
const size_t n_items = gguf_get_arr_n(meta.get(), kid);
361+
362+
for (size_t i = 0; i < n_items; i++) {
363+
const T value = gguf_get_arr_str(meta.get(), kid, i);
364+
result[i] = value;
365+
}
366+
} else {
367+
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
368+
}
346369

347370
return true;
348371
}
@@ -352,6 +375,8 @@ namespace GGUFMeta {
352375
return get_arr(llm_kv(kid), result, required);
353376
}
354377

378+
template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
379+
355380
template<typename T>
356381
bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
357382
auto it = kv_overrides.find(key);

src/llama-model.cpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -425,22 +425,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
425425

426426
// get metadata as string
427427
for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
428-
const char * name = gguf_get_key(ctx, i);
429428
gguf_type type = gguf_get_kv_type(ctx, i);
430-
431429
if (type == GGUF_TYPE_ARRAY) {
432-
if (LLM_KV(arch)(LLM_KV_CLASSIFIER_OUTPUT_LABELS) == name) {
433-
const size_t n_items = gguf_get_arr_n(ctx, i);
434-
435-
for (size_t j = 0; j < n_items; j++) {
436-
const std::string value = gguf_get_arr_str(ctx, i, j);
437-
classifier_labels.emplace_back(value);
438-
}
439-
}
440-
} else {
441-
const std::string value = gguf_kv_to_str(ctx, i);
442-
gguf_kv.emplace(name, value);
430+
continue;
443431
}
432+
const char * name = gguf_get_key(ctx, i);
433+
const std::string value = gguf_kv_to_str(ctx, i);
434+
gguf_kv.emplace(name, value);
444435
}
445436

446437
// get general kv
@@ -553,6 +544,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
553544
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
554545

555546
// for classifier models
547+
ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
556548
if (!classifier_labels.empty()) {
557549
hparams.n_cls_out = classifier_labels.size();
558550
}

0 commit comments

Comments
 (0)