@@ -7818,12 +7818,14 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
78187818struct ggml_tensor * ggml_opt_step_adamw(
78197819        struct ggml_context * ctx,
78207820        struct ggml_tensor  * a,
7821+         struct ggml_tensor  * grad,
78217822        float                 alpha,
78227823        float                 beta1,
78237824        float                 beta2,
78247825        float                 eps,
78257826        float                 wd) {
78267827    GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
7828+     GGML_ASSERT(ggml_are_same_shape(a, grad));
78277829    GGML_ASSERT(alpha >  0.0f);
78287830    GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
78297831    GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
@@ -7842,9 +7844,9 @@ struct ggml_tensor * ggml_opt_step_adamw(
78427844
78437845    result->op     = GGML_OP_OPT_STEP_ADAMW;
78447846    result->src[0] = a;
7845-     result->src[1] = a-> grad;
7846-     result->src[2] = ggml_dup_tensor(ctx, a );
7847-     result->src[3] = ggml_dup_tensor(ctx, a );
7847+     result->src[1] = grad;
7848+     result->src[2] = ggml_dup_tensor(ctx, grad );
7849+     result->src[3] = ggml_dup_tensor(ctx, grad );
78487850
78497851    return result;
78507852}
@@ -18769,7 +18771,7 @@ void ggml_build_opt_adamw(
1876918771
1877018772        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
1877118773            GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
18772-             struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
18774+             struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad,  alpha, beta1, beta2, eps, wd);
1877318775            ggml_build_forward_expand(gb, opt_step);
1877418776        }
1877518777    }
0 commit comments