Skip to content

Commit 6ee86e5

Browse files
committed
graph : restore ubatch in build_cb
ggml-ci
1 parent f63aeec commit 6ee86e5

File tree

4 files changed

+6
-5
lines changed

4 files changed

+6
-5
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ bool llama_context::apply_adapter_cvec(
196196
void llama_context::build_cb(
197197
ggml_tensor * cur,
198198
const char * name,
199+
const llama_ubatch & ubatch,
199200
int il) {
200201
if (il >= 0) {
201202
ggml_format_name(cur, "%s-%d", name, il);
@@ -213,10 +214,7 @@ void llama_context::build_cb(
213214
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
214215
// FIXME: fix in ggml_backend_sched
215216
const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
216-
// TODO: during #11213, the requirement for ubatch.n_tokens < 32 was removed to simplify
217-
// not sure if this is still needed, but it can be brought back if needed
218-
//if (ubatch.n_tokens < 32 || full_offload) {
219-
if (full_offload) {
217+
if (ubatch.n_tokens < 32 || full_offload) {
220218
if (il != -1 && strcmp(name, "norm") == 0) {
221219
const auto & dev_layer = model.dev_layer(il);
222220
for (auto & backend : backends) {

src/llama-context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ struct llama_context : public llama_graph_i {
8585
virtual void build_cb(
8686
ggml_tensor * cur,
8787
const char * name,
88+
const llama_ubatch & ubatch,
8889
int il);
8990

9091
// TODO: add encode/decode graphs

src/llama-graph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class llama_graph_i {
1414
virtual void build_cb(
1515
ggml_tensor * cur,
1616
const char * name,
17+
const llama_ubatch & ubatch,
1718
int il) = 0;
1819

1920
// apply control vector for layer il

src/llama-model.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara
248248
return cur_buft;
249249
}
250250
}
251+
251252
return nullptr;
252253
}
253254

@@ -3888,7 +3889,7 @@ struct llm_build_context {
38883889

38893890
// TODO: tmp
38903891
void cb(struct ggml_tensor * cur, const char * name, int il) {
3891-
lgf.build_cb(cur, name, il);
3892+
lgf.build_cb(cur, name, ubatch, il);
38923893
}
38933894

38943895
// TODO: tmp

0 commit comments

Comments
 (0)