|
3172 | 3172 | " mean = default_optimizer_predict.loc[:, nf_model.__name__].mean()\n", |
3173 | 3173 | "\n", |
3174 | 3174 | " # using a customized optimizer\n", |
3175 | | - " params.update({\n", |
3176 | | - " \"optimizer\": torch.optim.Adadelta,\n", |
3177 | | - " \"optimizer_kwargs\": {\"rho\": 0.45}, \n", |
3178 | | - " })\n", |
| 3175 | + " optimizer = torch.optim.Adadelta(params=models2[0].parameters(), rho=0.75)\n", |
| 3176 | + " scheduler=torch.optim.lr_scheduler.StepLR(\n", |
| 3177 | + " optimizer=optimizer, step_size=10e7, gamma=0.5\n", |
| 3178 | + " )\n", |
| 3179 | + "\n", |
3179 | 3180 | " models2 = [nf_model(**params)]\n", |
| 3181 | + " models2[0].set_configure_optimizers(\n", |
| 3182 | + " optimizer=optimizer,\n", |
| 3183 | + " scheduler=scheduler,\n", |
| 3184 | + " )\n", |
| 3185 | + "\n", |
3180 | 3186 | " nf2 = NeuralForecast(models=models2, freq='M')\n", |
3181 | 3187 | " nf2.fit(AirPassengersPanel_train)\n", |
3182 | 3188 | " customized_optimizer_predict = nf2.predict()\n", |
3183 | 3189 | " mean2 = customized_optimizer_predict.loc[:, nf_model.__name__].mean()\n", |
| 3190 | + "\n", |
3184 | 3191 | " assert mean2 != mean" |
3185 | 3192 | ] |
3186 | 3193 | }, |
|
3194 | 3201 | "#| hide\n", |
3195 | 3202 | "# test that if the user-defined optimizer is not a subclass of torch.optim.optimizer, failed with exception\n", |
3196 | 3203 | "# tests cover different types of base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n", |
3197 | | - "test_fail(lambda: NHITS(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n", |
3198 | | - "test_fail(lambda: RNN(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n", |
3199 | | - "test_fail(lambda: StemGNN(h=12, input_size=24, max_steps=10, n_series=2, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n" |
3200 | | - ] |
3201 | | - }, |
3202 | | - { |
3203 | | - "cell_type": "code", |
3204 | | - "execution_count": null, |
3205 | | - "id": "d908240f", |
3206 | | - "metadata": {}, |
3207 | | - "outputs": [], |
3208 | | - "source": [ |
3209 | | - "#| hide\n", |
3210 | | - "# test that if we pass \"lr\" parameter, we expect warning and it ignores the passed in 'lr' parameter\n", |
3211 | | - "# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n", |
3212 | 3204 | "\n", |
3213 | | - "for nf_model in [NHITS, RNN, StemGNN]:\n", |
3214 | | - " params = {\n", |
3215 | | - " \"h\": 12, \n", |
3216 | | - " \"input_size\": 24, \n", |
3217 | | - " \"max_steps\": 1, \n", |
3218 | | - " \"optimizer\": torch.optim.Adadelta, \n", |
3219 | | - " \"optimizer_kwargs\": {\"lr\": 0.8, \"rho\": 0.45}\n", |
3220 | | - " }\n", |
| 3205 | + "for model_name in [NHITS, RNN, StemGNN]:\n", |
| 3206 | + " params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 10}\n", |
3221 | 3207 | " if nf_model.__name__ == \"StemGNN\":\n", |
3222 | 3208 | " params.update({\"n_series\": 2})\n", |
3223 | | - " models = [nf_model(**params)]\n", |
3224 | | - " nf = NeuralForecast(models=models, freq='M')\n", |
3225 | | - " with warnings.catch_warnings(record=True) as issued_warnings:\n", |
3226 | | - " warnings.simplefilter('always', UserWarning)\n", |
3227 | | - " nf.fit(AirPassengersPanel_train)\n", |
3228 | | - " assert any(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\" in str(w.message) for w in issued_warnings)" |
3229 | | - ] |
3230 | | - }, |
3231 | | - { |
3232 | | - "cell_type": "code", |
3233 | | - "execution_count": null, |
3234 | | - "id": "c97858b5-e6a0-4353-a48f-5a5460eb2314", |
3235 | | - "metadata": {}, |
3236 | | - "outputs": [], |
3237 | | - "source": [ |
3238 | | - "#| hide\n", |
3239 | | - "# test that if we pass \"optimizer_kwargs\" but not \"optimizer\", we expect a warning\n", |
3240 | | - "# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n", |
3241 | 3209 | "\n", |
3242 | | - "for nf_model in [NHITS, RNN, StemGNN]:\n", |
3243 | | - " params = {\n", |
3244 | | - " \"h\": 12, \n", |
3245 | | - " \"input_size\": 24, \n", |
3246 | | - " \"max_steps\": 1,\n", |
3247 | | - " \"optimizer_kwargs\": {\"lr\": 0.8, \"rho\": 0.45}\n", |
3248 | | - " }\n", |
3249 | | - " if nf_model.__name__ == \"StemGNN\":\n", |
3250 | | - " params.update({\"n_series\": 2})\n", |
3251 | | - " models = [nf_model(**params)]\n", |
3252 | | - " nf = NeuralForecast(models=models, freq='M')\n", |
3253 | | - " with warnings.catch_warnings(record=True) as issued_warnings:\n", |
3254 | | - " warnings.simplefilter('always', UserWarning)\n", |
3255 | | - " nf.fit(AirPassengersPanel_train)\n", |
3256 | | - " assert any(\"ignoring optimizer_kwargs as the optimizer is not specified\" in str(w.message) for w in issued_warnings)" |
3257 | | - ] |
3258 | | - }, |
3259 | | - { |
3260 | | - "cell_type": "code", |
3261 | | - "execution_count": null, |
3262 | | - "id": "24142322", |
3263 | | - "metadata": {}, |
3264 | | - "outputs": [], |
3265 | | - "source": [ |
3266 | | - "#| hide\n", |
3267 | | - "# test customized lr_scheduler behavior such that the user defined lr_scheduler result should differ from default\n", |
3268 | | - "# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n", |
3269 | | - "\n", |
3270 | | - "for nf_model in [NHITS, RNN, StemGNN]:\n", |
3271 | | - " params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 1}\n", |
3272 | | - " if nf_model.__name__ == \"StemGNN\":\n", |
3273 | | - " params.update({\"n_series\": 2})\n", |
3274 | | - " models = [nf_model(**params)]\n", |
3275 | | - " nf = NeuralForecast(models=models, freq='M')\n", |
3276 | | - " nf.fit(AirPassengersPanel_train)\n", |
3277 | | - " default_optimizer_predict = nf.predict()\n", |
3278 | | - " mean = default_optimizer_predict.loc[:, nf_model.__name__].mean()\n", |
3279 | | - "\n", |
3280 | | - " # using a customized lr_scheduler, default is StepLR\n", |
3281 | | - " params.update({\n", |
3282 | | - " \"lr_scheduler\": torch.optim.lr_scheduler.ConstantLR,\n", |
3283 | | - " \"lr_scheduler_kwargs\": {\"factor\": 0.78}, \n", |
3284 | | - " })\n", |
3285 | | - " models2 = [nf_model(**params)]\n", |
3286 | | - " nf2 = NeuralForecast(models=models2, freq='M')\n", |
3287 | | - " nf2.fit(AirPassengersPanel_train)\n", |
3288 | | - " customized_optimizer_predict = nf2.predict()\n", |
3289 | | - " mean2 = customized_optimizer_predict.loc[:, nf_model.__name__].mean()\n", |
3290 | | - " assert mean2 != mean" |
| 3210 | + " model = model_name(**params) \n", |
| 3211 | + " optimizer = torch.nn.Module()\n", |
| 3212 | + " scheduler = torch.optim.lr_scheduler.StepLR(\n", |
| 3213 | + " optimizer=torch.optim.Adam(model.parameters()), step_size=10e7, gamma=0.5\n", |
| 3214 | + " ) \n", |
| 3215 | + " test_fail(lambda: model.set_configure_optimizers(optimizer=optimizer, scheduler=scheduler), contains=\"optimizer is not a valid instance of torch.optim.Optimizer\")\n" |
3291 | 3216 | ] |
3292 | 3217 | }, |
3293 | 3218 | { |
|
3298 | 3223 | "outputs": [], |
3299 | 3224 | "source": [ |
3300 | 3225 | "#| hide\n", |
3301 | | - "# test that if the user-defined lr_scheduler is not a subclass of torch.optim.lr_scheduler, failed with exception\n", |
| 3226 | + "# test that if the user-defined scheduler is not a subclass of torch.optim.lr_scheduler, failed with exception\n", |
3302 | 3227 | "# tests cover different types of base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n", |
3303 | | - "test_fail(lambda: NHITS(h=12, input_size=24, max_steps=10, lr_scheduler=torch.nn.Module), contains=\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n", |
3304 | | - "test_fail(lambda: RNN(h=12, input_size=24, max_steps=10, lr_scheduler=torch.nn.Module), contains=\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n", |
3305 | | - "test_fail(lambda: StemGNN(h=12, input_size=24, max_steps=10, n_series=2, lr_scheduler=torch.nn.Module), contains=\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n" |
3306 | | - ] |
3307 | | - }, |
3308 | | - { |
3309 | | - "cell_type": "code", |
3310 | | - "execution_count": null, |
3311 | | - "id": "b1d8bebb", |
3312 | | - "metadata": {}, |
3313 | | - "outputs": [], |
3314 | | - "source": [ |
3315 | | - "#| hide\n", |
3316 | | - "# test that if we pass in \"optimizer\" parameter, we expect warning and it ignores them\n", |
3317 | | - "# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n", |
3318 | | - "\n", |
3319 | | - "for nf_model in [NHITS, RNN, StemGNN]:\n", |
3320 | | - " params = {\n", |
3321 | | - " \"h\": 12, \n", |
3322 | | - " \"input_size\": 24, \n", |
3323 | | - " \"max_steps\": 1, \n", |
3324 | | - " \"lr_scheduler\": torch.optim.lr_scheduler.ConstantLR, \n", |
3325 | | - " \"lr_scheduler_kwargs\": {\"optimizer\": torch.optim.Adadelta, \"factor\": 0.22}\n", |
3326 | | - " }\n", |
3327 | | - " if nf_model.__name__ == \"StemGNN\":\n", |
3328 | | - " params.update({\"n_series\": 2})\n", |
3329 | | - " models = [nf_model(**params)]\n", |
3330 | | - " nf = NeuralForecast(models=models, freq='M')\n", |
3331 | | - " with warnings.catch_warnings(record=True) as issued_warnings:\n", |
3332 | | - " warnings.simplefilter('always', UserWarning)\n", |
3333 | | - " nf.fit(AirPassengersPanel_train)\n", |
3334 | | - " assert any(\"ignoring optimizer passed in lr_scheduler_kwargs, using the model's optimizer\" in str(w.message) for w in issued_warnings)" |
3335 | | - ] |
3336 | | - }, |
3337 | | - { |
3338 | | - "cell_type": "code", |
3339 | | - "execution_count": null, |
3340 | | - "id": "06febece", |
3341 | | - "metadata": {}, |
3342 | | - "outputs": [], |
3343 | | - "source": [ |
3344 | | - "#| hide\n", |
3345 | | - "# test that if we pass in \"lr_scheduler_kwargs\" but not \"lr_scheduler\", we expect a warning\n", |
3346 | | - "# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n", |
3347 | 3228 | "\n", |
3348 | | - "for nf_model in [NHITS, RNN, StemGNN]:\n", |
3349 | | - " params = {\n", |
3350 | | - " \"h\": 12, \n", |
3351 | | - " \"input_size\": 24, \n", |
3352 | | - " \"max_steps\": 1,\n", |
3353 | | - " \"lr_scheduler_kwargs\": {\"optimizer\": torch.optim.Adadelta, \"factor\": 0.22}\n", |
3354 | | - " }\n", |
| 3229 | + "for model_name in [NHITS, RNN, StemGNN]:\n", |
| 3230 | + " params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 10}\n", |
3355 | 3231 | " if nf_model.__name__ == \"StemGNN\":\n", |
3356 | 3232 | " params.update({\"n_series\": 2})\n", |
3357 | | - " models = [nf_model(**params)]\n", |
3358 | | - " nf = NeuralForecast(models=models, freq='M')\n", |
3359 | | - " with warnings.catch_warnings(record=True) as issued_warnings:\n", |
3360 | | - " warnings.simplefilter('always', UserWarning)\n", |
3361 | | - " nf.fit(AirPassengersPanel_train)\n", |
3362 | | - " assert any(\"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\" in str(w.message) for w in issued_warnings)\n" |
| 3233 | + " model = model_name(**params)\n", |
| 3234 | + " optimizer = torch.optim.Adam(model.parameters())\n", |
| 3235 | + " test_fail(lambda: model.set_configure_optimizers(optimizer=optimizer, scheduler=torch.nn.Module), contains=\"scheduler is not a valid instance of torch.optim.lr_scheduler.LRScheduler\")" |
3363 | 3236 | ] |
3364 | 3237 | }, |
3365 | 3238 | { |
|
3493 | 3366 | " models[0].set_configure_optimizers(\n", |
3494 | 3367 | " optimizer=optimizer,\n", |
3495 | 3368 | " scheduler=scheduler,\n", |
3496 | | - "\n", |
3497 | 3369 | " )\n", |
3498 | 3370 | " nf2 = NeuralForecast(models=models, freq='M')\n", |
3499 | 3371 | " nf2.fit(AirPassengersPanel_train)\n", |
|
0 commit comments