@@ -279,26 +279,11 @@ class CrossAttention : public GGMLBlock {
279279 int64_t n_context = context->ne [1 ];
280280 int64_t inner_dim = d_head * n_head;
281281
282- auto q = to_q->forward (ctx, x); // [N, n_token, inner_dim]
283- q = ggml_reshape_4d (ctx, q, d_head, n_head, n_token, n); // [N, n_token, n_head, d_head]
284- q = ggml_cont (ctx, ggml_permute (ctx, q, 0 , 2 , 1 , 3 )); // [N, n_head, n_token, d_head]
285- q = ggml_reshape_3d (ctx, q, d_head, n_token, n_head * n); // [N * n_head, n_token, d_head]
282+ auto q = to_q->forward (ctx, x); // [N, n_token, inner_dim]
283+ auto k = to_k->forward (ctx, context); // [N, n_context, inner_dim]
284+ auto v = to_v->forward (ctx, context); // [N, n_context, inner_dim]
286285
287- auto k = to_k->forward (ctx, context); // [N, n_context, inner_dim]
288- k = ggml_reshape_4d (ctx, k, d_head, n_head, n_context, n); // [N, n_context, n_head, d_head]
289- k = ggml_cont (ctx, ggml_permute (ctx, k, 0 , 2 , 1 , 3 )); // [N, n_head, n_context, d_head]
290- k = ggml_reshape_3d (ctx, k, d_head, n_context, n_head * n); // [N * n_head, n_context, d_head]
291-
292- auto v = to_v->forward (ctx, context); // [N, n_context, inner_dim]
293- v = ggml_reshape_4d (ctx, v, d_head, n_head, n_context, n); // [N, n_context, n_head, d_head]
294- v = ggml_cont (ctx, ggml_permute (ctx, v, 1 , 2 , 0 , 3 )); // [N, n_head, d_head, n_context]
295- v = ggml_reshape_3d (ctx, v, n_context, d_head, n_head * n); // [N * n_head, d_head, n_context]
296-
297- auto kqv = ggml_nn_attention (ctx, q, k, v, false ); // [N * n_head, n_token, d_head]
298- kqv = ggml_reshape_4d (ctx, kqv, d_head, n_token, n_head, n);
299- kqv = ggml_cont (ctx, ggml_permute (ctx, kqv, 0 , 2 , 1 , 3 )); // [N, n_token, n_head, d_head]
300-
301- x = ggml_reshape_3d (ctx, kqv, d_head * n_head, n_token, n); // [N, n_token, inner_dim]
286+ x = ggml_nn_attention_ext (ctx, q, k, v, n_head, NULL , false ); // [N, n_token, inner_dim]
302287
303288 x = to_out_0->forward (ctx, x); // [N, n_token, query_dim]
304289 return x;
@@ -382,7 +367,7 @@ class SpatialTransformer : public GGMLBlock {
382367 int64_t n_head;
383368 int64_t d_head;
384369 int64_t depth = 1 ; // 1
385- int64_t context_dim = 768 ; // hidden_size, 1024 for VERSION_2_x
370+ int64_t context_dim = 768 ; // hidden_size, 1024 for VERSION_SD2
386371
387372public:
388373 SpatialTransformer (int64_t in_channels,
0 commit comments