Skip to content

Commit 0674957

Browse files
committed
attempt #1
1 parent da62356 commit 0674957

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

crates/models/bert/src/lib.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ impl KnownModel for Bert {
241241

242242
// self-attention
243243
{
244-
current = ctx0.op_cont(&current);
244+
print_shape(&current, "current");
245245
let q_current = ctx0.op_reshape_3d(
246246
&ctx0.op_add(
247247
&ctx0.op_mul_mat(&self.layers[il].q_w, &current),
@@ -263,6 +263,10 @@ impl KnownModel for Bert {
263263
input_len,
264264
);
265265
let k = ctx0.op_permute(&k_current, (0, 2, 1, 3));
266+
let k = ctx0.op_cpy(
267+
&k,
268+
&ctx0.new_tensor_3d(ggml::Type::F32, d_head, input_len, n_head),
269+
);
266270

267271
let v_current = ctx0.op_reshape_3d(
268272
&ctx0.op_add(
@@ -274,9 +278,13 @@ impl KnownModel for Bert {
274278
input_len,
275279
);
276280
let mut v = ctx0.op_permute(&v_current, (0, 2, 1, 3));
281+
v = ctx0.op_cpy(
282+
&v,
283+
&ctx0.new_tensor_3d(ggml::Type::F32, d_head, input_len, n_head),
284+
);
277285

278-
let k = ctx0.op_cont(&k);
279-
let q = ctx0.op_cont(&q);
286+
print_shape(&k, "k");
287+
print_shape(&q, "q");
280288
let mut kq = ctx0.op_mul_mat(&k, &q);
281289

282290
// TODO: look into op_scale_inplace and op_soft_max_inplace
@@ -465,3 +473,7 @@ struct Layer {
465473
ff_o_w: ggml::Tensor,
466474
ff_o_b: ggml::Tensor,
467475
}
476+
477+
fn print_shape(t: &ggml::Tensor, name: &str) {
478+
println!("{name} {} {:?}", t.get_type(), t.get_ne());
479+
}

0 commit comments

Comments
 (0)