Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ernie/dataset/vl_sft_reader/vl_sft_dataset_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def gen_sample_list(self):
for i, _ in enumerate(self.task_group):
sample_size = int(self.weight_list[i] * self.length)
print(
f"Take {sample_size} samples from {self.task_group[i]._file_name} (total length: {len(self.task_group[i].exs)}) to construct current sample list"
f"Take {sample_size} samples from {self.task_group[i]._file_name} to construct current sample list"
)
indices.extend([i] * sample_size)
return indices
Expand Down
8 changes: 6 additions & 2 deletions erniekit/train/ocr_vl_sft/pretraining_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,7 +1655,11 @@ def lr_ratio_fn(param):

return self.optimizer

def save_model(self, output_dir=None):
def save_model(
self,
output_dir=None,
merge_tensor_parallel=False,
):
"""
Saves the model and associated configuration files to the specified directory.

Expand All @@ -1669,7 +1673,7 @@ def save_model(self, output_dir=None):
None

"""
super().save_model(output_dir)
super().save_model(output_dir, merge_tensor_parallel)
if self.args.should_save:
with open(
os.path.join(output_dir, "static_name_to_dyg_name.json"), "w"
Expand Down
1 change: 1 addition & 0 deletions erniekit/train/ocr_vl_sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def run_ocr_vl_sft(
preprocess_args.batch_size = finetuning_args.batch_size
finetuning_args.max_seq_len = data_args.max_seq_len
finetuning_args.max_seq_length = data_args.max_seq_len
finetuning_args.packing = data_args.packing

# create output dir
os.makedirs(finetuning_args.output_dir, exist_ok=True)
Expand Down