Skip to content

Commit 4dd5439

Browse files
author
lexasub
committed
zzz
1 parent 29c0f78 commit 4dd5439

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

tools/finetune-gguf-dataset/finetune-gguf.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,21 @@ int main(int argc, char ** argv) {
124124

125125
LOG_INF("%s: Dataset loaded. Total sequences: %" PRId64 "\n", __func__, total_sequences);
126126

127-
if (n_ctx_train == 0) {
127+
int32_t effective_n_ctx_train = n_ctx_train;
128+
if (effective_n_ctx_train == 0) {
128129
uint32_t max_seq_len_in_dataset = 0;
129130
for (int64_t i = 0; i < total_sequences; ++i) {
130131
max_seq_len_in_dataset = std::max(max_seq_len_in_dataset, static_cast<uint32_t>(dataset_reader->llama_gguf_reader_get_tensor_size(i)) / static_cast<uint32_t>(sizeof(llama_token)));
131132
}
132-
n_ctx_train = max_seq_len_in_dataset;
133-
LOG_INF("%s: Auto-determined training context size (n_ctx_train): %d\n", __func__, n_ctx_train);
134-
if (n_ctx_train > llama_n_ctx(ctx)) {
135-
LOG_DBG("%s: Auto-determined training context size (%d) is larger than model's context size (%d). Sequences will be truncated.\n", __func__, n_ctx_train, llama_n_ctx(ctx));
133+
effective_n_ctx_train = max_seq_len_in_dataset;
134+
LOG_INF("%s: Auto-determined training context size (n_ctx_train): %d\n", __func__, effective_n_ctx_train);
135+
if (effective_n_ctx_train > llama_model_n_ctx_train(model)) {
136+
LOG_DBG("%s: Auto-determined training context size (%d) is larger than model's native context size (%d). Sequences will be truncated by llama_opt_dataset_add_data.\n", __func__, effective_n_ctx_train, llama_model_n_ctx_train(model));
137+
}
138+
} else {
139+
LOG_INF("%s: Using user-specified training context size (n_ctx_train): %d\n", __func__, effective_n_ctx_train);
140+
if (effective_n_ctx_train > llama_model_n_ctx_train(model)) {
141+
LOG_DBG("%s: User-specified training context size (%d) is larger than model's native context size (%d). Sequences will be truncated by llama_opt_dataset_add_data.\n", __func__, effective_n_ctx_train, llama_model_n_ctx_train(model));
136142
}
137143
}
138144

@@ -192,12 +198,12 @@ int main(int argc, char ** argv) {
192198
(unsigned) lr.epochs, (double) params.n_batch / params.n_ubatch, (double) params.val_split);
193199

194200
struct llama_opt_params lopt_params {
195-
/*n_ctx_train =*/ 0,
196-
/*param_filter =*/ llama_opt_param_filter_all,
201+
/*n_ctx_train =*/ static_cast<uint32_t>(effective_n_ctx_train), // Use the determined or user-specified training context size
202+
/*param_filter =*/ llama_opt_param_filter_all, // Parse filter string
197203
/*param_filter_ud =*/ nullptr,
198-
/*get_opt_pars =*/ common_opt_lr_pars,
199-
/*get_opt_pars_ud =*/ &params.lr,
200-
/*optimizer_type =*/ params.optimizer,
204+
/*get_opt_pars =*/ common_opt_lr_pars, // Use common learning rate scheduler
205+
/*get_opt_pars_ud =*/ &params.lr, // Pass params.lr struct
206+
/*optimizer_type =*/ params.optimizer, // Use optimizer type from common_params
201207
};
202208
llama_opt_init(ctx, model, lopt_params);
203209

0 commit comments

Comments
 (0)