Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions apps/sft_v2/llama3_8b.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# >>> python -m apps.sft_v2.main --config apps/sft_v2/llama3_8b.yaml

# Config for supervised full finetuning using a Llama3.1 8B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# export HF_HUB_DISABLE_XET=1
# uv run forge download meta-llama/Meta-Llama-3.1-8B-Instruct
# forge download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't have uv installed so this cmd didn't work. is uv necessary?

also, do we still need to do export HF_HUB_DISABLE_XET=1?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i tried 'HF_HUB_DISABLE_XET=1 about ~1 week ago and it broke. I dont think it works well with our devmachines. @joecummings do you know?



# TODO: required by torchtitan
Expand All @@ -14,11 +16,11 @@ comm:
model:
name: llama3
flavor: 8B
tokenizer_path: /tmp/Llama-3.1-8B-Instruct
hf_assets_path: /tmp/Meta-Llama-3.1-8B-Instruct

processes:
scheduler: local # local | mast (not supported yet)
hosts: 1
# scheduler: local # local | mast (not supported yet)
# hosts: 1
procs: 8
with_gpus: true

Expand Down
18 changes: 6 additions & 12 deletions apps/sft_v2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import torchtitan.experiments.forge.train_spec as forge_train_spec
from forge.cli.config import parse
from forge.controller import ForgeActor, spawn_actors
from forge.controller import ForgeActor
from forge.data.collate import collate_packed
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
Expand Down Expand Up @@ -130,16 +130,16 @@ async def setup(self):
# self.logger = self.setup_logger(self.train_config.logger_config)

def setup_data(self):
print(os.path.join(self.job_config.model.tokenizer_path, "tokenizer.json"))
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json"))
tokenizer = HuggingFaceModelTokenizer(
tokenizer_json_path=os.path.join(
self.job_config.model.tokenizer_path, "tokenizer.json"
self.job_config.model.hf_assets_path, "tokenizer.json"
),
tokenizer_config_json_path=os.path.join(
self.job_config.model.tokenizer_path, "tokenizer_config.json"
self.job_config.model.hf_assets_path, "tokenizer_config.json"
),
generation_config_path=os.path.join(
self.job_config.model.tokenizer_path, "generation_config.json"
self.job_config.model.hf_assets_path, "generation_config.json"
),
)

Expand Down Expand Up @@ -280,13 +280,7 @@ def __repr__(self) -> str:
async def run(cfg: DictConfig) -> None:
logging.info("Spawing recipe...")
process_cfg = cfg.pop("processes")
recipe = await spawn_actors(
"sft",
ForgeSFTRecipe,
{"config": cfg},
process_cfg,
set_address=True,
)
recipe = await ForgeSFTRecipe.options(**process_cfg).as_service(cfg)

logging.info("Created recipe, running setup.")
await recipe.setup.fanout()
Expand Down
Loading