@@ -227,7 +227,7 @@ struct gguf_value {
227227 for (size_t i = 0 ; i < arr_size; ++i) {
228228 memcpy (data.data () + type_size * i, &(*value.array )[i].value , type_size);
229229 }
230- gguf_set_arr_data (ctx, k, arr_type, data.data (), data.size ());
230+ gguf_set_arr_data (ctx, k, arr_type, data.data (), data.size () / type_size );
231231 }
232232 // TODO: handle nested arrays
233233 }
@@ -317,7 +317,12 @@ struct model_variant {
317317 gguf_add_tensor (ctx_gguf, tensor);
318318 }
319319
320- return gguf_write_to_file (ctx_gguf, fname, false );
320+ bool status = gguf_write_to_file (ctx_gguf, fname, false );
321+
322+ ggml_free (ctx);
323+ gguf_free (ctx_gguf);
324+
325+ return status;
321326 }
322327
323328 static void insert_from_arch (std::vector<model_variant> & variants, llm_arch arch) {
@@ -762,9 +767,8 @@ int main(int argc, char ** argv) {
762767 std::mt19937 rng (42 );
763768
764769 // TODO: multiple sequences per token
765- const int64_t n_batch = 2048 ;
766- const int64_t n_seq_len = 1024 ;
767- std::uniform_int_distribution<int64_t > rand_seq_init_len (n_seq_len / 4 , 3 * n_seq_len / 4 );
770+ const int32_t n_batch = 2048 ;
771+ const int32_t n_seq_len = 1024 ;
768772
769773 llama_batch batch = llama_batch_init (n_batch, 0 , 1 );
770774 // TODO: batch with embeddings
@@ -794,10 +798,10 @@ int main(int argc, char ** argv) {
794798 // const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
795799 // const auto n_embd = llama_model_n_embd(model);
796800
797- for (int64_t n_seq_max : { 1 , 2 , 13 } ) {
801+ for (int32_t n_seq_max : { 1 , 2 , 13 } ) {
798802
799803 // TODO(later): context shift testing
800- for (int64_t n_ctx : { n_seq_len * n_seq_max }) {
804+ for (int32_t n_ctx : { n_seq_len * n_seq_max }) {
801805
802806 std::vector<reference_logits> ref_outputs;
803807
@@ -824,7 +828,7 @@ int main(int argc, char ** argv) {
824828
825829 for (bool shuffle : { false , true }) {
826830
827- for (int64_t n_ubatch : { 1 , 2 , 512 } ) {
831+ for (int32_t n_ubatch : { 1 , 2 , 512 } ) {
828832
829833 std::vector<bool > valid (n_seq_max, true );
830834
@@ -852,7 +856,7 @@ int main(int argc, char ** argv) {
852856 if (batch.n_tokens < n_batch) {
853857 const int64_t seq_len =
854858 std::min (n_batch - batch.n_tokens ,
855- ( int64_t ) ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]);
859+ ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]);
856860
857861 ref_outputs[seq_id].add_to_batch (batch, seq_id_n_past[seq_id], seq_len, seq_id);
858862 seq_ids_in_batch.insert (seq_id);
@@ -891,7 +895,7 @@ int main(int argc, char ** argv) {
891895 }
892896
893897 fprintf (stdout,
894- " Comparing output for '%s', with shuffle=%i, n_seq_max=%li , n_ctx=%li , n_ubatch=%li : " ,
898+ " Comparing output for '%s', with shuffle=%i, n_seq_max=%i , n_ctx=%i , n_ubatch=%i : " ,
895899 variant.name .c_str (), shuffle, n_seq_max, n_ctx, n_ubatch);
896900 if (std::all_of (valid.begin (), valid.end (), [](bool v) { return v; })) {
897901 fprintf (stdout, " \033 [1;32mOK\033 [0m\n " );
0 commit comments