Skip to content

Commit 1a21945

Browse files
committed
fix sft v2
1 parent 70a1ba6 commit 1a21945

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

apps/sft_v2/llama3_8b.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
# >>> python -m apps.sft_v2.main --config apps/sft_v2/llama3_8b.yaml
2+
13
# Config for supervised full finetuning using a Llama3.1 8B Instruct model
24
#
35
# This config assumes that you've run the following command before launching
46
# this run:
57
# export HF_HUB_DISABLE_XET=1
6-
# uv run forge download meta-llama/Meta-Llama-3.1-8B-Instruct
8+
# forge download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct
79

810

911
# TODO: required by torchtitan
@@ -14,11 +16,11 @@ comm:
1416
model:
1517
name: llama3
1618
flavor: 8B
17-
tokenizer_path: /tmp/Llama-3.1-8B-Instruct
19+
hf_assets_path: /tmp/Meta-Llama-3.1-8B-Instruct
1820

1921
processes:
20-
scheduler: local # local | mast (not supported yet)
21-
hosts: 1
22+
# scheduler: local # local | mast (not supported yet)
23+
# hosts: 1
2224
procs: 8
2325
with_gpus: true
2426

apps/sft_v2/main.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import torchtitan.experiments.forge.train_spec as forge_train_spec
2525
from forge.cli.config import parse
26-
from forge.controller import ForgeActor, spawn_actors
26+
from forge.controller import ForgeActor
2727
from forge.data.collate import collate_packed
2828
from forge.data.datasets.packed import PackedDataset, TextPacker
2929
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
@@ -130,16 +130,16 @@ async def setup(self):
130130
# self.logger = self.setup_logger(self.train_config.logger_config)
131131

132132
def setup_data(self):
133-
print(os.path.join(self.job_config.model.tokenizer_path, "tokenizer.json"))
133+
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json"))
134134
tokenizer = HuggingFaceModelTokenizer(
135135
tokenizer_json_path=os.path.join(
136-
self.job_config.model.tokenizer_path, "tokenizer.json"
136+
self.job_config.model.hf_assets_path, "tokenizer.json"
137137
),
138138
tokenizer_config_json_path=os.path.join(
139-
self.job_config.model.tokenizer_path, "tokenizer_config.json"
139+
self.job_config.model.hf_assets_path, "tokenizer_config.json"
140140
),
141141
generation_config_path=os.path.join(
142-
self.job_config.model.tokenizer_path, "generation_config.json"
142+
self.job_config.model.hf_assets_path, "generation_config.json"
143143
),
144144
)
145145

@@ -280,13 +280,7 @@ def __repr__(self) -> str:
280280
async def run(cfg: DictConfig) -> None:
281281
logging.info("Spawing recipe...")
282282
process_cfg = cfg.pop("processes")
283-
recipe = await spawn_actors(
284-
"sft",
285-
ForgeSFTRecipe,
286-
{"config": cfg},
287-
process_cfg,
288-
set_address=True,
289-
)
283+
recipe = await ForgeSFTRecipe.options(**process_cfg).as_service(cfg)
290284

291285
logging.info("Created recipe, running setup.")
292286
await recipe.setup.fanout()

0 commit comments

Comments
 (0)