Skip to content

Commit de94f54

Browse files
marcopeixelephaint
andauthored
FIX: Add logic to load custom models when using ReduceLROnPlateau (#1340)
Co-authored-by: Olivier Sprangers <45119856+elephaint@users.noreply.github.com>
1 parent 0614b6c commit de94f54

File tree

2 files changed

+52
-31
lines changed

2 files changed

+52
-31
lines changed

nbs/core.ipynb

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,7 +1481,14 @@
14811481
" continue\n",
14821482
"\n",
14831483
" model_name = repr(model)\n",
1484-
" model_class_name = model.__class__.__name__.lower()\n",
1484+
" if model.__class__.__name__.lower() in MODEL_FILENAME_DICT:\n",
1485+
" model_class_name = model.__class__.__name__.lower()\n",
1486+
" elif model.__class__.__base__.__name__.lower() in MODEL_FILENAME_DICT:\n",
1487+
" model_class_name = model.__class__.__base__.__name__.lower()\n",
1488+
" else:\n",
1489+
" raise ValueError(\n",
1490+
" f\"Model {model.__class__.__name__} is not supported for saving.\"\n",
1491+
" )\n",
14851492
" alias_to_model[model_name] = model_class_name\n",
14861493
" count_names[model_name] = count_names.get(model_name, -1) + 1\n",
14871494
" model.save(f\"{path}/{model_name}_{count_names[model_name]}.ckpt\")\n",
@@ -1529,7 +1536,7 @@
15291536
"\n",
15301537
" with fsspec.open(f\"{path}/configuration.pkl\", \"wb\") as f:\n",
15311538
" pickle.dump(config_dict, f)\n",
1532-
"\n",
1539+
" \n",
15331540
" @staticmethod\n",
15341541
" def load(path, verbose=False, **kwargs):\n",
15351542
" \"\"\"Load NeuralForecast\n",
@@ -1550,50 +1557,60 @@
15501557
" Instantiated `NeuralForecast` class.\n",
15511558
" \"\"\"\n",
15521559
" # Standarize path without '/'\n",
1553-
" if path[-1] == '/':\n",
1560+
" if path[-1] == \"/\":\n",
15541561
" path = path[:-1]\n",
1555-
" \n",
1562+
"\n",
15561563
" fs, _, _ = fsspec.get_fs_token_paths(path)\n",
1557-
" files = [f.split('/')[-1] for f in fs.ls(path) if fs.isfile(f)]\n",
1564+
" files = [f.split(\"/\")[-1] for f in fs.ls(path) if fs.isfile(f)]\n",
15581565
"\n",
15591566
" # Load models\n",
1560-
" models_ckpt = [f for f in files if f.endswith('.ckpt')]\n",
1567+
" models_ckpt = [f for f in files if f.endswith(\".ckpt\")]\n",
15611568
" if len(models_ckpt) == 0:\n",
1562-
" raise Exception('No model found in directory.') \n",
1563-
" \n",
1564-
" if verbose: print(10 * '-' + ' Loading models ' + 10 * '-')\n",
1569+
" raise Exception(\"No model found in directory.\")\n",
1570+
"\n",
1571+
" if verbose:\n",
1572+
" print(10 * \"-\" + \" Loading models \" + 10 * \"-\")\n",
15651573
" models = []\n",
15661574
" try:\n",
1567-
" with fsspec.open(f'{path}/alias_to_model.pkl', 'rb') as f:\n",
1575+
" with fsspec.open(f\"{path}/alias_to_model.pkl\", \"rb\") as f:\n",
15681576
" alias_to_model = pickle.load(f)\n",
15691577
" except FileNotFoundError:\n",
15701578
" alias_to_model = {}\n",
1579+
" \n",
15711580
" for model in models_ckpt:\n",
1572-
" model_name = '_'.join(model.split('_')[:-1])\n",
1581+
" model_name = \"_\".join(model.split(\"_\")[:-1])\n",
15731582
" model_class_name = alias_to_model.get(model_name, model_name)\n",
1574-
" loaded_model = MODEL_FILENAME_DICT[model_class_name].load(f'{path}/{model}', **kwargs)\n",
1583+
" loaded_model = MODEL_FILENAME_DICT[model_class_name].load(\n",
1584+
" f\"{path}/{model}\", **kwargs\n",
1585+
" )\n",
15751586
" loaded_model.alias = model_name\n",
15761587
" models.append(loaded_model)\n",
1577-
" if verbose: print(f\"Model {model_name} loaded.\")\n",
1588+
" if verbose:\n",
1589+
" print(f\"Model {model_name} loaded.\")\n",
15781590
"\n",
1579-
" if verbose: print(10*'-' + ' Loading dataset ' + 10*'-')\n",
1591+
" if verbose:\n",
1592+
" print(10 * \"-\" + \" Loading dataset \" + 10 * \"-\")\n",
15801593
" # Load dataset\n",
15811594
" try:\n",
15821595
" with fsspec.open(f\"{path}/dataset.pkl\", \"rb\") as f:\n",
15831596
" dataset = pickle.load(f)\n",
1584-
" if verbose: print('Dataset loaded.')\n",
1597+
" if verbose:\n",
1598+
" print(\"Dataset loaded.\")\n",
15851599
" except FileNotFoundError:\n",
15861600
" dataset = None\n",
1587-
" if verbose: print('No dataset found in directory.')\n",
1601+
" if verbose:\n",
1602+
" print(\"No dataset found in directory.\")\n",
15881603
"\n",
1589-
" if verbose: print(10*'-' + ' Loading configuration ' + 10*'-')\n",
1604+
" if verbose:\n",
1605+
" print(10 * \"-\" + \" Loading configuration \" + 10 * \"-\")\n",
15901606
" # Load configuration\n",
15911607
" try:\n",
15921608
" with fsspec.open(f\"{path}/configuration.pkl\", \"rb\") as f:\n",
15931609
" config_dict = pickle.load(f)\n",
1594-
" if verbose: print('Configuration loaded.')\n",
1610+
" if verbose:\n",
1611+
" print(\"Configuration loaded.\")\n",
15951612
" except FileNotFoundError:\n",
1596-
" raise Exception('No configuration found in directory.')\n",
1613+
" raise Exception(\"No configuration found in directory.\")\n",
15971614
"\n",
15981615
" # in 1.6.4, `local_scaler_type` / `scalers_` lived on the dataset.\n",
15991616
" # in order to preserve backwards-compatibility, we check to see if these are found on the dataset\n",
@@ -1604,34 +1621,30 @@
16041621
" # Create NeuralForecast object\n",
16051622
" neuralforecast = NeuralForecast(\n",
16061623
" models=models,\n",
1607-
" freq=config_dict['freq'],\n",
1624+
" freq=config_dict[\"freq\"],\n",
16081625
" local_scaler_type=config_dict.get(\"local_scaler_type\", default_scalar_type),\n",
16091626
" )\n",
16101627
"\n",
1611-
" attr_to_default = {\n",
1612-
" \"id_col\": \"unique_id\",\n",
1613-
" \"time_col\": \"ds\",\n",
1614-
" \"target_col\": \"y\"\n",
1615-
" }\n",
1628+
" attr_to_default = {\"id_col\": \"unique_id\", \"time_col\": \"ds\", \"target_col\": \"y\"}\n",
16161629
" for attr, default in attr_to_default.items():\n",
16171630
" setattr(neuralforecast, attr, config_dict.get(attr, default))\n",
16181631
" # only restore attribute if available\n",
1619-
" for attr in ['prediction_intervals', '_cs_df']:\n",
1632+
" for attr in [\"prediction_intervals\", \"_cs_df\"]:\n",
16201633
" setattr(neuralforecast, attr, config_dict.get(attr, None))\n",
16211634
"\n",
16221635
" # Dataset\n",
16231636
" if dataset is not None:\n",
16241637
" neuralforecast.dataset = dataset\n",
16251638
" restore_attrs = [\n",
1626-
" 'uids',\n",
1627-
" 'last_dates',\n",
1628-
" 'ds',\n",
1639+
" \"uids\",\n",
1640+
" \"last_dates\",\n",
1641+
" \"ds\",\n",
16291642
" ]\n",
16301643
" for attr in restore_attrs:\n",
16311644
" setattr(neuralforecast, attr, config_dict[attr])\n",
16321645
"\n",
16331646
" # Fitted flag\n",
1634-
" neuralforecast._fitted = config_dict['_fitted']\n",
1647+
" neuralforecast._fitted = config_dict[\"_fitted\"]\n",
16351648
"\n",
16361649
" neuralforecast.scalers_ = config_dict.get(\"scalers_\", default_scalars_)\n",
16371650
"\n",

neuralforecast/core.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1488,7 +1488,14 @@ def save(
14881488
continue
14891489

14901490
model_name = repr(model)
1491-
model_class_name = model.__class__.__name__.lower()
1491+
if model.__class__.__name__.lower() in MODEL_FILENAME_DICT:
1492+
model_class_name = model.__class__.__name__.lower()
1493+
elif model.__class__.__base__.__name__.lower() in MODEL_FILENAME_DICT:
1494+
model_class_name = model.__class__.__base__.__name__.lower()
1495+
else:
1496+
raise ValueError(
1497+
f"Model {model.__class__.__name__} is not supported for saving."
1498+
)
14921499
alias_to_model[model_name] = model_class_name
14931500
count_names[model_name] = count_names.get(model_name, -1) + 1
14941501
model.save(f"{path}/{model_name}_{count_names[model_name]}.ckpt")
@@ -1577,6 +1584,7 @@ def load(path, verbose=False, **kwargs):
15771584
alias_to_model = pickle.load(f)
15781585
except FileNotFoundError:
15791586
alias_to_model = {}
1587+
15801588
for model in models_ckpt:
15811589
model_name = "_".join(model.split("_")[:-1])
15821590
model_class_name = alias_to_model.get(model_name, model_name)

0 commit comments

Comments
 (0)