Skip to content

Commit fa9f8ce

Browse files
committed
tmp
1 parent 012de9b commit fa9f8ce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+1139
-1234
lines changed

src/fairseq2/checkpoint/_manager.py

Lines changed: 135 additions & 135 deletions
Large diffs are not rendered by default.

src/fairseq2/checkpoint/_metadata_provider.py

Lines changed: 19 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from __future__ import annotations
88

9-
import json
109
from abc import ABC, abstractmethod
1110
from pathlib import Path
1211
from typing import Iterable, final
@@ -18,10 +17,8 @@
1817
AssetMetadataSaveError,
1918
CachedAssetMetadataProvider,
2019
)
21-
from fairseq2.file_system import FileMode, FileSystem
20+
from fairseq2.file_system import FileSystem
2221
from fairseq2.gang import GangError, Gangs
23-
from fairseq2.models.llama import LLAMA_MODEL_FAMILY, LLaMAConfig
24-
from fairseq2.models.llama.integ import convert_to_hg_llama_config
2522
from fairseq2.utils.structured import unstructure
2623
from fairseq2.utils.yaml import YamlDumper
2724

@@ -52,37 +49,7 @@ def __init__(
5249

5350
def save(self, model_family: str, model_config: object) -> None:
5451
if self._gangs.root.rank == 0:
55-
unstructured_config = unstructure(model_config)
56-
57-
metadata: dict[str, object] = {
58-
"name": "checkpoint",
59-
"model_family": model_family,
60-
"model_config": {
61-
"_set_": unstructured_config,
62-
},
63-
}
64-
65-
if self._gangs.tp.size != 1:
66-
metadata["num_shards"] = self._gangs.tp.size
67-
68-
metadata_file = self._checkpoint_dir.joinpath("model.yaml")
69-
70-
def save_error() -> AssetMetadataSaveError:
71-
return AssetMetadataSaveError(
72-
f"The checkpoint metadata cannot be saved to the '{metadata_file}' file. See the nested exception for details."
73-
)
74-
75-
try:
76-
self._file_system.make_directory(metadata_file.parent)
77-
except OSError as ex:
78-
raise save_error() from ex
79-
80-
try:
81-
self._yaml_dumper.dump(metadata, metadata_file)
82-
except OSError as ex:
83-
raise save_error() from ex
84-
85-
self._save_huggingface_config(model_family, model_config)
52+
self._save_asset_card(model_family, model_config)
8653

8754
try:
8855
self._gangs.root.barrier()
@@ -91,40 +58,36 @@ def save_error() -> AssetMetadataSaveError:
9158
"The collective barrier after the checkpoint metadata save operation has failed. See the nested exception for details."
9259
) from ex
9360

94-
def _save_huggingface_config(self, model_family: str, model_config: object) -> None:
95-
if model_family != LLAMA_MODEL_FAMILY:
96-
return
61+
def _save_asset_card(self, model_family: str, model_config: object) -> None:
62+
unstructured_model_config = unstructure(model_config)
9763

98-
if not isinstance(model_config, LLaMAConfig):
99-
raise TypeError(
100-
f"`model_config` must be of type `{LLaMAConfig}`, but is of type `{type(model_config)}` instead."
101-
)
64+
metadata: dict[str, object] = {
65+
"name": "checkpoint",
66+
"model_family": model_family,
67+
"model_config": {
68+
"_set_": unstructured_model_config,
69+
},
70+
}
10271

103-
hg_config = convert_to_hg_llama_config(model_config)
72+
if self._gangs.tp.size != 1:
73+
metadata["num_shards"] = self._gangs.tp.size
10474

105-
hg_config_file = self._checkpoint_dir.joinpath("cc/config.json")
75+
metadata_file = self._checkpoint_dir.joinpath("model.yaml")
10676

10777
def save_error() -> AssetMetadataSaveError:
10878
return AssetMetadataSaveError(
109-
f"The Hugging Face model configuration cannot be saved to the '{hg_config_file}' file. See the nested exception for details."
79+
f"The checkpoint metadata cannot be saved to the '{metadata_file}' file. See the nested exception for details."
11080
)
11181

11282
try:
113-
self._file_system.make_directory(hg_config_file.parent)
114-
except OSError as ex:
115-
raise save_error() from ex
116-
117-
try:
118-
fp = self._file_system.open_text(hg_config_file, mode=FileMode.WRITE)
83+
self._file_system.make_directory(metadata_file.parent)
11984
except OSError as ex:
12085
raise save_error() from ex
12186

12287
try:
123-
json.dump(hg_config, fp, indent=2, sort_keys=True)
88+
self._yaml_dumper.dump(metadata, metadata_file)
12489
except OSError as ex:
12590
raise save_error() from ex
126-
finally:
127-
fp.close()
12891

12992

13093
@final
@@ -170,23 +133,10 @@ def _load_cache(self) -> dict[str, dict[str, object]]:
170133
"The checkpoint metadata does not have a 'checkpoint@' entry."
171134
) from None
172135

173-
num_shards = metadata.get("num_shards", 1)
174-
175-
if not isinstance(num_shards, int) or num_shards < 1:
176-
raise AssetMetadataLoadError(
177-
"The 'num_shards' value in the checkpoint metadata is not a positive integer."
178-
)
179-
180-
if num_shards == 1:
181-
filename = "model.pt"
182-
else:
183-
# TODO: Fix once DownloadManager refactoring complete!
184-
filename = "model.0{shard_idx}.pt"
185-
186136
def add_checkpoint_metadata(name: str, step_nr: int) -> None:
187-
file = self._checkpoint_dir.joinpath(f"step_{step_nr}/{filename}")
137+
path = self._checkpoint_dir.joinpath(f"step_{step_nr}")
188138

189-
cache[name] = {"base": "checkpoint", "checkpoint": str(file)}
139+
cache[name] = {"base": "checkpoint", "checkpoint": str(path)}
190140

191141
max_step_nr = -1
192142

src/fairseq2/cli/_setup.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
from fairseq2.chatbots import UnknownChatbotError
1010
from fairseq2.cli.commands.assets import ListAssetsHandler, ShowAssetHandler
1111
from fairseq2.cli.commands.chatbot import RunChatbotHandler
12-
from fairseq2.cli.commands.llama import (
13-
ConvertLLaMACheckpointHandler,
14-
WriteHFLLaMAConfigHandler,
15-
)
12+
from fairseq2.cli.commands.llama import ConvertLLaMACheckpointHandler
1613
from fairseq2.cli.commands.recipe import RecipeCommandHandler
1714
from fairseq2.context import RuntimeContext
1815
from fairseq2.data.text.tokenizers import (
@@ -37,7 +34,6 @@
3734
from fairseq2.metrics.text import UnknownBleuTokenizerError
3835
from fairseq2.models import (
3936
InvalidModelTypeError,
40-
ShardedModelLoadError,
4137
UnknownModelArchitectureError,
4238
UnknownModelError,
4339
UnknownModelFamilyError,
@@ -120,7 +116,7 @@ def setup_cli(context: RuntimeContext) -> Cli:
120116

121117
signature = "extension_function(context: RuntimeContext, cli: Cli) -> None"
122118

123-
run_extensions("fairseq2.cli", signature, context, cli)
119+
run_extensions("fairseq2.cli", signature, cli, context)
124120

125121
return cli
126122

@@ -179,12 +175,6 @@ def _register_llama_cli(cli: Cli) -> None:
179175
help="convert fairseq2 LLaMA checkpoints to reference checkpoints",
180176
)
181177

182-
group.add_command(
183-
name="write_hf_config",
184-
handler=WriteHFLLaMAConfigHandler(),
185-
help="write fairseq2 LLaMA configurations in Hugging Face format",
186-
)
187-
188178

189179
def _register_lm_cli(cli: Cli) -> None:
190180
group = cli.add_group("lm", help="language model recipes")
@@ -357,7 +347,6 @@ def _register_user_error_types(cli: Cli) -> None:
357347
cli.register_user_error_type(ModelCompilationNotSupportedError)
358348
cli.register_user_error_type(ModelParallelismNotSupportedError)
359349
cli.register_user_error_type(ModelPathNotFoundError)
360-
cli.register_user_error_type(ShardedModelLoadError)
361350
cli.register_user_error_type(UnknownBeamSearchAlgorithmError)
362351
cli.register_user_error_type(UnknownBleuTokenizerError)
363352
cli.register_user_error_type(UnknownChatbotError)

src/fairseq2/cli/commands/llama/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,3 @@
99
from fairseq2.cli.commands.llama._convert_checkpoint import (
1010
ConvertLLaMACheckpointHandler as ConvertLLaMACheckpointHandler,
1111
)
12-
from fairseq2.cli.commands.llama._write_hf_config import (
13-
WriteHFLLaMAConfigHandler as WriteHFLLaMAConfigHandler,
14-
)

src/fairseq2/cli/commands/llama/_convert_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def file_write_error() -> CliCommandError:
232232
"dim": model_config.model_dim,
233233
"n_layers": model_config.num_layers,
234234
"n_heads": model_config.num_attn_heads,
235-
"multiple_of": model_config.ffn_inner_dim_to_multiple,
235+
"multiple_of": model_config.ffn_inner_dim_multiple_of,
236236
"rope_theta": model_config.rope_theta,
237237
"norm_eps": 1e-5,
238238
}

src/fairseq2/cli/commands/llama/_write_hf_config.py

Lines changed: 0 additions & 122 deletions
This file was deleted.

src/fairseq2/gang.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Gangs:
591591
mesh = torch.arange(root_gang.size).view(dp_size, tp_size)
592592

593593
# Get the coordinate of this process in the mesh.
594-
rank_coords = [x.item() for x in torch.where(mesh == root_gang.rank)]
594+
rank_coord = [x.item() for x in torch.where(mesh == root_gang.rank)]
595595

596596
dp_gang: Gang | None = None
597597

@@ -619,7 +619,7 @@ def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Gangs:
619619
else:
620620
for i in range(tp_size):
621621
sub_gang = root_gang.create_gang(mesh[:, i].tolist())
622-
if i == rank_coords[1]:
622+
if i == rank_coord[1]:
623623
dp_gang = sub_gang
624624

625625
if dp_gang is None:
@@ -651,7 +651,7 @@ def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Gangs:
651651
else:
652652
for i in range(dp_size):
653653
sub_gang = root_gang.create_gang(mesh[i, :].tolist())
654-
if i == rank_coords[0]:
654+
if i == rank_coord[0]:
655655
tp_gang = sub_gang
656656

657657
if tp_gang is None:
@@ -701,7 +701,7 @@ def setup_fsdp_gangs(gangs: Gangs, intra_node_size: int | None = None) -> Gangs:
701701
mesh = torch.arange(dp_gang.size).view(inter_node_size, intra_node_size)
702702

703703
# Get the coordinate of this process in the mesh.
704-
rank_coords = [x.item() for x in torch.where(mesh == dp_gang.rank)]
704+
rank_coord = [x.item() for x in torch.where(mesh == dp_gang.rank)]
705705

706706
inter_gang: Gang | None = None
707707

@@ -729,7 +729,7 @@ def setup_fsdp_gangs(gangs: Gangs, intra_node_size: int | None = None) -> Gangs:
729729
else:
730730
for i in range(intra_node_size):
731731
sub_gang = dp_gang.create_gang(mesh[:, i].tolist())
732-
if i == rank_coords[1]:
732+
if i == rank_coord[1]:
733733
inter_gang = sub_gang
734734

735735
if inter_gang is None:
@@ -761,7 +761,7 @@ def setup_fsdp_gangs(gangs: Gangs, intra_node_size: int | None = None) -> Gangs:
761761
else:
762762
for i in range(inter_node_size):
763763
sub_gang = dp_gang.create_gang(mesh[i, :].tolist())
764-
if i == rank_coords[0]:
764+
if i == rank_coord[0]:
765765
intra_gang = sub_gang
766766

767767
if intra_gang is None:

src/fairseq2/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from fairseq2.models._handler import CheckpointConverter as CheckpointConverter
2626
from fairseq2.models._handler import DelegatingModelHandler as DelegatingModelHandler
2727
from fairseq2.models._handler import FsdpApplier as FsdpApplier
28+
from fairseq2.models._handler import HuggingFaceExporter as HuggingFaceExporter
2829
from fairseq2.models._handler import ModelCompiler as ModelCompiler
2930
from fairseq2.models._handler import ModelFactory as ModelFactory
3031
from fairseq2.models._handler import ModelHandler as ModelHandler

0 commit comments

Comments
 (0)