Skip to content

Commit 96998c2

Browse files
authored
Merge branch 'ggml-org:master' into mradermacher
2 parents 75b3322 + 086cf81 commit 96998c2

File tree

4 files changed

+60
-47
lines changed

4 files changed

+60
-47
lines changed

src/llama-batch.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ bool llama_batch_allocr::init(
157157
n_outputs += batch.logits[i] != 0;
158158
}
159159

160+
has_cpl = false;
161+
160162
// determine coupled sequences
161163
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
162164
for (int32_t i = 0; i < batch.n_tokens; ++i) {

src/llama-batch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class llama_batch_allocr {
117117
using seq_cpl_t = std::vector<bool>;
118118

119119
// helper flag to quickly determine if there are any coupled sequences in the batch
120-
bool has_cpl;
120+
bool has_cpl = false;
121121

122122
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
123123
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1

src/llama-kv-cache-unified.cpp

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,8 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
12871287
const int64_t n_tps = n_tokens/n_stream;
12881288
const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
12891289

1290+
std::fill(data, data + ggml_nelements(dst), -INFINITY);
1291+
12901292
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
12911293
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
12921294
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
@@ -1310,44 +1312,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
13101312

13111313
const llama_pos p1 = ubatch->pos[i];
13121314

1313-
for (uint32_t j = 0; j < n_kv; ++j) {
1314-
float f = 0.0f;
1315-
1316-
bool masked = false;
1315+
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
13171316

1317+
for (uint32_t j = 0; j < n_kv; ++j) {
13181318
if (cells.is_empty(j)) {
1319-
masked = true;
1320-
} else {
1321-
const llama_pos p0 = cells.pos_get(j);
1322-
1323-
// mask the token if not the same sequence
1324-
masked = masked || (!cells.seq_has(j, seq_id));
1319+
continue;
1320+
}
13251321

1326-
// mask future tokens
1327-
masked = masked || (causal_attn && p0 > p1);
1322+
// mask the token if not the same sequence
1323+
if (!cells.seq_has(j, seq_id)) {
1324+
continue;
1325+
}
13281326

1329-
// apply SWA if any
1330-
masked = masked || (is_masked_swa(p0, p1));
1327+
const llama_pos p0 = cells.pos_get(j);
13311328

1332-
if (!masked && hparams.use_alibi) {
1333-
f = -std::abs(p0 - p1);
1334-
}
1329+
// mask future tokens
1330+
if (causal_attn && p0 > p1) {
1331+
continue;
13351332
}
13361333

1337-
if (masked) {
1338-
f = -INFINITY;
1334+
// apply SWA if any
1335+
if (is_masked_swa(p0, p1)) {
1336+
continue;
13391337
}
13401338

1341-
data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f;
1342-
}
1343-
1344-
// mask padded tokens
1345-
if (data) {
1346-
for (uint32_t ii = n_tps; ii < n_tps_pad; ++ii) {
1347-
for (uint32_t j = 0; j < n_kv; ++j) {
1348-
data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY;
1349-
}
1350-
}
1339+
data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
13511340
}
13521341
}
13531342
}

src/llama-model.cpp

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16564,46 +16564,68 @@ struct llm_build_lfm2 : public llm_graph_context {
1656416564
ggml_tensor * cur,
1656516565
llm_graph_input_rs * inp_recr,
1656616566
int il) {
16567-
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
16567+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
16568+
const uint32_t kv_head = mctx_cur->get_head();
16569+
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
16570+
const int64_t n_seqs = ubatch.n_seqs;
16571+
GGML_ASSERT(n_seqs != 0);
16572+
GGML_ASSERT(ubatch.equal_seqs);
16573+
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
16574+
16575+
GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
16576+
const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
16577+
16578+
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
16579+
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
1656816580

1656916581
auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
1657016582
cb(bcx, "model.layers.{}.conv.in_proj", il);
1657116583

1657216584
constexpr auto n_chunks = 3;
1657316585
GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
1657416586
auto const chunk_size = bcx->ne[0] / n_chunks;
16575-
auto * b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx));
16576-
auto * c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx));
16577-
auto * x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx));
16587+
auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 0*chunk_size*ggml_element_size(bcx));
16588+
auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 1*chunk_size*ggml_element_size(bcx));
16589+
auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 2*chunk_size*ggml_element_size(bcx));
1657816590

1657916591
auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
1658016592

16581-
// read conv state directly, with build_rs generation is slower
16582-
ggml_tensor * conv_state = mctx_cur->get_r_l(il);
16583-
const int64_t n_seqs = ubatch.n_seqs;
16584-
ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
16585-
conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
16593+
// read conv state
16594+
auto * conv_state = mctx_cur->get_r_l(il);
16595+
auto * conv_rs = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
16596+
auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
1658616597

1658716598
bx = ggml_concat(ctx0, conv, bx, 0);
1658816599
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
1658916600

16590-
auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
16601+
// last d_conv columns is a new conv state
16602+
auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], (bx->ne[0] - conv->ne[0])*ggml_element_size(bx));
1659116603
GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
1659216604

16593-
// write conv state
16594-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state));
16605+
// write new conv conv state
16606+
ggml_build_forward_expand(
16607+
gf,
16608+
ggml_cpy(
16609+
ctx0,
16610+
new_conv,
16611+
ggml_view_1d(
16612+
ctx0,
16613+
conv_state,
16614+
ggml_nelements(new_conv),
16615+
kv_head*d_conv*n_embd*ggml_element_size(new_conv)
16616+
)
16617+
)
16618+
);
1659516619

1659616620
auto * conv_kernel = model.layers[il].shortconv.conv;
16597-
GGML_ASSERT(hparams.n_shortconv_l_cache > 0);
16598-
16599-
// construct ssm_conv op
16600-
ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
16621+
auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
1660116622
cb(conv_out, "model.layers.{}.conv.conv", il);
1660216623

1660316624
auto * y = ggml_mul(ctx0, c, conv_out);
16604-
1660516625
y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
1660616626
cb(y, "model.layers.{}.conv.out_proj", il);
16627+
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
16628+
y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
1660716629

1660816630
return y;
1660916631
}

0 commit comments

Comments
 (0)