Skip to content

Commit e2ad0f0

Browse files
committed
feat: move callbacks to fms-hf-tuning
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 2b64383 commit e2ad0f0

File tree

4 files changed

+113
-3
lines changed

4 files changed

+113
-3
lines changed

tuning/config/acceleration_configs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# Local
1616
from .acceleration_framework_config import AccelerationFrameworkConfig
1717
from .attention_and_distributed_packing import AttentionAndDistributedPackingConfig
18+
from .callbacks import get_additional_accel_framework_callbacks
1819
from .fast_moe import FastMoeConfig
1920
from .fused_ops_and_kernels import FusedOpsAndKernelsConfig
2021
from .quantized_lora_config import QuantizedLoraConfig
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Local
2+
from .fast_moe import get_callbacks
3+
4+
5+
def get_additional_accel_framework_callbacks(active_plugins, **kwargs):
6+
callbacks = []
7+
for active_plugin in active_plugins:
8+
if "ScatterMoEAccelerationPlugin" == active_plugin[0]:
9+
callbacks.extend(get_callbacks(**kwargs))
10+
return callbacks

tuning/config/acceleration_configs/fast_moe.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,30 @@
1414

1515
# Standard
1616
from dataclasses import dataclass
17+
import os
18+
19+
# Third Party
20+
from transformers import (
21+
Trainer,
22+
TrainerCallback,
23+
TrainerControl,
24+
TrainerState,
25+
TrainingArguments,
26+
)
27+
from transformers.trainer import TRAINING_ARGS_NAME
28+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
29+
import torch
1730

1831
# Local
1932
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass
2033

34+
is_recover_safetensors_from_dcp_available = True
35+
try:
36+
# Third Party
37+
from fms_acceleration_moe.utils import recover_safetensors_from_dcp
38+
except ImportError:
39+
is_recover_safetensors_from_dcp_available = False
40+
2141

2242
@parsable_dataclass
2343
@dataclass
@@ -34,3 +54,77 @@ class FastMoeConfig:
3454
def __post_init__(self):
3555
# ensure nested dataclasses initialized
3656
ensure_nested_dataclasses_initialized(self)
57+
58+
59+
def get_callbacks(**kwargs):
60+
pretrained_model_name_or_path = kwargs.pop("pretrained_model_name_or_path")
61+
trainer = kwargs.pop("trainer")
62+
callbacks = []
63+
if is_recover_safetensors_from_dcp_available:
64+
65+
class ConvertAndSaveHFCheckpointAtEverySave(TrainerCallback):
66+
def __init__(self, pretrained_model_name_or_path: str, trainer: Trainer):
67+
self.pretrained_model_name_or_path = pretrained_model_name_or_path
68+
self.trainer = trainer
69+
70+
def on_save(
71+
self,
72+
args: TrainingArguments,
73+
state: TrainerState,
74+
control: TrainerControl,
75+
**kwargs,
76+
):
77+
"""
78+
Save all HF files and convert dcp checkpoint to safetensors at every save operation.
79+
"""
80+
81+
def checkpoint():
82+
checkpoint_dir = os.path.join(
83+
args.output_dir,
84+
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
85+
)
86+
hf_converted_output_dir = os.path.join(
87+
checkpoint_dir, "hf_converted_checkpoint"
88+
)
89+
if os.path.exists(hf_converted_output_dir):
90+
# if the folder already exists
91+
# we return, since this is possible to happen
92+
# saving the checkpointing at the end of the training
93+
return
94+
os.mkdir(hf_converted_output_dir)
95+
try:
96+
recover_safetensors_from_dcp(
97+
checkpoint_dir,
98+
self.pretrained_model_name_or_path,
99+
hf_converted_output_dir,
100+
)
101+
# save tokenizer
102+
if self.trainer.processing_class:
103+
self.trainer.processing_class.save_pretrained(
104+
hf_converted_output_dir
105+
)
106+
# save training args
107+
torch.save(
108+
args,
109+
os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME),
110+
)
111+
# save model config files
112+
self.trainer.model.config.save_pretrained(
113+
hf_converted_output_dir
114+
)
115+
116+
except Exception as e:
117+
raise ValueError(
118+
f"Failed to convert the checkpoint {checkpoint_dir}\
119+
to a HF compatible checkpoint"
120+
) from e
121+
122+
if state.is_world_process_zero:
123+
checkpoint()
124+
125+
callbacks.append(
126+
ConvertAndSaveHFCheckpointAtEverySave(
127+
pretrained_model_name_or_path, trainer
128+
)
129+
)
130+
return callbacks

tuning/sft_trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
FastMoeConfig,
4949
FusedOpsAndKernelsConfig,
5050
QuantizedLoraConfig,
51+
get_additional_accel_framework_callbacks,
5152
)
5253
from tuning.config.tracker_configs import (
5354
AimConfig,
@@ -408,10 +409,14 @@ def train(
408409
accelerator = None if not is_accelerate_available() else trainer.accelerator
409410

410411
# ready for train may produce additional callbacks for the trainer
411-
for x in framework.get_callbacks_and_ready_for_train(
412-
model, accelerator, trainer, model_args.model_name_or_path
413-
):
412+
for x in framework.get_callbacks_and_ready_for_train(model, accelerator):
414413
trainer.add_callback(x)
414+
for clb in get_additional_accel_framework_callbacks(
415+
active_plugins=framework.active_plugins,
416+
trainer=trainer,
417+
pretrained_model_name_or_path=model_args.model_name_or_path,
418+
):
419+
trainer.add_callback(clb)
415420

416421
resume_from_checkpoint = None
417422
# Check if resume flag is not passed (None), or if flag is true and

0 commit comments

Comments
 (0)