Skip to content

Commit 8280645

Browse files
committed
context : move common inputs to base class
ggml-ci
1 parent d5e8e1a commit 8280645

File tree

2 files changed

+111
-111
lines changed

2 files changed

+111
-111
lines changed

src/llama-context.cpp

Lines changed: 89 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,95 @@ ggml_tensor * llama_context::build_rope_factors(int il) {
987987
return model.layers[il].rope_short;
988988
}
989989

990+
ggml_tensor * llama_context::build_inp_embd(
991+
ggml_context * ctx0,
992+
ggml_tensor * tok_embd,
993+
const llama_ubatch & ubatch) {
994+
const auto & hparams = model.hparams;
995+
996+
const int64_t n_embd = hparams.n_embd;
997+
998+
struct ggml_tensor * inpL;
999+
1000+
if (ubatch.token) {
1001+
inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1002+
//cb(inp_tokens, "inp_tokens", -1);
1003+
ggml_set_input(inp_tokens);
1004+
1005+
inpL = ggml_get_rows(ctx0, tok_embd, inp_tokens);
1006+
1007+
// apply lora for embedding tokens if needed
1008+
for (const auto & lora : loras) {
1009+
struct llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
1010+
if (lw == nullptr) {
1011+
continue;
1012+
}
1013+
1014+
const float adapter_scale = lora.second;
1015+
const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
1016+
1017+
struct ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
1018+
ctx0, lw->b, // non-transposed lora_b
1019+
ggml_get_rows(ctx0, lw->a, inp_tokens)
1020+
), scale);
1021+
1022+
inpL = ggml_add(ctx0, inpL, inpL_delta);
1023+
}
1024+
} else {
1025+
inp_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
1026+
inpL = inp_embd;
1027+
ggml_set_input(inp_embd);
1028+
}
1029+
1030+
// For Granite architecture
1031+
if (hparams.f_embedding_scale != 0.0f) {
1032+
inpL = ggml_scale(ctx0, inpL, hparams.f_embedding_scale);
1033+
}
1034+
1035+
//cb(inpL, "inp_embd", -1);
1036+
1037+
return inpL;
1038+
}
1039+
1040+
ggml_tensor * llama_context::build_inp_pos(
1041+
ggml_context * ctx0,
1042+
int32_t n_tokens) {
1043+
inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
1044+
ggml_set_input(inp_pos);
1045+
1046+
return inp_pos;
1047+
}
1048+
1049+
ggml_tensor * llama_context::build_inp_out_ids(
1050+
ggml_context * ctx0,
1051+
int32_t n_tokens,
1052+
bool worst_case) {
1053+
const int32_t n_out_ids = worst_case ? n_tokens : n_outputs;
1054+
1055+
inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_out_ids);
1056+
ggml_set_input(inp_out_ids);
1057+
1058+
return inp_out_ids;
1059+
}
1060+
1061+
ggml_tensor * llama_context::build_inp_mean(
1062+
ggml_context * ctx0,
1063+
int32_t n_tokens) {
1064+
inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
1065+
ggml_set_input(inp_mean);
1066+
1067+
return inp_mean;
1068+
}
1069+
1070+
ggml_tensor * llama_context::build_inp_cls(
1071+
ggml_context * ctx0,
1072+
int32_t n_tokens) {
1073+
inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
1074+
ggml_set_input(inp_cls);
1075+
1076+
return inp_cls;
1077+
}
1078+
9901079
//
9911080
// state
9921081
//
@@ -2682,95 +2771,6 @@ ggml_tensor * llama_context_kv_self::build_soft_max_ext(
26822771
return ggml_soft_max_ext(ctx0, kq, inp_KQ_mask_cnv, kq_scale, hparams.f_max_alibi_bias);
26832772
}
26842773

2685-
ggml_tensor * llama_context_kv_self::build_inp_embd(
2686-
ggml_context * ctx0,
2687-
ggml_tensor * tok_embd,
2688-
const llama_ubatch & ubatch) {
2689-
const auto & hparams = model.hparams;
2690-
2691-
const int64_t n_embd = hparams.n_embd;
2692-
2693-
struct ggml_tensor * inpL;
2694-
2695-
if (ubatch.token) {
2696-
inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
2697-
//cb(inp_tokens, "inp_tokens", -1);
2698-
ggml_set_input(inp_tokens);
2699-
2700-
inpL = ggml_get_rows(ctx0, tok_embd, inp_tokens);
2701-
2702-
// apply lora for embedding tokens if needed
2703-
for (const auto & lora : loras) {
2704-
struct llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
2705-
if (lw == nullptr) {
2706-
continue;
2707-
}
2708-
2709-
const float adapter_scale = lora.second;
2710-
const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
2711-
2712-
struct ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
2713-
ctx0, lw->b, // non-transposed lora_b
2714-
ggml_get_rows(ctx0, lw->a, inp_tokens)
2715-
), scale);
2716-
2717-
inpL = ggml_add(ctx0, inpL, inpL_delta);
2718-
}
2719-
} else {
2720-
inp_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
2721-
inpL = inp_embd;
2722-
ggml_set_input(inp_embd);
2723-
}
2724-
2725-
// For Granite architecture
2726-
if (hparams.f_embedding_scale != 0.0f) {
2727-
inpL = ggml_scale(ctx0, inpL, hparams.f_embedding_scale);
2728-
}
2729-
2730-
//cb(inpL, "inp_embd", -1);
2731-
2732-
return inpL;
2733-
}
2734-
2735-
ggml_tensor * llama_context_kv_self::build_inp_pos(
2736-
ggml_context * ctx0,
2737-
int32_t n_tokens) {
2738-
inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
2739-
ggml_set_input(inp_pos);
2740-
2741-
return inp_pos;
2742-
}
2743-
2744-
ggml_tensor * llama_context_kv_self::build_inp_out_ids(
2745-
ggml_context * ctx0,
2746-
int32_t n_tokens,
2747-
bool worst_case) {
2748-
const int32_t n_out_ids = worst_case ? n_tokens : n_outputs;
2749-
2750-
inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_out_ids);
2751-
ggml_set_input(inp_out_ids);
2752-
2753-
return inp_out_ids;
2754-
}
2755-
2756-
ggml_tensor * llama_context_kv_self::build_inp_mean(
2757-
ggml_context * ctx0,
2758-
int32_t n_tokens) {
2759-
inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
2760-
ggml_set_input(inp_mean);
2761-
2762-
return inp_mean;
2763-
}
2764-
2765-
ggml_tensor * llama_context_kv_self::build_inp_cls(
2766-
ggml_context * ctx0,
2767-
int32_t n_tokens) {
2768-
inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
2769-
ggml_set_input(inp_cls);
2770-
2771-
return inp_cls;
2772-
}
2773-
27742774
void llama_context_kv_self::build_k_shift(
27752775
ggml_context * ctx0,
27762776
ggml_cgraph * graph) {

src/llama-context.h

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,28 @@ struct llama_context : public llama_graph_i {
169169

170170
virtual ggml_tensor * build_rope_factors(int il);
171171

172+
virtual ggml_tensor * build_inp_embd(
173+
ggml_context * ctx0,
174+
ggml_tensor * tok_embd,
175+
const llama_ubatch & ubatch);
176+
177+
virtual ggml_tensor * build_inp_pos(
178+
ggml_context * ctx0,
179+
int32_t n_tokens);
180+
181+
virtual ggml_tensor * build_inp_out_ids(
182+
ggml_context * ctx0,
183+
int32_t n_tokens,
184+
bool worst_case);
185+
186+
virtual ggml_tensor * build_inp_mean(
187+
ggml_context * ctx0,
188+
int32_t n_tokens);
189+
190+
virtual ggml_tensor * build_inp_cls(
191+
ggml_context * ctx0,
192+
int32_t n_tokens);
193+
172194
// state save/load
173195

174196
virtual size_t state_get_size();
@@ -330,28 +352,6 @@ class llama_context_kv_self : public llama_context {
330352
struct ggml_tensor * inp_KQ_mask_swa_cnv; // [kv_size, n_batch]
331353
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
332354

333-
virtual ggml_tensor * build_inp_embd(
334-
ggml_context * ctx0,
335-
ggml_tensor * tok_embd,
336-
const llama_ubatch & ubatch) override;
337-
338-
virtual ggml_tensor * build_inp_pos(
339-
ggml_context * ctx0,
340-
int32_t n_tokens) override;
341-
342-
virtual ggml_tensor * build_inp_out_ids(
343-
ggml_context * ctx0,
344-
int32_t n_tokens,
345-
bool worst_case) override;
346-
347-
virtual ggml_tensor * build_inp_mean(
348-
ggml_context * ctx0,
349-
int32_t n_tokens) override;
350-
351-
virtual ggml_tensor * build_inp_cls(
352-
ggml_context * ctx0,
353-
int32_t n_tokens) override;
354-
355355
virtual void build_attn_inp(
356356
ggml_context * ctx0,
357357
int32_t n_tokens,

0 commit comments

Comments
 (0)