Skip to content

Commit 5030b1e

Browse files
committed
fix regression: LayerNorm
1 parent 08bfc9d commit 5030b1e

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

src/layers.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1639,15 +1639,28 @@ namespace chatllm
16391639

16401640
ggml::tensor *GroupNorm::forward(ComputeContext *ctx, ggml::tensor *input)
16411641
{
1642+
ggml::tensor *output = nullptr;
1643+
16421644
// input: [seqlen, normalized_shape]
1643-
ggml::tensor *output = num_groups == ggml::get_dim(weight, 0) ?
1644-
ggml::norm(ctx, input, eps) : ggml::group_norm(ctx, input, num_groups, eps);
1645-
auto weight_view = ggml::reshape(ctx, weight, 1, 1, ggml::get_dim(weight, 0));
1646-
output = ggml::mul(ctx, output, weight_view);
1647-
if (bias)
1645+
if (num_groups == ggml::get_dim(weight, 0))
16481646
{
1649-
auto bias_view = ggml::reshape(ctx, bias, 1, 1, ggml::get_dim(bias, 0));
1650-
output = ggml::add(ctx, output, bias_view);
1647+
output = ggml::norm(ctx, input, eps);
1648+
output = ggml::mul(ctx, output, weight);
1649+
if (bias)
1650+
{
1651+
output = ggml::add(ctx, output, bias);
1652+
}
1653+
}
1654+
else
1655+
{
1656+
output = ggml::group_norm(ctx, input, num_groups, eps);
1657+
auto weight_view = ggml::reshape(ctx, weight, 1, 1, ggml::get_dim(weight, 0));
1658+
output = ggml::mul(ctx, output, weight_view);
1659+
if (bias)
1660+
{
1661+
auto bias_view = ggml::reshape(ctx, bias, 1, 1, ggml::get_dim(bias, 0));
1662+
output = ggml::add(ctx, output, bias_view);
1663+
}
16511664
}
16521665
return output;
16531666
}

0 commit comments

Comments
 (0)