@@ -316,7 +316,7 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector<struct
316316 return true ;
317317}
318318
319- struct gguf_context * gguf_init_from_file_impl (FILE * file, struct gguf_init_params params) {
319+ static struct gguf_context * gguf_init_from_file_impl (FILE * file, struct gguf_init_params params, tensor_shape_read_cb_t on_tensor_shape_read ) {
320320 const struct gguf_reader gr (file);
321321 struct gguf_context * ctx = new gguf_context;
322322
@@ -525,25 +525,52 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
525525 {
526526 uint32_t n_dims = -1 ;
527527 ok = ok && gr.read (n_dims);
528- if (n_dims > GGML_MAX_DIMS) {
529- GGML_LOG_ERROR (" %s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 " \n " ,
530- __func__, info.t .name , n_dims, GGML_MAX_DIMS);
531- ok = false ;
532- break ;
533- }
534- for (uint32_t j = 0 ; ok && j < GGML_MAX_DIMS; ++j) {
535- info.t .ne [j] = 1 ;
528+
529+ std::vector<int64_t > ne (n_dims);
530+ for (uint32_t j = 0 ; ok && j < n_dims; ++j) {
531+ ne[j] = 1 ;
536532 if (j < n_dims) {
537- ok = ok && gr.read (info. t . ne [j]);
533+ ok = ok && gr.read (ne[j]);
538534 }
539535
540536 // check that all ne are non-negative
541- if (info. t . ne [j] < 0 ) {
537+ if (ne[j] < 0 ) {
542538 GGML_LOG_ERROR (" %s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n " ,
543- __func__, info.t .name , j, info.t .ne [j]);
539+ __func__, info.t .name , j, ne[j]);
540+ ok = false ;
541+ break ;
542+ }
543+ }
544+
545+ if (!ok) {
546+ break ;
547+ }
548+
549+ if (on_tensor_shape_read) {
550+ gguf_tensor_shape shape;
551+ ok = on_tensor_shape_read (ne.data (), n_dims, &shape);
552+ if (!ok) {
553+ GGML_LOG_ERROR (" %s: tensor '%s' on_tensor_shape_read return false \n " ,
554+ __func__, info.t .name );
555+ break ;
556+ }
557+ for (uint32_t j = 0 ; j < GGML_MAX_DIMS; ++j) {
558+ info.t .ne [j] = shape.ne [j];
559+ }
560+ } else {
561+ if (n_dims > GGML_MAX_DIMS) {
562+ GGML_LOG_ERROR (" %s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 " \n " ,
563+ __func__, info.t .name , n_dims, GGML_MAX_DIMS);
544564 ok = false ;
545565 break ;
546566 }
567+ for (uint32_t j = 0 ; j < GGML_MAX_DIMS; ++j) {
568+ if (j < n_dims) {
569+ info.t .ne [j] = ne[j];
570+ } else {
571+ info.t .ne [j] = 1 ;
572+ }
573+ }
547574 }
548575
549576 // check that the total number of elements is representable
@@ -730,19 +757,27 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
730757 return ctx;
731758}
732759
733- struct gguf_context * gguf_init_from_file (const char * fname, struct gguf_init_params params) {
760+ struct gguf_context * gguf_init_from_file_impl (FILE * file, struct gguf_init_params params) {
761+ return gguf_init_from_file_impl (file, params, nullptr );
762+ }
763+
764+ struct gguf_context * gguf_init_from_file_ext (const char * fname, struct gguf_init_params params, tensor_shape_read_cb_t on_tensor_shape_read) {
734765 FILE * file = ggml_fopen (fname, " rb" );
735766
736767 if (!file) {
737768 GGML_LOG_ERROR (" %s: failed to open GGUF file '%s'\n " , __func__, fname);
738769 return nullptr ;
739770 }
740771
741- struct gguf_context * result = gguf_init_from_file_impl (file, params);
772+ struct gguf_context * result = gguf_init_from_file_impl (file, params, on_tensor_shape_read );
742773 fclose (file);
743774 return result;
744775}
745776
777+ struct gguf_context * gguf_init_from_file (const char * fname, struct gguf_init_params params) {
778+ return gguf_init_from_file_ext (fname, params, NULL );;
779+ }
780+
746781void gguf_free (struct gguf_context * ctx) {
747782 if (ctx == nullptr ) {
748783 return ;
0 commit comments