@@ -4830,7 +4830,6 @@ struct llama_model_loader {
48304830 n_created++;
48314831 }
48324832
4833- ggml_set_param(nullptr, tensor);
48344833 return tensor;
48354834
48364835 }
@@ -22636,10 +22635,20 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
2263622635// training
2263722636//
2263822637
22639- static struct ggml_opt_optimizer_params llama_get_default_optimizer_params(void * userdata) {
22640- struct ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata);
22641- result.adamw.alpha = 1e-6f;
22642- return result;
22638+ bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
22639+ GGML_UNUSED(tensor);
22640+ GGML_UNUSED(userdata);
22641+ return true;
22642+ }
22643+
22644+ static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
22645+ if (!tensor || tensor->type != GGML_TYPE_F32) {
22646+ return;
22647+ }
22648+ if (!param_filter(tensor, userdata)) {
22649+ return;
22650+ }
22651+ ggml_set_param(tensor);
2264322652}
2264422653
2264522654void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params) {
@@ -22656,6 +22665,30 @@ void llama_opt_init(struct llama_context * lctx, struct llama_model * model, str
2265622665 opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
2265722666
2265822667 lctx->opt_ctx = ggml_opt_init(opt_params);
22668+
22669+ llama_opt_param_filter param_filter = lopt_params.param_filter;
22670+ void * param_filter_ud = lopt_params.param_filter_ud;
22671+
22672+ llama_set_param(model->tok_embd, param_filter, param_filter_ud);
22673+ llama_set_param(model->type_embd, param_filter, param_filter_ud);
22674+ llama_set_param(model->pos_embd, param_filter, param_filter_ud);
22675+ llama_set_param(model->tok_norm, param_filter, param_filter_ud);
22676+ llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
22677+ llama_set_param(model->output_norm, param_filter, param_filter_ud);
22678+ llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
22679+ llama_set_param(model->output, param_filter, param_filter_ud);
22680+ llama_set_param(model->output_b, param_filter, param_filter_ud);
22681+ llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
22682+ llama_set_param(model->cls, param_filter, param_filter_ud);
22683+ llama_set_param(model->cls_b, param_filter, param_filter_ud);
22684+ llama_set_param(model->cls_out, param_filter, param_filter_ud);
22685+ llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
22686+
22687+ for (struct llama_layer & layer : model->layers) {
22688+ for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
22689+ llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
22690+ }
22691+ }
2265922692}
2266022693
2266122694static void llama_opt_epoch_iter(
0 commit comments