@@ -7792,9 +7792,75 @@ struct llm_build_plamo2 : public llm_graph_context {
77927792 Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
77937793 }
77947794 } else if (layer_type == "mamba") {
7795- // Mamba layer processing - simplified implementation for now
7796- // TODO: Implement full mamba layer logic
7797- GGML_ASSERT(false && "Mamba layers not yet fully implemented for PLaMo2");
7795+ // Mamba layer processing
7796+ const int64_t d_conv = hparams.ssm_d_conv;
7797+ const int64_t d_inner = hparams.ssm_d_inner;
7798+ const int64_t d_state = hparams.ssm_d_state;
7799+ const int64_t dt_rank = hparams.ssm_dt_rank;
7800+
7801+ // Apply linear transformation: n_embd -> 2*d_inner
7802+ ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, mixer_norm);
7803+ cb(xz, "ssm_in", il);
7804+
7805+ // Split into x and z
7806+ ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, n_tokens, xz->nb[1], 0);
7807+ ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, n_tokens, xz->nb[1], d_inner*ggml_element_size(xz));
7808+
7809+ // For simplified PLaMo2 implementation without state caching,
7810+ // we use a basic convolution approach
7811+ // Reshape x for convolution: {d_inner, n_tokens} -> {n_tokens, d_inner}
7812+ x = ggml_cont(ctx0, ggml_transpose(ctx0, x));
7813+
7814+ // Apply 1D convolution with proper padding
7815+ // Note: PLaMo2 conv1d weight shape is {d_inner, d_conv}
7816+ ggml_tensor * conv_w = ggml_reshape_2d(ctx0, model.layers[il].ssm_conv1d, d_conv, d_inner);
7817+ x = ggml_conv_1d(ctx0, conv_w, x, 1, d_conv - 1, 1);
7818+
7819+ // Transpose back and apply SiLU
7820+ x = ggml_cont(ctx0, ggml_transpose(ctx0, x));
7821+ x = ggml_silu(ctx0, x);
7822+ cb(x, "ssm_conv", il);
7823+
7824+ // SSM sequence transformation
7825+ {
7826+ // Project x to dt, B, C
7827+ ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_bcdt, x);
7828+ cb(x_db, "ssm_bcdt", il);
7829+
7830+ // Split into dt, B, C
7831+ ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
7832+ ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
7833+ ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
7834+
7835+ // Project dt_rank to d_inner
7836+ dt = build_lora_mm(model.layers[il].ssm_dt, dt);
7837+ cb(dt, "ssm_dt", il);
7838+
7839+ // For simplified implementation without full SSM scan,
7840+ // we'll create a basic selective scan approximation
7841+ // Note: This is a simplified version and may not capture all SSM dynamics
7842+
7843+ // Create dummy state tensors for ggml_ssm_scan
7844+ ggml_tensor * dummy_s = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, d_state, d_inner, 1);
7845+
7846+ // Use ggml_ssm_scan for proper SSM computation
7847+ ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, dummy_s, x, dt, model.layers[il].ssm_a, B, C);
7848+
7849+ // Extract the output (first part of y_ssm)
7850+ ggml_tensor * y = ggml_view_2d(ctx0, y_ssm, d_inner, n_tokens, y_ssm->nb[1], 0);
7851+
7852+ // Add D parameter contribution
7853+ y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
7854+ x = y;
7855+ }
7856+
7857+ // Gated output
7858+ x = ggml_mul(ctx0, x, ggml_silu(ctx0, ggml_cont(ctx0, z)));
7859+ cb(x, "ssm_gate", il);
7860+
7861+ // Output projection
7862+ cur = build_lora_mm(model.layers[il].ssm_out, x);
7863+ cb(cur, "ssm_out", il);
77987864 } else {
77997865 // Default to attention-like processing for unknown layer types
78007866 cur = build_lora_mm(model.layers[il].wqkv, mixer_norm);
0 commit comments