2626from fsspec .implementations .local import LocalFileSystem
2727from omegaconf import Container , OmegaConf
2828from omegaconf .dictconfig import DictConfig
29-
3029from torch .utils .data import DataLoader
3130
3231from pytorch_lightning import LightningModule , Trainer
@@ -484,16 +483,20 @@ class Options(str, Enum):
484483 option1name = "option1val"
485484 option2name = "option2val"
486485 option3name = "option3val"
486+
487487 hparams = dict (
488- batch_size = 32 , learning_rate = 0.001 , data_root = "./any/path/here" , nested = dict (any_num = 123 , anystr = "abcd" ),
489- switch = Options .option3name
488+ batch_size = 32 ,
489+ learning_rate = 0.001 ,
490+ data_root = "./any/path/here" ,
491+ nested = dict (any_num = 123 , anystr = "abcd" ),
492+ switch = Options .option3name ,
490493 )
491494 path_yaml = os .path .join (tmpdir , "testing-hparams.yaml" )
492495
493496 def compare (loadedParams , defaultParams : dict ):
494497 assert isinstance (loadedParams , (dict , DictConfig ))
495498 assert loadedParams .keys () == defaultParams .keys ()
496- for k ,v in defaultParams .items ():
499+ for k , v in defaultParams .items ():
497500 if isinstance (v , Enum ):
498501 assert v .name == loadedParams [k ]
499502 else :
@@ -506,7 +509,7 @@ def compare(loadedParams, defaultParams: dict):
506509 compare (load_hparams_from_yaml (path_yaml , use_omegaconf = False ), hparams )
507510
508511 save_hparams_to_yaml (path_yaml , AttributeDict (hparams ))
509- compare (load_hparams_from_yaml (path_yaml , use_omegaconf = False ),hparams )
512+ compare (load_hparams_from_yaml (path_yaml , use_omegaconf = False ), hparams )
510513
511514 save_hparams_to_yaml (path_yaml , OmegaConf .create (hparams ))
512515 compare (load_hparams_from_yaml (path_yaml ), hparams )
0 commit comments