Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ ref_model:
context_parallel_degree: 1
expert_parallel_degree: 1
checkpoint:
enable: true
initial_load_path: hf://${model}
initial_load_in_hf: true

Expand Down
8 changes: 3 additions & 5 deletions src/forge/actors/reference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ReferenceModel(ForgeActor):
def __post_init__(self):
"""Initializes config types and env variables."""
super().__init__()

# Instantiate dict fields
for f in fields(self):
attr = getattr(self, f.name)
Expand All @@ -60,13 +61,9 @@ def __post_init__(self):
f"{f.name} should be a {f.type} type or a dict like object"
)

"""
torchrun normally hands env variables, but we need to do it ourselves
in monarch for now.
"""
self.step = 0
self.rank = current_rank().rank
self.size = math.prod(current_size().values())
self.step = 0

env = {
"RANK": str(self.rank),
Expand All @@ -86,6 +83,7 @@ def __post_init__(self):
async def setup(self):
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually staring at this in the trainer side .. It's unclear at a glance how the checkpointer.load is associated with loading the HF model weights

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's not very clear without digging into the TorchTitan checkpointing code here: https://github.com/pytorch/torchtitan/blob/5b5d46856b400c8550989415bee91473aab4f921/torchtitan/components/checkpoint.py#L523

All the information is taken from the config and instantiated into the CheckpointManager. Then the load call only takes a "step", which in our case isn't needed b/c it should be a static model every time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙃

self.model = self.engine.model_parts[0] # No pipeline parallelism yet
self.model.eval()

Expand Down
59 changes: 30 additions & 29 deletions src/forge/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,36 @@ def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: list[str]) -> DictC
return OmegaConf.merge(yaml_conf, cli_conf)


def _resolve_hf_model_path(hf_url: str) -> str:
"""Resolve HuggingFace model URL to local path using snapshot_download."""
if not hf_url.startswith("hf://"):
raise ValueError(f"Invalid HuggingFace URL format: {hf_url}")

repo_name = hf_url.replace("hf://", "")
if not repo_name:
raise ValueError("Empty repository name in HuggingFace URL")

try:
# First, try to get from cache only (local_files_only=True)
# This checks if the model is already cached without downloading
try:
local_dir = snapshot_download(
repo_name, revision="main", local_files_only=True
)
return local_dir
except LocalEntryNotFoundError:
# Model not in cache, download it (local_files_only=False)
local_dir = snapshot_download(
repo_name, revision="main", local_files_only=False
)
return local_dir

except Exception as e:
raise Exception(
f"Failed to resolve HuggingFace model '{repo_name}': {e}"
) from e


def resolve_hf_hub_paths(cfg: DictConfig) -> DictConfig:
"""
Resolves HuggingFace Hub URLs in configuration by downloading models and
Expand Down Expand Up @@ -168,35 +198,6 @@ def resolve_hf_hub_paths(cfg: DictConfig) -> DictConfig:
if not OmegaConf.is_config(cfg):
raise ValueError(f"Input must be an OmegaConf config object, got {type(cfg)}")

def _resolve_hf_model_path(hf_url: str) -> str:
"""Resolve HuggingFace model URL to local path using snapshot_download."""
if not hf_url.startswith("hf://"):
raise ValueError(f"Invalid HuggingFace URL format: {hf_url}")

repo_name = hf_url.replace("hf://", "")
if not repo_name:
raise ValueError("Empty repository name in HuggingFace URL")

try:
# First, try to get from cache only (local_files_only=True)
# This checks if the model is already cached without downloading
try:
local_dir = snapshot_download(
repo_name, revision="main", local_files_only=True
)
return local_dir
except LocalEntryNotFoundError:
# Model not in cache, download it (local_files_only=False)
local_dir = snapshot_download(
repo_name, revision="main", local_files_only=False
)
return local_dir

except Exception as e:
raise Exception(
f"Failed to resolve HuggingFace model '{repo_name}': {e}"
) from e

def _recursively_resolve_paths(obj: Any) -> Any:
"""Recursively resolve hf:// paths in nested data structures."""
if isinstance(obj, str) and obj.startswith("hf://"):
Expand Down
Loading
Loading