Skip to content

Commit d9f1d13

Browse files
committed
add gguf_init_from_file_ext impl
1 parent 36f2215 commit d9f1d13

File tree

1 file changed

+49
-14
lines changed

1 file changed

+49
-14
lines changed

ggml/src/gguf.cpp

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector<struct
316316
return true;
317317
}
318318

319-
struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
319+
static struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params, tensor_shape_read_cb_t on_tensor_shape_read) {
320320
const struct gguf_reader gr(file);
321321
struct gguf_context * ctx = new gguf_context;
322322

@@ -525,25 +525,52 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
525525
{
526526
uint32_t n_dims = -1;
527527
ok = ok && gr.read(n_dims);
528-
if (n_dims > GGML_MAX_DIMS) {
529-
GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
530-
__func__, info.t.name, n_dims, GGML_MAX_DIMS);
531-
ok = false;
532-
break;
533-
}
534-
for (uint32_t j = 0; ok && j < GGML_MAX_DIMS; ++j) {
535-
info.t.ne[j] = 1;
528+
529+
std::vector<int64_t> ne(n_dims);
530+
for (uint32_t j = 0; ok && j < n_dims; ++j) {
531+
ne[j] = 1;
536532
if (j < n_dims) {
537-
ok = ok && gr.read(info.t.ne[j]);
533+
ok = ok && gr.read(ne[j]);
538534
}
539535

540536
// check that all ne are non-negative
541-
if (info.t.ne[j] < 0) {
537+
if (ne[j] < 0) {
542538
GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
543-
__func__, info.t.name, j, info.t.ne[j]);
539+
__func__, info.t.name, j, ne[j]);
540+
ok = false;
541+
break;
542+
}
543+
}
544+
545+
if (!ok) {
546+
break;
547+
}
548+
549+
if (on_tensor_shape_read) {
550+
gguf_tensor_shape shape;
551+
ok = on_tensor_shape_read(ne.data(), n_dims, &shape);
552+
if (!ok) {
553+
GGML_LOG_ERROR("%s: tensor '%s' on_tensor_shape_read return false \n",
554+
__func__, info.t.name);
555+
break;
556+
}
557+
for (uint32_t j = 0; j < GGML_MAX_DIMS; ++j) {
558+
info.t.ne[j] = shape.ne[j];
559+
}
560+
} else {
561+
if (n_dims > GGML_MAX_DIMS) {
562+
GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
563+
__func__, info.t.name, n_dims, GGML_MAX_DIMS);
544564
ok = false;
545565
break;
546566
}
567+
for (uint32_t j = 0; j < GGML_MAX_DIMS; ++j) {
568+
if (j < n_dims) {
569+
info.t.ne[j] = ne[j];
570+
} else {
571+
info.t.ne[j] = 1;
572+
}
573+
}
547574
}
548575

549576
// check that the total number of elements is representable
@@ -730,19 +757,27 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
730757
return ctx;
731758
}
732759

733-
struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
760+
struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
761+
return gguf_init_from_file_impl(file, params, nullptr);
762+
}
763+
764+
struct gguf_context * gguf_init_from_file_ext(const char * fname, struct gguf_init_params params, tensor_shape_read_cb_t on_tensor_shape_read) {
734765
FILE * file = ggml_fopen(fname, "rb");
735766

736767
if (!file) {
737768
GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname);
738769
return nullptr;
739770
}
740771

741-
struct gguf_context * result = gguf_init_from_file_impl(file, params);
772+
struct gguf_context * result = gguf_init_from_file_impl(file, params, on_tensor_shape_read);
742773
fclose(file);
743774
return result;
744775
}
745776

777+
struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
778+
return gguf_init_from_file_ext(fname, params, NULL);;
779+
}
780+
746781
void gguf_free(struct gguf_context * ctx) {
747782
if (ctx == nullptr) {
748783
return;

0 commit comments

Comments
 (0)