Skip to content

Commit 2e3d7c6

Browse files
committed
Remove old interface and deprecate the arguments
nbdev_clean --clear_all remove unnecessary changes
1 parent c0a7eb8 commit 2e3d7c6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+47
-1075
lines changed

nbs/common.base_model.ipynb

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
"import random\n",
3535
"import warnings\n",
3636
"from contextlib import contextmanager\n",
37-
"from copy import deepcopy\n",
3837
"from dataclasses import dataclass\n",
3938
"\n",
4039
"import fsspec\n",
@@ -121,10 +120,6 @@
121120
" random_seed,\n",
122121
" loss,\n",
123122
" valid_loss,\n",
124-
" optimizer,\n",
125-
" optimizer_kwargs,\n",
126-
" lr_scheduler,\n",
127-
" lr_scheduler_kwargs,\n",
128123
" futr_exog_list,\n",
129124
" hist_exog_list,\n",
130125
" stat_exog_list,\n",
@@ -150,18 +145,6 @@
150145
" self.train_trajectories = []\n",
151146
" self.valid_trajectories = []\n",
152147
"\n",
153-
" # Optimization\n",
154-
" if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
155-
" raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
156-
" self.optimizer = optimizer\n",
157-
" self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}\n",
158-
"\n",
159-
" # lr scheduler\n",
160-
" if lr_scheduler is not None and not issubclass(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):\n",
161-
" raise TypeError(\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
162-
" self.lr_scheduler = lr_scheduler\n",
163-
" self.lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}\n",
164-
"\n",
165148
" # customized by set_configure_optimizers()\n",
166149
" self.config_optimizers = None\n",
167150
"\n",
@@ -412,41 +395,19 @@
412395
"\n",
413396
" def configure_optimizers(self):\n",
414397
" if self.config_optimizers is not None:\n",
398+
" # return the customized optimizer settings if specified\n",
415399
" return self.config_optimizers\n",
416-
" \n",
417-
" if self.optimizer:\n",
418-
" optimizer_signature = inspect.signature(self.optimizer)\n",
419-
" optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
420-
" if 'lr' in optimizer_signature.parameters:\n",
421-
" if 'lr' in optimizer_kwargs:\n",
422-
" warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
423-
" optimizer_kwargs['lr'] = self.learning_rate\n",
424-
" optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n",
425-
" else:\n",
426-
" if self.optimizer_kwargs:\n",
427-
" warnings.warn(\n",
428-
" \"ignoring optimizer_kwargs as the optimizer is not specified\"\n",
429-
" )\n",
430-
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
431400
" \n",
432-
" lr_scheduler = {'frequency': 1, 'interval': 'step'}\n",
433-
" if self.lr_scheduler:\n",
434-
" lr_scheduler_signature = inspect.signature(self.lr_scheduler)\n",
435-
" lr_scheduler_kwargs = deepcopy(self.lr_scheduler_kwargs)\n",
436-
" if 'optimizer' in lr_scheduler_signature.parameters:\n",
437-
" if 'optimizer' in lr_scheduler_kwargs:\n",
438-
" warnings.warn(\"ignoring optimizer passed in lr_scheduler_kwargs, using the model's optimizer\")\n",
439-
" del lr_scheduler_kwargs['optimizer']\n",
440-
" lr_scheduler['scheduler'] = self.lr_scheduler(optimizer=optimizer, **lr_scheduler_kwargs)\n",
441-
" else:\n",
442-
" if self.lr_scheduler_kwargs:\n",
443-
" warnings.warn(\n",
444-
" \"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\"\n",
445-
" ) \n",
446-
" lr_scheduler['scheduler'] = torch.optim.lr_scheduler.StepLR(\n",
401+
" # default choice\n",
402+
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
403+
" scheduler = {\n",
404+
" \"scheduler\": torch.optim.lr_scheduler.StepLR(\n",
447405
" optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5\n",
448-
" )\n",
449-
" return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}\n",
406+
" ),\n",
407+
" \"frequency\": 1,\n",
408+
" \"interval\": \"step\",\n",
409+
" }\n",
410+
" return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
450411
"\n",
451412
" def set_configure_optimizers(\n",
452413
" self, \n",

nbs/common.base_multivariate.ipynb

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,12 @@
105105
" drop_last_loader=False,\n",
106106
" random_seed=1, \n",
107107
" alias=None,\n",
108-
" optimizer=None,\n",
109-
" optimizer_kwargs=None,\n",
110-
" lr_scheduler=None,\n",
111-
" lr_scheduler_kwargs=None,\n",
112108
" dataloader_kwargs=None,\n",
113109
" **trainer_kwargs):\n",
114110
" super().__init__(\n",
115111
" random_seed=random_seed,\n",
116112
" loss=loss,\n",
117-
" valid_loss=valid_loss,\n",
118-
" optimizer=optimizer,\n",
119-
" optimizer_kwargs=optimizer_kwargs,\n",
120-
" lr_scheduler=lr_scheduler,\n",
121-
" lr_scheduler_kwargs=lr_scheduler_kwargs, \n",
113+
" valid_loss=valid_loss, \n",
122114
" futr_exog_list=futr_exog_list,\n",
123115
" hist_exog_list=hist_exog_list,\n",
124116
" stat_exog_list=stat_exog_list,\n",

nbs/common.base_recurrent.ipynb

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,12 @@
111111
" drop_last_loader=False,\n",
112112
" random_seed=1, \n",
113113
" alias=None,\n",
114-
" optimizer=None,\n",
115-
" optimizer_kwargs=None,\n",
116-
" lr_scheduler=None,\n",
117-
" lr_scheduler_kwargs=None,\n",
118114
" dataloader_kwargs=None,\n",
119115
" **trainer_kwargs):\n",
120116
" super().__init__(\n",
121117
" random_seed=random_seed,\n",
122118
" loss=loss,\n",
123119
" valid_loss=valid_loss,\n",
124-
" optimizer=optimizer,\n",
125-
" optimizer_kwargs=optimizer_kwargs,\n",
126-
" lr_scheduler=lr_scheduler,\n",
127-
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
128120
" futr_exog_list=futr_exog_list,\n",
129121
" hist_exog_list=hist_exog_list,\n",
130122
" stat_exog_list=stat_exog_list,\n",

nbs/common.base_windows.ipynb

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,12 @@
115115
" drop_last_loader=False,\n",
116116
" random_seed=1,\n",
117117
" alias=None,\n",
118-
" optimizer=None,\n",
119-
" optimizer_kwargs=None,\n",
120-
" lr_scheduler=None,\n",
121-
" lr_scheduler_kwargs=None,\n",
122118
" dataloader_kwargs=None,\n",
123119
" **trainer_kwargs):\n",
124120
" super().__init__(\n",
125121
" random_seed=random_seed,\n",
126122
" loss=loss,\n",
127123
" valid_loss=valid_loss,\n",
128-
" optimizer=optimizer,\n",
129-
" optimizer_kwargs=optimizer_kwargs,\n",
130-
" lr_scheduler=lr_scheduler,\n",
131-
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
132124
" futr_exog_list=futr_exog_list,\n",
133125
" hist_exog_list=hist_exog_list,\n",
134126
" stat_exog_list=stat_exog_list,\n",

nbs/core.ipynb

Lines changed: 25 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -3172,15 +3172,22 @@
31723172
" mean = default_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
31733173
"\n",
31743174
" # using a customized optimizer\n",
3175-
" params.update({\n",
3176-
" \"optimizer\": torch.optim.Adadelta,\n",
3177-
" \"optimizer_kwargs\": {\"rho\": 0.45}, \n",
3178-
" })\n",
3175+
" optimizer = torch.optim.Adadelta(params=models2[0].parameters(), rho=0.75)\n",
3176+
" scheduler=torch.optim.lr_scheduler.StepLR(\n",
3177+
" optimizer=optimizer, step_size=10e7, gamma=0.5\n",
3178+
" )\n",
3179+
"\n",
31793180
" models2 = [nf_model(**params)]\n",
3181+
" models2[0].set_configure_optimizers(\n",
3182+
" optimizer=optimizer,\n",
3183+
" scheduler=scheduler,\n",
3184+
" )\n",
3185+
"\n",
31803186
" nf2 = NeuralForecast(models=models2, freq='M')\n",
31813187
" nf2.fit(AirPassengersPanel_train)\n",
31823188
" customized_optimizer_predict = nf2.predict()\n",
31833189
" mean2 = customized_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
3190+
"\n",
31843191
" assert mean2 != mean"
31853192
]
31863193
},
@@ -3194,100 +3201,18 @@
31943201
"#| hide\n",
31953202
"# test that if the user-defined optimizer is not a subclass of torch.optim.optimizer, failed with exception\n",
31963203
"# tests cover different types of base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
3197-
"test_fail(lambda: NHITS(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
3198-
"test_fail(lambda: RNN(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
3199-
"test_fail(lambda: StemGNN(h=12, input_size=24, max_steps=10, n_series=2, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n"
3200-
]
3201-
},
3202-
{
3203-
"cell_type": "code",
3204-
"execution_count": null,
3205-
"id": "d908240f",
3206-
"metadata": {},
3207-
"outputs": [],
3208-
"source": [
3209-
"#| hide\n",
3210-
"# test that if we pass \"lr\" parameter, we expect warning and it ignores the passed in 'lr' parameter\n",
3211-
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
32123204
"\n",
3213-
"for nf_model in [NHITS, RNN, StemGNN]:\n",
3214-
" params = {\n",
3215-
" \"h\": 12, \n",
3216-
" \"input_size\": 24, \n",
3217-
" \"max_steps\": 1, \n",
3218-
" \"optimizer\": torch.optim.Adadelta, \n",
3219-
" \"optimizer_kwargs\": {\"lr\": 0.8, \"rho\": 0.45}\n",
3220-
" }\n",
3205+
"for model_name in [NHITS, RNN, StemGNN]:\n",
3206+
" params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 10}\n",
32213207
" if nf_model.__name__ == \"StemGNN\":\n",
32223208
" params.update({\"n_series\": 2})\n",
3223-
" models = [nf_model(**params)]\n",
3224-
" nf = NeuralForecast(models=models, freq='M')\n",
3225-
" with warnings.catch_warnings(record=True) as issued_warnings:\n",
3226-
" warnings.simplefilter('always', UserWarning)\n",
3227-
" nf.fit(AirPassengersPanel_train)\n",
3228-
" assert any(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\" in str(w.message) for w in issued_warnings)"
3229-
]
3230-
},
3231-
{
3232-
"cell_type": "code",
3233-
"execution_count": null,
3234-
"id": "c97858b5-e6a0-4353-a48f-5a5460eb2314",
3235-
"metadata": {},
3236-
"outputs": [],
3237-
"source": [
3238-
"#| hide\n",
3239-
"# test that if we pass \"optimizer_kwargs\" but not \"optimizer\", we expect a warning\n",
3240-
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
32413209
"\n",
3242-
"for nf_model in [NHITS, RNN, StemGNN]:\n",
3243-
" params = {\n",
3244-
" \"h\": 12, \n",
3245-
" \"input_size\": 24, \n",
3246-
" \"max_steps\": 1,\n",
3247-
" \"optimizer_kwargs\": {\"lr\": 0.8, \"rho\": 0.45}\n",
3248-
" }\n",
3249-
" if nf_model.__name__ == \"StemGNN\":\n",
3250-
" params.update({\"n_series\": 2})\n",
3251-
" models = [nf_model(**params)]\n",
3252-
" nf = NeuralForecast(models=models, freq='M')\n",
3253-
" with warnings.catch_warnings(record=True) as issued_warnings:\n",
3254-
" warnings.simplefilter('always', UserWarning)\n",
3255-
" nf.fit(AirPassengersPanel_train)\n",
3256-
" assert any(\"ignoring optimizer_kwargs as the optimizer is not specified\" in str(w.message) for w in issued_warnings)"
3257-
]
3258-
},
3259-
{
3260-
"cell_type": "code",
3261-
"execution_count": null,
3262-
"id": "24142322",
3263-
"metadata": {},
3264-
"outputs": [],
3265-
"source": [
3266-
"#| hide\n",
3267-
"# test customized lr_scheduler behavior such that the user defined lr_scheduler result should differ from default\n",
3268-
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
3269-
"\n",
3270-
"for nf_model in [NHITS, RNN, StemGNN]:\n",
3271-
" params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 1}\n",
3272-
" if nf_model.__name__ == \"StemGNN\":\n",
3273-
" params.update({\"n_series\": 2})\n",
3274-
" models = [nf_model(**params)]\n",
3275-
" nf = NeuralForecast(models=models, freq='M')\n",
3276-
" nf.fit(AirPassengersPanel_train)\n",
3277-
" default_optimizer_predict = nf.predict()\n",
3278-
" mean = default_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
3279-
"\n",
3280-
" # using a customized lr_scheduler, default is StepLR\n",
3281-
" params.update({\n",
3282-
" \"lr_scheduler\": torch.optim.lr_scheduler.ConstantLR,\n",
3283-
" \"lr_scheduler_kwargs\": {\"factor\": 0.78}, \n",
3284-
" })\n",
3285-
" models2 = [nf_model(**params)]\n",
3286-
" nf2 = NeuralForecast(models=models2, freq='M')\n",
3287-
" nf2.fit(AirPassengersPanel_train)\n",
3288-
" customized_optimizer_predict = nf2.predict()\n",
3289-
" mean2 = customized_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
3290-
" assert mean2 != mean"
3210+
" model = model_name(**params) \n",
3211+
" optimizer = torch.nn.Module()\n",
3212+
" scheduler = torch.optim.lr_scheduler.StepLR(\n",
3213+
" optimizer=torch.optim.Adam(model.parameters()), step_size=10e7, gamma=0.5\n",
3214+
" ) \n",
3215+
" test_fail(lambda: model.set_configure_optimizers(optimizer=optimizer, scheduler=scheduler), contains=\"optimizer is not a valid instance of torch.optim.Optimizer\")\n"
32913216
]
32923217
},
32933218
{
@@ -3298,68 +3223,16 @@
32983223
"outputs": [],
32993224
"source": [
33003225
"#| hide\n",
3301-
"# test that if the user-defined lr_scheduler is not a subclass of torch.optim.lr_scheduler, failed with exception\n",
3226+
"# test that if the user-defined scheduler is not a subclass of torch.optim.lr_scheduler, failed with exception\n",
33023227
"# tests cover different types of base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
3303-
"test_fail(lambda: NHITS(h=12, input_size=24, max_steps=10, lr_scheduler=torch.nn.Module), contains=\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
3304-
"test_fail(lambda: RNN(h=12, input_size=24, max_steps=10, lr_scheduler=torch.nn.Module), contains=\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
3305-
"test_fail(lambda: StemGNN(h=12, input_size=24, max_steps=10, n_series=2, lr_scheduler=torch.nn.Module), contains=\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n"
3306-
]
3307-
},
3308-
{
3309-
"cell_type": "code",
3310-
"execution_count": null,
3311-
"id": "b1d8bebb",
3312-
"metadata": {},
3313-
"outputs": [],
3314-
"source": [
3315-
"#| hide\n",
3316-
"# test that if we pass in \"optimizer\" parameter, we expect warning and it ignores them\n",
3317-
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
3318-
"\n",
3319-
"for nf_model in [NHITS, RNN, StemGNN]:\n",
3320-
" params = {\n",
3321-
" \"h\": 12, \n",
3322-
" \"input_size\": 24, \n",
3323-
" \"max_steps\": 1, \n",
3324-
" \"lr_scheduler\": torch.optim.lr_scheduler.ConstantLR, \n",
3325-
" \"lr_scheduler_kwargs\": {\"optimizer\": torch.optim.Adadelta, \"factor\": 0.22}\n",
3326-
" }\n",
3327-
" if nf_model.__name__ == \"StemGNN\":\n",
3328-
" params.update({\"n_series\": 2})\n",
3329-
" models = [nf_model(**params)]\n",
3330-
" nf = NeuralForecast(models=models, freq='M')\n",
3331-
" with warnings.catch_warnings(record=True) as issued_warnings:\n",
3332-
" warnings.simplefilter('always', UserWarning)\n",
3333-
" nf.fit(AirPassengersPanel_train)\n",
3334-
" assert any(\"ignoring optimizer passed in lr_scheduler_kwargs, using the model's optimizer\" in str(w.message) for w in issued_warnings)"
3335-
]
3336-
},
3337-
{
3338-
"cell_type": "code",
3339-
"execution_count": null,
3340-
"id": "06febece",
3341-
"metadata": {},
3342-
"outputs": [],
3343-
"source": [
3344-
"#| hide\n",
3345-
"# test that if we pass in \"lr_scheduler_kwargs\" but not \"lr_scheduler\", we expect a warning\n",
3346-
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
33473228
"\n",
3348-
"for nf_model in [NHITS, RNN, StemGNN]:\n",
3349-
" params = {\n",
3350-
" \"h\": 12, \n",
3351-
" \"input_size\": 24, \n",
3352-
" \"max_steps\": 1,\n",
3353-
" \"lr_scheduler_kwargs\": {\"optimizer\": torch.optim.Adadelta, \"factor\": 0.22}\n",
3354-
" }\n",
3229+
"for model_name in [NHITS, RNN, StemGNN]:\n",
3230+
" params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 10}\n",
33553231
" if nf_model.__name__ == \"StemGNN\":\n",
33563232
" params.update({\"n_series\": 2})\n",
3357-
" models = [nf_model(**params)]\n",
3358-
" nf = NeuralForecast(models=models, freq='M')\n",
3359-
" with warnings.catch_warnings(record=True) as issued_warnings:\n",
3360-
" warnings.simplefilter('always', UserWarning)\n",
3361-
" nf.fit(AirPassengersPanel_train)\n",
3362-
" assert any(\"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\" in str(w.message) for w in issued_warnings)\n"
3233+
" model = model_name(**params)\n",
3234+
" optimizer = torch.optim.Adam(model.parameters())\n",
3235+
" test_fail(lambda: model.set_configure_optimizers(optimizer=optimizer, scheduler=torch.nn.Module), contains=\"scheduler is not a valid instance of torch.optim.lr_scheduler.LRScheduler\")"
33633236
]
33643237
},
33653238
{
@@ -3493,7 +3366,6 @@
34933366
" models[0].set_configure_optimizers(\n",
34943367
" optimizer=optimizer,\n",
34953368
" scheduler=scheduler,\n",
3496-
"\n",
34973369
" )\n",
34983370
" nf2 = NeuralForecast(models=models, freq='M')\n",
34993371
" nf2.fit(AirPassengersPanel_train)\n",

0 commit comments

Comments
 (0)