@@ -5374,6 +5374,7 @@ static struct ggml_tensor * ggml_group_norm_impl(
53745374 struct ggml_context * ctx,
53755375 struct ggml_tensor * a,
53765376 int n_groups,
5377+ float eps,
53775378 bool inplace) {
53785379
53795380 bool is_node = false;
@@ -5384,7 +5385,8 @@ static struct ggml_tensor * ggml_group_norm_impl(
53845385
53855386 struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
53865387
5387- result->op_params[0] = n_groups;
5388+ ggml_set_op_params_i32(result, 0, n_groups);
5389+ ggml_set_op_params_f32(result, 1, eps);
53885390
53895391 result->op = GGML_OP_GROUP_NORM;
53905392 result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5396,15 +5398,17 @@ static struct ggml_tensor * ggml_group_norm_impl(
53965398struct ggml_tensor * ggml_group_norm(
53975399 struct ggml_context * ctx,
53985400 struct ggml_tensor * a,
5399- int n_groups) {
5400- return ggml_group_norm_impl(ctx, a, n_groups, false);
5401+ int n_groups,
5402+ float eps) {
5403+ return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
54015404}
54025405
54035406struct ggml_tensor * ggml_group_norm_inplace(
54045407 struct ggml_context * ctx,
54055408 struct ggml_tensor * a,
5406- int n_groups) {
5407- return ggml_group_norm_impl(ctx, a, n_groups, true);
5409+ int n_groups,
5410+ float eps) {
5411+ return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
54085412}
54095413
54105414// ggml_mul_mat
@@ -12095,10 +12099,11 @@ static void ggml_compute_forward_group_norm_f32(
1209512099
1209612100 GGML_TENSOR_UNARY_OP_LOCALS
1209712101
12098- const float eps = 1e-6f; // TODO: make this a parameter
12099-
1210012102 // TODO: optimize
1210112103
12104+ float eps;
12105+ memcpy(&eps, dst->op_params + 1, sizeof(float));
12106+
1210212107 int n_channels = src0->ne[2];
1210312108 int n_groups = dst->op_params[0];
1210412109 int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
0 commit comments