diff --git a/apps/sft/main.py b/apps/sft/main.py index b5ae6fc16..3ca1eba9d 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -30,6 +30,7 @@ from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig +from torchtitan.models.attention import init_attention_mask from tqdm import tqdm @@ -92,7 +93,7 @@ def setup(self): # self.logger = self.setup_logger(self.train_config.logger_config) def setup_data(self, dataset_config, batch_size): - tokenizer = HuggingFaceModelTokenizer( + self.tokenizer = HuggingFaceModelTokenizer( tokenizer_json_path=os.path.join( self.job_config.model.hf_assets_path, "tokenizer.json" ), @@ -105,7 +106,7 @@ def setup_data(self, dataset_config, batch_size): ) dataset = sft_iterable_dataset( - model_transform=tokenizer, + model_transform=self.tokenizer, message_transform=AlpacaToMessages(), path=dataset_config.path, split=dataset_config.split, @@ -142,6 +143,13 @@ def forward_backward( # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage inputs = input_dict["tokens"] + + if getattr(self.model_args, "use_flex_attn", False): + cp_mesh = ( + parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None + ) + init_attention_mask(inputs, self.tokenizer.base_tokenizer.eos_id, cp_mesh) + optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( cp_mesh=parallel_dims.world_mesh["cp"],