Skip to content

Commit a6a1e3f

Browse files
Update ggml/src/ggml-opt.cpp
Co-authored-by: Johannes Gäßler <[email protected]>
1 parent 085f870 commit a6a1e3f

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

ggml/src/ggml-opt.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -550,9 +550,18 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
550550
ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
551551
ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
552552
}
553-
struct ggml_tensor * opt_step =
554-
m ? ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params) :
555-
ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
553+
struct ggml_tensor * opt_step;
554+
switch (opt_ctx->optimizer_type) {
555+
case GGML_OPT_OPTIMIZER_ADAMW:
556+
opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
557+
break;
558+
case GGML_OPT_OPTIMIZER_SGD:
559+
opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
560+
break;
561+
default:
562+
GGML_ABORT("fatal error");
563+
break;
564+
}
556565
step_prefix.resize(step_prefix_len); // to avoid recreating a new step_prefix string temp n_nodes times
557566
ggml_set_name(opt_step, (step_prefix += node->name).c_str());
558567
ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);

0 commit comments

Comments
 (0)