@@ -1639,15 +1639,28 @@ namespace chatllm
16391639
16401640 ggml::tensor *GroupNorm::forward (ComputeContext *ctx, ggml::tensor *input)
16411641 {
1642+ ggml::tensor *output = nullptr ;
1643+
16421644 // input: [seqlen, normalized_shape]
1643- ggml::tensor *output = num_groups == ggml::get_dim (weight, 0 ) ?
1644- ggml::norm (ctx, input, eps) : ggml::group_norm (ctx, input, num_groups, eps);
1645- auto weight_view = ggml::reshape (ctx, weight, 1 , 1 , ggml::get_dim (weight, 0 ));
1646- output = ggml::mul (ctx, output, weight_view);
1647- if (bias)
1645+ if (num_groups == ggml::get_dim (weight, 0 ))
16481646 {
1649- auto bias_view = ggml::reshape (ctx, bias, 1 , 1 , ggml::get_dim (bias, 0 ));
1650- output = ggml::add (ctx, output, bias_view);
1647+ output = ggml::norm (ctx, input, eps);
1648+ output = ggml::mul (ctx, output, weight);
1649+ if (bias)
1650+ {
1651+ output = ggml::add (ctx, output, bias);
1652+ }
1653+ }
1654+ else
1655+ {
1656+ output = ggml::group_norm (ctx, input, num_groups, eps);
1657+ auto weight_view = ggml::reshape (ctx, weight, 1 , 1 , ggml::get_dim (weight, 0 ));
1658+ output = ggml::mul (ctx, output, weight_view);
1659+ if (bias)
1660+ {
1661+ auto bias_view = ggml::reshape (ctx, bias, 1 , 1 , ggml::get_dim (bias, 0 ));
1662+ output = ggml::add (ctx, output, bias_view);
1663+ }
16511664 }
16521665 return output;
16531666 }
0 commit comments