@@ -7818,12 +7818,14 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
7818
7818
struct ggml_tensor * ggml_opt_step_adamw(
7819
7819
struct ggml_context * ctx,
7820
7820
struct ggml_tensor * a,
7821
+ struct ggml_tensor * grad,
7821
7822
float alpha,
7822
7823
float beta1,
7823
7824
float beta2,
7824
7825
float eps,
7825
7826
float wd) {
7826
7827
GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
7828
+ GGML_ASSERT(ggml_are_same_shape(a, grad));
7827
7829
GGML_ASSERT(alpha > 0.0f);
7828
7830
GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
7829
7831
GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
@@ -7842,9 +7844,9 @@ struct ggml_tensor * ggml_opt_step_adamw(
7842
7844
7843
7845
result->op = GGML_OP_OPT_STEP_ADAMW;
7844
7846
result->src[0] = a;
7845
- result->src[1] = a-> grad;
7846
- result->src[2] = ggml_dup_tensor(ctx, a );
7847
- result->src[3] = ggml_dup_tensor(ctx, a );
7847
+ result->src[1] = grad;
7848
+ result->src[2] = ggml_dup_tensor(ctx, grad );
7849
+ result->src[3] = ggml_dup_tensor(ctx, grad );
7848
7850
7849
7851
return result;
7850
7852
}
@@ -18769,7 +18771,7 @@ void ggml_build_opt_adamw(
18769
18771
18770
18772
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
18771
18773
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
18772
- struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
18774
+ struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd);
18773
18775
ggml_build_forward_expand(gb, opt_step);
18774
18776
}
18775
18777
}
0 commit comments