Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions apps/sft_v2/qwen3_8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# >>> python -m apps.sft_v2.main --config apps/sft_v2/qwen3_8b.yaml


# TODO: required by torchtitan
# https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265
Comment on lines +4 to +5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove this now?

Copy link
Contributor Author

@daniellepintz daniellepintz Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comm:
trace_buf_size: 0

model_name: "Qwen/Qwen3-8B"

model:
name: qwen3
flavor: 8B
hf_assets_path: hf://${model_name}

processes:
procs: 8
with_gpus: true

optimizer:
name: AdamW
lr: 1e-5
eps: 1e-8

lr_scheduler:
warmup_steps: 200

training:
local_batch_size: 1
seq_len: 2048
max_norm: 1.0
steps: 1000
compile: false
dataset: "c4"

parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
disable_loss_parallel: false

checkpoint:
enable: true
initial_load_path: hf://${model_name}
initial_load_in_hf: true
last_save_in_hf: true
interval: 500
async_mode: "disabled"

activation_checkpoint:
mode: selective
selective_ac_option: op

# profiling:
# enable_profiling: false

# metrics:
# log_freq: 10
# enable_tensorboard: true
# save_tb_folder: "tb"
26 changes: 20 additions & 6 deletions src/forge/data/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import json
from typing import Any
from typing import Any, Optional

import jinja2
from jinja2 import StrictUndefined
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
self._infer_bos_eos_tokens()
self._infer_should_add_bos_eos()

def _get_token_from_config(self, config: dict[str, Any], key: str) -> str:
def _get_token_from_config(self, config: dict[str, Any], key: str) -> Optional[str]:
"""
HF BOS/EOS tokens are either stored as e.g. {'bos_token': 5}
or {'bos_token': {'content': 5, ...}}. This utility handles both.
Expand All @@ -72,7 +72,7 @@ def _get_token_from_config(self, config: dict[str, Any], key: str) -> str:
raise ValueError(f"Could not parse {key} from config")
token = token["content"]
else:
if not isinstance(token, str):
if token is not None and not isinstance(token, str):
raise ValueError(f"Could not parse {key} from config")
return token

Expand Down Expand Up @@ -137,7 +137,12 @@ def encode(
list[int]: The list of token ids.
"""
token_ids = self.tokenizer.encode(text).ids
if add_bos and not self.hf_adds_bos and self.bos_token not in text:
if (
add_bos
and not self.hf_adds_bos
and self.bos_token is not None
and self.bos_token not in text
):
token_ids.insert(0, self.bos_id)
if add_eos and not self.hf_adds_eos:
token_ids.append(self.eos_id)
Expand Down Expand Up @@ -262,8 +267,14 @@ def extract_top_level_variables(self, config):
def render_template(
self, messages: list[dict[str, str]], add_eos: bool = True
) -> str:
# Need to set tool_calls to something for qwen chat_template
if self.base_tokenizer.config["tokenizer_class"] == "Qwen2Tokenizer":
for message in messages:
if "tool_calls" not in message:
message["tool_calls"] = {}
rendered = self.template.render(
messages=messages,
tools=None,
add_generation_prompt=add_eos,
**self.special_tokens_mapping, # We assume that the naming is consistent
**self.top_level_variables,
Expand Down Expand Up @@ -291,10 +302,13 @@ def tokenize_messages(
add_eos=add_eos if i == len(messages) - 1 else False,
)

current_tokens = self.base_tokenizer.encode(rendered, add_eos=False)
current_tokens = self.base_tokenizer.encode(
rendered, add_bos=False, add_eos=False
)

if (
self.base_tokenizer.bos_token in rendered
self.base_tokenizer.bos_token is not None
and self.base_tokenizer.bos_token in rendered
and self.base_tokenizer.hf_adds_bos
):
del current_tokens[0]
Expand Down
Loading