Skip to content

Commit 2c2eba5

Browse files
carmoccarasbt
authored andcommitted
lit_config.json -> model_config.yaml (#1096)
1 parent 2fd90af commit 2c2eba5

39 files changed

+105
-87
lines changed

eval/lm_eval_harness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def run_eval_harness(
157157
check_valid_checkpoint_dir(checkpoint_dir)
158158
tokenizer = Tokenizer(checkpoint_dir)
159159

160-
config = Config.from_json(checkpoint_dir / "lit_config.json")
160+
config = Config.from_file(checkpoint_dir / "model_config.yaml")
161161

162162
checkpoint_path = checkpoint_dir / "lit_model.pth"
163163

litgpt/chat/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def main(
132132

133133
check_valid_checkpoint_dir(checkpoint_dir)
134134

135-
config = Config.from_json(checkpoint_dir / "lit_config.json")
135+
config = Config.from_file(checkpoint_dir / "model_config.yaml")
136136

137137
checkpoint_path = checkpoint_dir / "lit_model.pth"
138138

litgpt/config.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
22

3-
import json
3+
import yaml
44
from copy import deepcopy
55
from dataclasses import dataclass, field
66
from pathlib import Path
@@ -120,8 +120,9 @@ def from_name(cls, name: str, **kwargs: Any) -> Self:
120120
return cls(**conf_dict)
121121

122122
@classmethod
123-
def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
123+
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
124124
with open(path, encoding="utf-8") as fp:
125+
<<<<<<< HEAD
125126
json_kwargs = json.load(fp)
126127

127128
if "condense_ratio" in json_kwargs: # legacy name
@@ -139,24 +140,29 @@ def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
139140

140141
json_kwargs.update(kwargs)
141142
return cls(**json_kwargs)
143+
=======
144+
file_kwargs = yaml.safe_load(fp)
145+
file_kwargs.update(kwargs)
146+
return cls(**file_kwargs)
147+
>>>>>>> 674f315 (lit_config.json -> model_config.yaml (#1096))
142148

143149
@classmethod
144150
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
145-
"""Automatically load `lit_config.json` and if it doesn't exist - a matching config from `litgpt/config.py`."""
146-
if (config_path := path / "lit_config.json").is_file():
147-
return cls.from_json(config_path, **kwargs)
151+
"""Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
152+
if (config_path := path / "model_config.yaml").is_file():
153+
return cls.from_file(config_path, **kwargs)
148154
if (model_name := path.name) in name_to_config:
149155
return cls.from_name(model_name, **kwargs)
150-
raise FileNotFoundError(f"For {str(path)!r} neither 'lit_config.json' nor matching config exists.")
156+
raise FileNotFoundError(f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists.")
151157

152158
@property
153159
def mlp_class(self) -> Type:
154-
# `self.mlp_class_name` cannot be the type to keep the config json serializable
160+
# `self.mlp_class_name` cannot be the type to keep the config serializable
155161
return getattr(litgpt.model, self.mlp_class_name)
156162

157163
@property
158164
def norm_class(self) -> Type:
159-
# `self.norm_class_name` cannot be the type to keep the config json serializable
165+
# `self.norm_class_name` cannot be the type to keep the config serializable
160166
if self.norm_class_name == "RMSNorm":
161167
from functools import partial
162168

litgpt/generate/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def main(
6161

6262
check_valid_checkpoint_dir(checkpoint_dir)
6363

64-
config = Config.from_json(checkpoint_dir / "lit_config.json")
64+
config = Config.from_file(checkpoint_dir / "model_config.yaml")
6565

6666
checkpoint_path = checkpoint_dir / "lit_model.pth"
6767

litgpt/generate/adapter_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def main(
6161

6262
check_valid_checkpoint_dir(checkpoint_dir)
6363

64-
config = Config.from_json(checkpoint_dir / "lit_config.json")
64+
config = Config.from_file(checkpoint_dir / "model_config.yaml")
6565

6666
checkpoint_path = checkpoint_dir / "lit_model.pth"
6767

litgpt/generate/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def main(
134134

135135
check_valid_checkpoint_dir(checkpoint_dir)
136136

137-
config = Config.from_json(checkpoint_dir / "lit_config.json")
137+
config = Config.from_file(checkpoint_dir / "model_config.yaml")
138138

139139
checkpoint_path = checkpoint_dir / "lit_model.pth"
140140

litgpt/generate/full.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def main(
6060

6161
check_valid_checkpoint_dir(checkpoint_dir)
6262

63-
config = Config.from_json(checkpoint_dir / "lit_config.json")
63+
config = Config.from_file(checkpoint_dir / "model_config.yaml")
6464

6565
checkpoint_path = finetuned_path
6666

litgpt/generate/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def main(
7070

7171
check_valid_checkpoint_dir(checkpoint_dir)
7272

73-
config = Config.from_json(
74-
checkpoint_dir / "lit_config.json",
73+
config = Config.from_file(
74+
checkpoint_dir / "model_config.yaml",
7575
lora_r=lora_r,
7676
lora_alpha=lora_alpha,
7777
lora_dropout=lora_dropout,

litgpt/generate/sequentially.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def main(
159159

160160
check_valid_checkpoint_dir(checkpoint_dir)
161161

162-
config = Config.from_json(checkpoint_dir / "lit_config.json")
162+
config = Config.from_file(checkpoint_dir / "model_config.yaml")
163163

164164
checkpoint_path = checkpoint_dir / "lit_model.pth"
165165

litgpt/generate/tp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def main(
138138

139139
check_valid_checkpoint_dir(checkpoint_dir)
140140

141-
config = Config.from_json(checkpoint_dir / "lit_config.json")
141+
config = Config.from_file(checkpoint_dir / "model_config.yaml")
142142

143143
model_file = "lit_model.pth"
144144
checkpoint_path = checkpoint_dir / model_file

0 commit comments

Comments
 (0)