5
5
import os
6
6
import tempfile
7
7
from pathlib import Path
8
+ from tempfile import NamedTemporaryFile
8
9
from typing import Union
9
10
from urllib .parse import urlparse
10
11
@@ -158,7 +159,9 @@ def plugin_set_cfg(cfg: DictDefault):
158
159
plugin_manager .cfg = cfg
159
160
160
161
161
- def load_cfg (config : Union [str , Path ] = Path ("examples/" ), ** kwargs ) -> DictDefault :
162
+ def load_cfg (
163
+ config : str | Path | DictDefault = Path ("examples/" ), ** kwargs
164
+ ) -> DictDefault :
162
165
"""
163
166
Loads the `axolotl` configuration stored at `config`, validates it, and performs
164
167
various setup.
@@ -170,13 +173,24 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
170
173
Returns:
171
174
`DictDefault` mapping configuration keys to values.
172
175
"""
173
- config = check_remote_config (config )
174
- if Path (config ).is_dir ():
175
- config = choose_config (Path (config ))
176
-
177
- # Load the config from the yaml file
178
- with open (config , encoding = "utf-8" ) as file :
179
- cfg : DictDefault = DictDefault (yaml .safe_load (file ))
176
+ if isinstance (config , (str , Path )):
177
+ config = check_remote_config (config )
178
+ if Path (config ).is_dir ():
179
+ config = choose_config (Path (config ))
180
+
181
+ # Load the config from the yaml file
182
+ with open (config , encoding = "utf-8" ) as file :
183
+ cfg : DictDefault = DictDefault (yaml .safe_load (file ))
184
+
185
+ cfg .axolotl_config_path = config
186
+ else :
187
+ cfg = config
188
+ with NamedTemporaryFile (
189
+ mode = "w" , delete = False , suffix = ".yml" , prefix = "axolotl_config_"
190
+ ) as temp_file :
191
+ temp_file .write (yaml .dump (config .to_dict ()))
192
+ temp_file .close ()
193
+ cfg .axolotl_config_path = temp_file .name
180
194
181
195
# If there are any options passed in the cli, if it is something that seems valid
182
196
# from the yaml, then overwrite the value
@@ -190,8 +204,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
190
204
else :
191
205
cfg [k ] = kwargs [k ]
192
206
193
- cfg .axolotl_config_path = config
194
-
195
207
try :
196
208
device_props = torch .cuda .get_device_properties ("cuda" )
197
209
gpu_version = "sm_" + str (device_props .major ) + str (device_props .minor )
0 commit comments