11# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
22
3- import json
3+ import yaml
44from copy import deepcopy
55from dataclasses import dataclass , field
66from pathlib import Path
@@ -120,8 +120,9 @@ def from_name(cls, name: str, **kwargs: Any) -> Self:
120120 return cls (** conf_dict )
121121
122122 @classmethod
123- def from_json (cls , path : Union [str , Path ], ** kwargs : Any ) -> Self :
123+ def from_file (cls , path : Union [str , Path ], ** kwargs : Any ) -> Self :
124124 with open (path , encoding = "utf-8" ) as fp :
125+ < << << << HEAD
125126 json_kwargs = json .load (fp )
126127
127128 if "condense_ratio" in json_kwargs : # legacy name
@@ -139,24 +140,29 @@ def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
139140
140141 json_kwargs .update (kwargs )
141142 return cls (** json_kwargs )
143+ == == == =
144+ file_kwargs = yaml .safe_load (fp )
145+ file_kwargs .update (kwargs )
146+ return cls (** file_kwargs )
147+ >> >> >> > 674 f315 (lit_config .json - > model_config .yaml (#1096))
142148
143149 @classmethod
144150 def from_checkpoint (cls , path : Path , ** kwargs : Any ) -> Self :
145- """Automatically load `lit_config.json ` and if it doesn't exist - a matching config from `litgpt/config.py`."""
146- if (config_path := path / "lit_config.json " ).is_file ():
147- return cls .from_json (config_path , ** kwargs )
151+ """Automatically load `model_config.yaml ` and if it doesn't exist - a matching config from `litgpt/config.py`."""
152+ if (config_path := path / "model_config.yaml " ).is_file ():
153+ return cls .from_file (config_path , ** kwargs )
148154 if (model_name := path .name ) in name_to_config :
149155 return cls .from_name (model_name , ** kwargs )
150- raise FileNotFoundError (f"For { str (path )!r} neither 'lit_config.json ' nor matching config exists." )
156+ raise FileNotFoundError (f"For { str (path )!r} neither 'model_config.yaml ' nor matching config exists." )
151157
152158 @property
153159 def mlp_class (self ) -> Type :
154- # `self.mlp_class_name` cannot be the type to keep the config json serializable
160+ # `self.mlp_class_name` cannot be the type to keep the config serializable
155161 return getattr (litgpt .model , self .mlp_class_name )
156162
157163 @property
158164 def norm_class (self ) -> Type :
159- # `self.norm_class_name` cannot be the type to keep the config json serializable
165+ # `self.norm_class_name` cannot be the type to keep the config serializable
160166 if self .norm_class_name == "RMSNorm" :
161167 from functools import partial
162168
0 commit comments