Skip to content

Commit 3924876

Browse files
committed
Remove old interface and deprecate the arguments
1 parent c0a7eb8 commit 3924876

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+75
-1076
lines changed

nbs/common.base_model.ipynb

Lines changed: 38 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,24 @@
2727
"execution_count": null,
2828
"id": "1c7c2ba5-19ee-421e-9252-7224b03f5201",
2929
"metadata": {},
30-
"outputs": [],
30+
"outputs": [
31+
{
32+
"name": "stderr",
33+
"output_type": "stream",
34+
"text": [
35+
"/root/miniconda3/envs/neuralforecast/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
36+
" from .autonotebook import tqdm as notebook_tqdm\n",
37+
"2024-12-11 17:06:11,409\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
38+
"2024-12-11 17:06:11,467\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n"
39+
]
40+
}
41+
],
3142
"source": [
3243
"#| export\n",
3344
"import inspect\n",
3445
"import random\n",
3546
"import warnings\n",
3647
"from contextlib import contextmanager\n",
37-
"from copy import deepcopy\n",
3848
"from dataclasses import dataclass\n",
3949
"\n",
4050
"import fsspec\n",
@@ -121,10 +131,6 @@
121131
" random_seed,\n",
122132
" loss,\n",
123133
" valid_loss,\n",
124-
" optimizer,\n",
125-
" optimizer_kwargs,\n",
126-
" lr_scheduler,\n",
127-
" lr_scheduler_kwargs,\n",
128134
" futr_exog_list,\n",
129135
" hist_exog_list,\n",
130136
" stat_exog_list,\n",
@@ -150,18 +156,6 @@
150156
" self.train_trajectories = []\n",
151157
" self.valid_trajectories = []\n",
152158
"\n",
153-
" # Optimization\n",
154-
" if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
155-
" raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
156-
" self.optimizer = optimizer\n",
157-
" self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}\n",
158-
"\n",
159-
" # lr scheduler\n",
160-
" if lr_scheduler is not None and not issubclass(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):\n",
161-
" raise TypeError(\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
162-
" self.lr_scheduler = lr_scheduler\n",
163-
" self.lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}\n",
164-
"\n",
165159
" # customized by set_configure_optimizers()\n",
166160
" self.config_optimizers = None\n",
167161
"\n",
@@ -412,41 +406,19 @@
412406
"\n",
413407
" def configure_optimizers(self):\n",
414408
" if self.config_optimizers is not None:\n",
409+
" # return the customized optimizer settings if specified\n",
415410
" return self.config_optimizers\n",
416-
" \n",
417-
" if self.optimizer:\n",
418-
" optimizer_signature = inspect.signature(self.optimizer)\n",
419-
" optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
420-
" if 'lr' in optimizer_signature.parameters:\n",
421-
" if 'lr' in optimizer_kwargs:\n",
422-
" warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
423-
" optimizer_kwargs['lr'] = self.learning_rate\n",
424-
" optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n",
425-
" else:\n",
426-
" if self.optimizer_kwargs:\n",
427-
" warnings.warn(\n",
428-
" \"ignoring optimizer_kwargs as the optimizer is not specified\"\n",
429-
" )\n",
430-
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
431411
" \n",
432-
" lr_scheduler = {'frequency': 1, 'interval': 'step'}\n",
433-
" if self.lr_scheduler:\n",
434-
" lr_scheduler_signature = inspect.signature(self.lr_scheduler)\n",
435-
" lr_scheduler_kwargs = deepcopy(self.lr_scheduler_kwargs)\n",
436-
" if 'optimizer' in lr_scheduler_signature.parameters:\n",
437-
" if 'optimizer' in lr_scheduler_kwargs:\n",
438-
" warnings.warn(\"ignoring optimizer passed in lr_scheduler_kwargs, using the model's optimizer\")\n",
439-
" del lr_scheduler_kwargs['optimizer']\n",
440-
" lr_scheduler['scheduler'] = self.lr_scheduler(optimizer=optimizer, **lr_scheduler_kwargs)\n",
441-
" else:\n",
442-
" if self.lr_scheduler_kwargs:\n",
443-
" warnings.warn(\n",
444-
" \"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\"\n",
445-
" ) \n",
446-
" lr_scheduler['scheduler'] = torch.optim.lr_scheduler.StepLR(\n",
412+
" # default choice\n",
413+
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
414+
" scheduler = {\n",
415+
" \"scheduler\": torch.optim.lr_scheduler.StepLR(\n",
447416
" optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5\n",
448-
" )\n",
449-
" return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}\n",
417+
" ),\n",
418+
" \"frequency\": 1,\n",
419+
" \"interval\": \"step\",\n",
420+
" }\n",
421+
" return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
450422
"\n",
451423
" def set_configure_optimizers(\n",
452424
" self, \n",
@@ -528,6 +500,22 @@
528500
" model.load_state_dict(content[\"state_dict\"], strict=True)\n",
529501
" return model"
530502
]
503+
},
504+
{
505+
"cell_type": "code",
506+
"execution_count": null,
507+
"id": "077ea025",
508+
"metadata": {},
509+
"outputs": [],
510+
"source": []
511+
},
512+
{
513+
"cell_type": "code",
514+
"execution_count": null,
515+
"id": "2b36e87a",
516+
"metadata": {},
517+
"outputs": [],
518+
"source": []
531519
}
532520
],
533521
"metadata": {

nbs/common.base_multivariate.ipynb

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,12 @@
105105
" drop_last_loader=False,\n",
106106
" random_seed=1, \n",
107107
" alias=None,\n",
108-
" optimizer=None,\n",
109-
" optimizer_kwargs=None,\n",
110-
" lr_scheduler=None,\n",
111-
" lr_scheduler_kwargs=None,\n",
112108
" dataloader_kwargs=None,\n",
113109
" **trainer_kwargs):\n",
114110
" super().__init__(\n",
115111
" random_seed=random_seed,\n",
116112
" loss=loss,\n",
117-
" valid_loss=valid_loss,\n",
118-
" optimizer=optimizer,\n",
119-
" optimizer_kwargs=optimizer_kwargs,\n",
120-
" lr_scheduler=lr_scheduler,\n",
121-
" lr_scheduler_kwargs=lr_scheduler_kwargs, \n",
113+
" valid_loss=valid_loss, \n",
122114
" futr_exog_list=futr_exog_list,\n",
123115
" hist_exog_list=hist_exog_list,\n",
124116
" stat_exog_list=stat_exog_list,\n",

nbs/common.base_recurrent.ipynb

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,12 @@
111111
" drop_last_loader=False,\n",
112112
" random_seed=1, \n",
113113
" alias=None,\n",
114-
" optimizer=None,\n",
115-
" optimizer_kwargs=None,\n",
116-
" lr_scheduler=None,\n",
117-
" lr_scheduler_kwargs=None,\n",
118114
" dataloader_kwargs=None,\n",
119115
" **trainer_kwargs):\n",
120116
" super().__init__(\n",
121117
" random_seed=random_seed,\n",
122118
" loss=loss,\n",
123119
" valid_loss=valid_loss,\n",
124-
" optimizer=optimizer,\n",
125-
" optimizer_kwargs=optimizer_kwargs,\n",
126-
" lr_scheduler=lr_scheduler,\n",
127-
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
128120
" futr_exog_list=futr_exog_list,\n",
129121
" hist_exog_list=hist_exog_list,\n",
130122
" stat_exog_list=stat_exog_list,\n",

nbs/common.base_windows.ipynb

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,12 @@
115115
" drop_last_loader=False,\n",
116116
" random_seed=1,\n",
117117
" alias=None,\n",
118-
" optimizer=None,\n",
119-
" optimizer_kwargs=None,\n",
120-
" lr_scheduler=None,\n",
121-
" lr_scheduler_kwargs=None,\n",
122118
" dataloader_kwargs=None,\n",
123119
" **trainer_kwargs):\n",
124120
" super().__init__(\n",
125121
" random_seed=random_seed,\n",
126122
" loss=loss,\n",
127123
" valid_loss=valid_loss,\n",
128-
" optimizer=optimizer,\n",
129-
" optimizer_kwargs=optimizer_kwargs,\n",
130-
" lr_scheduler=lr_scheduler,\n",
131-
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
132124
" futr_exog_list=futr_exog_list,\n",
133125
" hist_exog_list=hist_exog_list,\n",
134126
" stat_exog_list=stat_exog_list,\n",

0 commit comments

Comments
 (0)