Skip to content

Commit dec6ce2

Browse files
committed
llama : fix buffer checks for mamba and rwk
1 parent 0a683e8 commit dec6ce2

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

ggml/src/ggml.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7272,6 +7272,7 @@ struct ggml_tensor * ggml_ssm_conv(
72727272
const int64_t n_s = sx->ne[2];
72737273

72747274
// TODO: maybe support other strides than 1?
7275+
// FIXME: this is always true?
72757276
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
72767277
GGML_ASSERT(sx->ne[1] == d_inner);
72777278
GGML_ASSERT(n_t >= 0);

src/llama.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7122,7 +7122,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
71227122
} break;
71237123
case GGML_OP_MUL_MAT:
71247124
{
7125-
ggml_tensor * b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, w->ne[0], 512);
7125+
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]);
71267126
op_tensor = ggml_mul_mat(ctx, w, b);
71277127
} break;
71287128
case GGML_OP_MUL_MAT_ID:
@@ -7162,18 +7162,38 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
71627162
} break;
71637163
case GGML_OP_SSM_CONV:
71647164
{
7165-
// TODO: ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d);
7166-
op_tensor = ggml_ssm_conv(ctx, nullptr, w);
7165+
// FIXME
7166+
ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789);
7167+
op_tensor = ggml_ssm_conv(ctx, conv_x, w);
71677168
} break;
71687169
case GGML_OP_SSM_SCAN:
71697170
{
7170-
// TODO: ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C);
7171-
op_tensor = ggml_ssm_scan(ctx, nullptr, nullptr, nullptr, w, nullptr, nullptr);
7171+
// FIXME
7172+
const int64_t d_state = w->ne[0];
7173+
const int64_t d_inner = w->ne[1];
7174+
const int64_t n_seq_tokens = 512;
7175+
const int64_t n_seqs = 1;
7176+
ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs);
7177+
ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
7178+
ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
7179+
ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
7180+
ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
7181+
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
71727182
} break;
71737183
case GGML_OP_RWKV_WKV:
71747184
{
7175-
// TODO: ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
7176-
op_tensor = ggml_rwkv_wkv(ctx, nullptr, nullptr, nullptr, w, nullptr, nullptr);
7185+
// FIXME
7186+
const int64_t S = 123;
7187+
const int64_t H = 123;
7188+
const int64_t n_tokens = 123;
7189+
const int64_t n_seqs = 123;
7190+
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens);
7191+
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
7192+
ggml_tensor * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
7193+
ggml_tensor * tf = w;
7194+
ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
7195+
ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
7196+
op_tensor = ggml_rwkv_wkv(ctx, k, v, r, tf, td, state);
71777197
} break;
71787198
default:
71797199
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
@@ -7448,7 +7468,7 @@ static bool llm_load_tensors(
74487468

74497469
// tensors with "bias" suffix are always used with GGML_OP_ADD
74507470
ggml_op op;
7451-
bool bias = strcmp(tn.suffix, "bias") == 0;
7471+
bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
74527472
if (bias) {
74537473
op = GGML_OP_ADD;
74547474
} else {

0 commit comments

Comments
 (0)