Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig
from torchtitan.models.attention import init_attention_mask
from tqdm import tqdm


Expand Down Expand Up @@ -92,7 +93,7 @@ def setup(self):
# self.logger = self.setup_logger(self.train_config.logger_config)

def setup_data(self, dataset_config, batch_size):
tokenizer = HuggingFaceModelTokenizer(
self.tokenizer = HuggingFaceModelTokenizer(
tokenizer_json_path=os.path.join(
self.job_config.model.hf_assets_path, "tokenizer.json"
),
Expand All @@ -105,7 +106,7 @@ def setup_data(self, dataset_config, batch_size):
)

dataset = sft_iterable_dataset(
model_transform=tokenizer,
model_transform=self.tokenizer,
message_transform=AlpacaToMessages(),
path=dataset_config.path,
split=dataset_config.split,
Expand Down Expand Up @@ -142,6 +143,13 @@ def forward_backward(
# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
inputs = input_dict["tokens"]

if getattr(self.model_args, "use_flex_attn", False):
cp_mesh = (
parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
)
init_attention_mask(inputs, self.tokenizer.base_tokenizer.eos_id, cp_mesh)

optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=parallel_dims.world_mesh["cp"],
Expand Down
Loading