@@ -796,22 +796,17 @@ struct clip_graph {
796796 // resampler projector (it is just another transformer)
797797
798798 ggml_tensor * q = model.mm_model_query ;
799- { // layernorm
800- q = ggml_norm (ctx0, q, eps);
801- q = ggml_add (ctx0, ggml_mul (ctx0, q, model.mm_model_ln_q_w ), model.mm_model_ln_q_b );
802- }
803799 ggml_tensor * v = ggml_mul_mat (ctx0, model.mm_model_kv_proj , embeddings);
804- { // layernorm
805- v = ggml_norm (ctx0, v, eps);
806- v = ggml_add (ctx0, ggml_mul (ctx0, v, model.mm_model_ln_kv_w ), model.mm_model_ln_kv_b );
807- }
808- ggml_tensor * k;
809- { // position
810- // q = ggml_add(ctx0, q, model.mm_model_pos_embed);
811- k = ggml_add (ctx0, v, pos_embed);
812- }
813800
814- { // attention
801+ // norm
802+ q = build_norm (q, model.mm_model_ln_q_w , model.mm_model_ln_q_b , NORM_TYPE_NORMAL, eps, -1 );
803+ v = build_norm (v, model.mm_model_ln_kv_w , model.mm_model_ln_kv_b , NORM_TYPE_NORMAL, eps, -1 );
804+
805+ // k = v + pos_embed
806+ ggml_tensor * k = ggml_add (ctx0, v, pos_embed);
807+
808+ // attention
809+ {
815810 int n_embd = clip_n_mmproj_embd (ctx);
816811 const int d_head = 128 ;
817812 int n_head = n_embd/d_head;
@@ -824,32 +819,34 @@ struct clip_graph {
824819 num_query = 64 ;
825820 }
826821
827- ggml_tensor * Q = ggml_add (ctx0, ggml_mul_mat (ctx0, model.mm_model_attn_q_w , q), model.mm_model_attn_q_b );
828- ggml_tensor * K = ggml_add (ctx0, ggml_mul_mat (ctx0, model.mm_model_attn_k_w , k), model.mm_model_attn_k_b );
829- ggml_tensor * V = ggml_add (ctx0, ggml_mul_mat (ctx0, model.mm_model_attn_v_w , v), model.mm_model_attn_v_b );
830- // permute
831- Q = ggml_reshape_4d (ctx0, Q, d_head, n_head, num_query, batch_size);
832- Q = ggml_cont (ctx0, ggml_permute (ctx0, Q, 0 , 2 , 1 , 3 ));
833- Q = ggml_reshape_3d (ctx0, Q, d_head, num_query, n_head * batch_size);
834- K = ggml_reshape_4d (ctx0, K, d_head, n_head, n_pos, batch_size);
835- K = ggml_cont (ctx0, ggml_permute (ctx0, K, 0 , 2 , 1 , 3 ));
836- K = ggml_reshape_3d (ctx0, K, d_head, n_pos, n_head * batch_size);
837- V = ggml_reshape_4d (ctx0, V, d_head, n_head, n_pos, batch_size);
838- V = ggml_cont (ctx0, ggml_permute (ctx0, V, 1 , 2 , 0 , 3 ));
839- V = ggml_reshape_3d (ctx0, V, n_pos, d_head, n_head * batch_size);
840- ggml_tensor * KQ = ggml_mul_mat (ctx0, K, Q);
841- KQ = ggml_soft_max_ext (ctx0, KQ, nullptr , 1 .0f / sqrtf ((float )d_head), 0 .0f );
842- ggml_tensor * KQV = ggml_mul_mat (ctx0, V, KQ);
843- KQV = ggml_reshape_4d (ctx0, KQV, d_head, num_query, n_head, batch_size);
844- KQV = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
845- KQV = ggml_cont_3d (ctx0, KQV, n_embd, num_query, batch_size);
846-
847- embeddings = ggml_add (ctx0, ggml_mul_mat (ctx0, model.mm_model_attn_o_w , KQV), model.mm_model_attn_o_b );
848- }
849- { // layernorm
850- embeddings = ggml_norm (ctx0, embeddings, eps);
851- embeddings = ggml_add (ctx0, ggml_mul (ctx0, embeddings, model.mm_model_ln_post_w ), model.mm_model_ln_post_b );
852- }
822+ ggml_tensor * Q = ggml_add (ctx0,
823+ ggml_mul_mat (ctx0, model.mm_model_attn_q_w , q),
824+ model.mm_model_attn_q_b );
825+ ggml_tensor * K = ggml_add (ctx0,
826+ ggml_mul_mat (ctx0, model.mm_model_attn_k_w , k),
827+ model.mm_model_attn_k_b );
828+ ggml_tensor * V = ggml_add (ctx0,
829+ ggml_mul_mat (ctx0, model.mm_model_attn_v_w , v),
830+ model.mm_model_attn_v_b );
831+
832+ Q = ggml_reshape_3d (ctx0, Q, d_head, n_head, num_query);
833+ K = ggml_reshape_3d (ctx0, K, d_head, n_head, n_pos);
834+ V = ggml_reshape_3d (ctx0, V, d_head, n_head, n_pos);
835+
836+ cb (Q, " resampler_Q" , -1 );
837+ cb (K, " resampler_K" , -1 );
838+ cb (V, " resampler_V" , -1 );
839+
840+ embeddings = build_attn (
841+ model.mm_model_attn_o_w ,
842+ model.mm_model_attn_o_b ,
843+ Q, K, V, nullptr , kq_scale, -1 );
844+ cb (embeddings, " resampler_attn_out" , -1 );
845+ }
846+ // layernorm
847+ embeddings = build_norm (embeddings, model.mm_model_ln_post_w , model.mm_model_ln_post_b , NORM_TYPE_NORMAL, eps, -1 );
848+
849+ // projection
853850 embeddings = ggml_mul_mat (ctx0, model.mm_model_proj , embeddings);
854851
855852 // build the graph
0 commit comments