Skip to content

Commit 3af2709

Browse files
committed
Fixing mamba part of plamo2
1 parent 6e84697 commit 3af2709

File tree

3 files changed

+26
-120
lines changed

3 files changed

+26
-120
lines changed

examples/eval-callback/eval-callback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
122122

123123
if (!ggml_is_quantized(t->type)) {
124124
uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
125-
ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
125+
ggml_print_tensor(data, t->type, t->ne, t->nb, 256);
126126
}
127127

128128
return true;

src/llama-context.cpp

Lines changed: 0 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,109 +1131,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
11311131
}
11321132
}
11331133

1134-
// Debug: Dump tensor values after computation (for PLaMo-2 only)
1135-
#define PLAMO2_DEBUG
1136-
#ifdef PLAMO2_DEBUG
1137-
if (model.arch == LLM_ARCH_PLAMO2) { // Only for small inputs
1138-
// Create debug directory if it doesn't exist
1139-
#ifdef _WIN32
1140-
_mkdir("debug_tensors");
1141-
#else
1142-
mkdir("debug_tensors", 0755);
1143-
#endif
1144-
// Find debug tensors by searching through the graph (gf is now accessible via res->get_graph())
1145-
ggml_cgraph* current_gf = res->get_graph();
1146-
for (int i = 0; i < ggml_graph_n_nodes(current_gf); ++i) {
1147-
ggml_tensor* node = ggml_graph_node(current_gf, i);
1148-
printf("Processing node: %s\n", node->name ? node->name : "unknown");
1149-
if (node && node->name) {
1150-
bool should_dump = (strcmp(node->name, "embedding_output") == 0) ||
1151-
(strstr(node->name, "mamba_") == node->name) ||
1152-
(strstr(node->name, "attn_norm") == node->name) ||
1153-
(strstr(node->name, "norm") == node->name) ||
1154-
(strcmp(node->name, "tokens") == 0) ||
1155-
(strstr(node->name, "attn_pre_norm") == node->name) ||
1156-
(strcmp(node->name, "inp_embd") == 0) ||
1157-
(strcmp(node->name, "inp_tokens") == 0);
1158-
1159-
if (strcmp(node->name, "tokens") == 0) {
1160-
llama_token* token_data = (llama_token*)node->data;
1161-
printf("Input Tokens: ");
1162-
for (int j = 0; j < node->ne[0]; ++j) {
1163-
printf("%d ", token_data[j]);
1164-
}
1165-
printf("\n");
1166-
continue; // Skip dumping tensor values for "tokens"
1167-
}
1168-
1169-
if (should_dump && node->data) {
1170-
printf("=== Post-Compute Tensor Values ===\n");
1171-
printf("Tensor: %s\n", node->name);
1172-
printf("Shape: [%ld, %ld", node->ne[0], node->ne[1]);
1173-
if (node->ne[2] > 1) printf(", %ld", node->ne[2]);
1174-
if (node->ne[3] > 1) printf(", %ld", node->ne[3]);
1175-
printf("]\n");
1176-
1177-
int64_t total_elements = ggml_nelements(node);
1178-
float* data = new float[total_elements];
1179-
if (node->type == GGML_TYPE_F32) {
1180-
data = (float*)node->data;
1181-
} else if (node->type == GGML_TYPE_BF16) {
1182-
ggml_bf16_t * bf16_data = (ggml_bf16_t*)node->data;
1183-
for (int64_t j = 0; j < total_elements; j++) {
1184-
printf("%.6f -> %.6f \n", bf16_data[j], ggml_bf16_to_fp32(bf16_data[j]));
1185-
}
1186-
ggml_bf16_to_fp32_row((ggml_bf16_t*)node->data, data, total_elements);
1187-
}
1188-
1189-
if (total_elements > 0) {
1190-
// Calculate statistics
1191-
float sum = 0.0f, sum_sq = 0.0f, min_val = data[0], max_val = data[0];
1192-
for (int64_t j = 0; j < total_elements; j++) {
1193-
sum += data[j];
1194-
sum_sq += data[j] * data[j];
1195-
min_val = fminf(min_val, data[j]);
1196-
max_val = fmaxf(max_val, data[j]);
1197-
}
1198-
1199-
float mean = sum / total_elements;
1200-
float variance = (sum_sq / total_elements) - (mean * mean);
1201-
float std_dev = sqrtf(variance);
1202-
1203-
printf("Stats - Mean: %.6f, Std: %.6f, Min: %.6f, Max: %.6f\n",
1204-
mean, std_dev, min_val, max_val);
1205-
1206-
// Print first 8 values
1207-
printf("First 8 values: ");
1208-
for (int j = 0; j < 8 && j < total_elements; j++) {
1209-
printf("%.6f ", data[j]);
1210-
}
1211-
printf("\n");
1212-
1213-
// Save to file for detailed comparison
1214-
char filename[256];
1215-
snprintf(filename, sizeof(filename), "debug_tensors/%s.csv", node->name);
1216-
FILE* f = fopen(filename, "w");
1217-
if (f) {
1218-
for (int64_t j = 0; j < total_elements; ++j) {
1219-
fprintf(f, "%f", data[j]);
1220-
if ((j + 1) % node->ne[0] == 0) {
1221-
fprintf(f, "\n");
1222-
} else {
1223-
fprintf(f, ",");
1224-
}
1225-
}
1226-
fclose(f);
1227-
printf("Saved to: %s\n", filename);
1228-
}
1229-
}
1230-
printf("==================================\n");
1231-
}
1232-
}
1233-
}
1234-
}
1235-
#endif // PLAMO2_DEBUG
1236-
12371134
n_outputs_prev += n_outputs;
12381135
} while (mctx->next());
12391136

src/llama-model.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8234,13 +8234,8 @@ struct llm_build_plamo2 : public llm_graph_context {
82348234
// ggml_graph_add_node(gf, model.layers[il].attn_norm);
82358235
// cb(model.layers[il].attn_norm, "attn_norm", il);
82368236

8237-
ggml_graph_add_node(gf, model.layers[il].attn_norm);
8238-
cb(model.layers[il].attn_norm, "attn_norm_weight", il);
8239-
82408237
// pre_mixer_norm
8241-
cb(inpL, "attn_pre_norm_input", il);
82428238
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
8243-
cb(cur, "attn_pre_norm", il);
82448239

82458240
// check if this layer is Mamba or Attention
82468241
bool is_mamba_layer = hparams.is_recurrent(il);
@@ -8280,6 +8275,10 @@ struct llm_build_plamo2 : public llm_graph_context {
82808275
cur = ggml_add(ctx0, cur, residual);
82818276

82828277
inpL = cur;
8278+
8279+
if (il >= 2) {
8280+
break;
8281+
}
82838282
}
82848283

82858284
cur = inpL;
@@ -8445,17 +8444,28 @@ struct llm_build_plamo2 : public llm_graph_context {
84458444
ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur);
84468445
cb(zx, "mamba_in_proj", il);
84478446

8447+
zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
8448+
zx = ggml_reshape_4d(ctx0, zx, 2 * hparams.ssm_head_dim, hparams.ssm_num_heads, n_seq_tokens, n_seqs);
8449+
cb(zx, "mamba_in_proj_out", il);
8450+
84488451
// split into z and x
84498452
// => {d_inner, n_seq_tokens, n_seqs}
8450-
ggml_tensor * x = ggml_view_3d(ctx0, zx, d_inner, zx->ne[1], zx->ne[2], zx->nb[1], zx->nb[2], 0);
8451-
ggml_tensor * z = ggml_view_3d(ctx0, zx, d_inner, zx->ne[1], zx->ne[2], zx->nb[1], zx->nb[2], d_inner*ggml_element_size(zx));
8453+
ggml_tensor * x = ggml_view_4d(ctx0, zx, hparams.ssm_head_dim, zx->ne[1], zx->ne[2], zx->ne[3], zx->nb[1], zx->nb[2], zx->nb[3], hparams.ssm_head_dim*ggml_element_size(zx));
8454+
x = ggml_cont(ctx0, x);
8455+
x = ggml_reshape_4d(ctx0, x, hparams.ssm_head_dim * hparams.ssm_num_heads, 1, n_seq_tokens, n_seqs);
8456+
x = ggml_permute(ctx0, x, 0, 2, 1, 3);
84528457
cb(x, "mamba_x_split", il);
8458+
ggml_tensor * z = ggml_view_4d(ctx0, zx, hparams.ssm_head_dim, zx->ne[1], zx->ne[2], zx->ne[3], zx->nb[1], zx->nb[2], zx->nb[3], 0);
8459+
z = ggml_cont(ctx0, z);
8460+
z = ggml_reshape_4d(ctx0, z, hparams.ssm_head_dim * hparams.ssm_num_heads, 1, n_seq_tokens, n_seqs);
8461+
z = ggml_permute(ctx0, z, 0, 2, 1, 3);
84538462
cb(z, "mamba_z_split", il);
84548463

84558464
// conv1d
84568465
{
84578466
// => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
84588467
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
8468+
cb(conv_x, "mamba_conv1d_input", il);
84598469

84608470
// copy last (d_conv - 1) columns back into the state cache
84618471
ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs,
@@ -8471,9 +8481,6 @@ struct llm_build_plamo2 : public llm_graph_context {
84718481
x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
84728482
cb(x, "mamba_conv1d", il);
84738483

8474-
// bias
8475-
// x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); // PLaMo-2 does not use bias here
8476-
84778484
x = ggml_silu(ctx0, x);
84788485
cb(x, "mamba_conv1d_silu", il);
84798486
}
@@ -8486,9 +8493,9 @@ struct llm_build_plamo2 : public llm_graph_context {
84868493

84878494
// split into dt, B, C
84888495
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
8489-
ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
8490-
ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*dt_dim);
8491-
ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(dt_dim + d_state));
8496+
ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
8497+
ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state);
8498+
ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2*d_state));
84928499
cb(B, "mamba_B_raw", il);
84938500
cb(C, "mamba_C_raw", il);
84948501
cb(dt, "mamba_dt_raw", il);
@@ -8503,15 +8510,17 @@ struct llm_build_plamo2 : public llm_graph_context {
85038510

85048511
// dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
85058512
dt = build_lora_mm(model.layers[il].ssm_dt, dt);
8506-
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
85078513
cb(dt, "mamba_dt_proj", il);
85088514

85098515
// This is corresponding to the broadcast_to operation in ssd_update_state() of the originall code
8510-
ggml_tensor * dt_expanded = ggml_new_tensor_2d(ctx0, dt->type, d_inner, n_seq_tokens);
8516+
ggml_tensor * dt_expanded = ggml_new_tensor_2d(ctx0, dt->type, dt_dim * hparams.ssm_num_heads, dt->ne[1]);
85118517
dt = ggml_repeat(ctx0, dt, dt_expanded);
8518+
cb(dt, "mamba_dt_expanded", il);
8519+
85128520
ggml_tensor * A_expanded = ggml_new_tensor_2d(ctx0, model.layers[il].ssm_a->type, d_state, d_inner);
85138521
A_expanded = ggml_repeat(ctx0, model.layers[il].ssm_a, A_expanded);
8514-
cb(dt, "mamba_dt_expanded", il);
8522+
A_expanded = ggml_exp(ctx0, A_expanded);
8523+
A_expanded = ggml_scale(ctx0, A_expanded, -1.0f);
85158524
cb(A_expanded, "mamba_A_expanded", il);
85168525

85178526
// SSM scan operation

0 commit comments

Comments
 (0)