Skip to content

Commit 7657835

Browse files
committed
tests : fix overflow and memory leaks in test-model-random
* tests : fix integer types in test-model-random
1 parent 9cd402c commit 7657835

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

src/llama-model.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8847,9 +8847,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
88478847
};
88488848

88498849
struct llm_build_mamba : public llm_graph_context {
8850-
const llama_model & model;
8851-
8852-
llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
8850+
llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
88538851
ggml_tensor * cur;
88548852
ggml_tensor * inpL;
88558853

@@ -8865,7 +8863,7 @@ struct llm_build_mamba : public llm_graph_context {
88658863
LLM_NORM_RMS, il);
88668864
cb(cur, "attn_norm", il);
88678865

8868-
cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
8866+
cur = build_mamba_layer(gf, cur, state_copy, model, ubatch, il);
88698867

88708868
if (il == n_layer - 1) {
88718869
// skip computing output for unused tokens
@@ -8906,6 +8904,7 @@ struct llm_build_mamba : public llm_graph_context {
89068904
ggml_cgraph * gf,
89078905
ggml_tensor * cur,
89088906
ggml_tensor * state_copy,
8907+
const llama_model & model,
89098908
const llama_ubatch & ubatch,
89108909
int il) const {
89118910
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);

tests/test-model-random.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ struct gguf_value {
227227
for (size_t i = 0; i < arr_size; ++i) {
228228
memcpy(data.data() + type_size * i, &(*value.array)[i].value, type_size);
229229
}
230-
gguf_set_arr_data(ctx, k, arr_type, data.data(), data.size());
230+
gguf_set_arr_data(ctx, k, arr_type, data.data(), data.size() / type_size);
231231
}
232232
// TODO: handle nested arrays
233233
}
@@ -317,7 +317,12 @@ struct model_variant {
317317
gguf_add_tensor(ctx_gguf, tensor);
318318
}
319319

320-
return gguf_write_to_file(ctx_gguf, fname, false);
320+
bool status = gguf_write_to_file(ctx_gguf, fname, false);
321+
322+
ggml_free(ctx);
323+
gguf_free(ctx_gguf);
324+
325+
return status;
321326
}
322327

323328
static void insert_from_arch(std::vector<model_variant> & variants, llm_arch arch) {
@@ -762,9 +767,8 @@ int main(int argc, char ** argv) {
762767
std::mt19937 rng(42);
763768

764769
// TODO: multiple sequences per token
765-
const int64_t n_batch = 2048;
766-
const int64_t n_seq_len = 1024;
767-
std::uniform_int_distribution<int64_t> rand_seq_init_len(n_seq_len / 4, 3 * n_seq_len / 4);
770+
const int32_t n_batch = 2048;
771+
const int32_t n_seq_len = 1024;
768772

769773
llama_batch batch = llama_batch_init(n_batch, 0, 1);
770774
// TODO: batch with embeddings
@@ -794,10 +798,10 @@ int main(int argc, char ** argv) {
794798
// const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
795799
// const auto n_embd = llama_model_n_embd(model);
796800

797-
for (int64_t n_seq_max : { 1, 2, 13 } ) {
801+
for (int32_t n_seq_max : { 1, 2, 13 } ) {
798802

799803
// TODO(later): context shift testing
800-
for (int64_t n_ctx : { n_seq_len * n_seq_max }) {
804+
for (int32_t n_ctx : { n_seq_len * n_seq_max }) {
801805

802806
std::vector<reference_logits> ref_outputs;
803807

@@ -824,7 +828,7 @@ int main(int argc, char ** argv) {
824828

825829
for (bool shuffle : { false, true }) {
826830

827-
for (int64_t n_ubatch : { 1, 2, 512 } ) {
831+
for (int32_t n_ubatch : { 1, 2, 512 } ) {
828832

829833
std::vector<bool> valid(n_seq_max, true);
830834

@@ -852,7 +856,7 @@ int main(int argc, char ** argv) {
852856
if (batch.n_tokens < n_batch) {
853857
const int64_t seq_len =
854858
std::min(n_batch - batch.n_tokens,
855-
(int64_t) ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]);
859+
ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]);
856860

857861
ref_outputs[seq_id].add_to_batch(batch, seq_id_n_past[seq_id], seq_len, seq_id);
858862
seq_ids_in_batch.insert(seq_id);
@@ -891,7 +895,7 @@ int main(int argc, char ** argv) {
891895
}
892896

893897
fprintf(stdout,
894-
"Comparing output for '%s', with shuffle=%i, n_seq_max=%li, n_ctx=%li, n_ubatch=%li: ",
898+
"Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ",
895899
variant.name.c_str(), shuffle, n_seq_max, n_ctx, n_ubatch);
896900
if (std::all_of(valid.begin(), valid.end(), [](bool v) { return v; })) {
897901
fprintf(stdout, "\033[1;32mOK\033[0m\n");

0 commit comments

Comments
 (0)