Skip to content

Commit 5b10edf

Browse files
committed
add CPU only test + default split test
ggml-ci
1 parent bbd8b66 commit 5b10edf

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

tests/test-thread-safety.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ int main(int argc, char ** argv) {
3232
// }
3333
//}, NULL);
3434

35-
auto mparams = common_model_params_to_llama(params);
3635
auto cparams = common_context_params_to_llama(params);
3736

3837
int dev_count = ggml_backend_dev_count();
@@ -43,7 +42,7 @@ int main(int argc, char ** argv) {
4342
gpu_dev_count++;
4443
}
4544
}
46-
const int num_models = gpu_dev_count + 1; // GPUs + 1 CPU model
45+
const int num_models = gpu_dev_count + 1 + 1; // GPUs + 1 CPU model + 1 layer split
4746
//const int num_models = std::max(1, gpu_dev_count);
4847
const int num_contexts = std::max(1, params.n_parallel);
4948

@@ -52,8 +51,17 @@ int main(int argc, char ** argv) {
5251
std::atomic<bool> failed = false;
5352

5453
for (int m = 0; m < num_models; ++m) {
55-
mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
56-
mparams.main_gpu = m < gpu_dev_count ? m : -1;
54+
auto mparams = common_model_params_to_llama(params);
55+
56+
if (m < gpu_dev_count) {
57+
mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
58+
mparams.main_gpu = m;
59+
} else if (m == gpu_dev_count) {
60+
mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
61+
mparams.main_gpu = -1; // CPU model
62+
} else {
63+
mparams.split_mode = LLAMA_SPLIT_MODE_LAYER;;
64+
}
5765

5866
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
5967
if (model == NULL) {
@@ -111,20 +119,21 @@ int main(int argc, char ** argv) {
111119
token = llama_vocab_bos(vocab);
112120
}
113121

122+
result += common_token_to_piece(ctx.get(), token);
123+
114124
if (llama_vocab_is_eog(vocab, token)) {
115125
break;
116126
}
117-
result += common_token_to_piece(ctx.get(), token);
118127

119128
batch = llama_batch_get_one(&token, 1);
120129
if (llama_decode(ctx.get(), batch)) {
121-
LOG_ERR("failed to decode\n");
130+
LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
122131
failed.store(true);
123132
return;
124133
}
125134
}
126135

127-
LOG_INF("Model %d/%d, Context %d/%d: Result: '%s'\n", m + 1, num_models, c + 1, num_contexts, result.c_str());
136+
LOG_INF("Model %d/%d, Context %d/%d: %s\n\n", m + 1, num_models, c + 1, num_contexts, result.c_str());
128137
});
129138
}
130139
}
@@ -138,6 +147,6 @@ int main(int argc, char ** argv) {
138147
return 1;
139148
}
140149

141-
LOG_INF("All threads completed successfully.\n");
150+
LOG_INF("All threads finished without errors.\n");
142151
return 0;
143152
}

0 commit comments

Comments
 (0)