|
27 | 27 | "execution_count": null, |
28 | 28 | "id": "1c7c2ba5-19ee-421e-9252-7224b03f5201", |
29 | 29 | "metadata": {}, |
30 | | - "outputs": [], |
| 30 | + "outputs": [ |
| 31 | + { |
| 32 | + "name": "stderr", |
| 33 | + "output_type": "stream", |
| 34 | + "text": [ |
| 35 | + "/root/miniconda3/envs/neuralforecast/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", |
| 36 | + " from .autonotebook import tqdm as notebook_tqdm\n", |
| 37 | + "2024-12-11 17:06:11,409\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", |
| 38 | + "2024-12-11 17:06:11,467\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" |
| 39 | + ] |
| 40 | + } |
| 41 | + ], |
31 | 42 | "source": [ |
32 | 43 | "#| export\n", |
33 | 44 | "import inspect\n", |
34 | 45 | "import random\n", |
35 | 46 | "import warnings\n", |
36 | 47 | "from contextlib import contextmanager\n", |
37 | | - "from copy import deepcopy\n", |
38 | 48 | "from dataclasses import dataclass\n", |
39 | 49 | "\n", |
40 | 50 | "import fsspec\n", |
|
121 | 131 | " random_seed,\n", |
122 | 132 | " loss,\n", |
123 | 133 | " valid_loss,\n", |
124 | | - " optimizer,\n", |
125 | | - " optimizer_kwargs,\n", |
126 | | - " lr_scheduler,\n", |
127 | | - " lr_scheduler_kwargs,\n", |
128 | 134 | " futr_exog_list,\n", |
129 | 135 | " hist_exog_list,\n", |
130 | 136 | " stat_exog_list,\n", |
|
150 | 156 | " self.train_trajectories = []\n", |
151 | 157 | " self.valid_trajectories = []\n", |
152 | 158 | "\n", |
153 | | - " # Optimization\n", |
154 | | - " if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n", |
155 | | - " raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n", |
156 | | - " self.optimizer = optimizer\n", |
157 | | - " self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}\n", |
158 | | - "\n", |
159 | | - " # lr scheduler\n", |
160 | | - " if lr_scheduler is not None and not issubclass(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):\n", |
161 | | - " raise TypeError(\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n", |
162 | | - " self.lr_scheduler = lr_scheduler\n", |
163 | | - " self.lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}\n", |
164 | | - "\n", |
165 | 159 | " # customized by set_configure_optimizers()\n", |
166 | 160 | " self.config_optimizers = None\n", |
167 | 161 | "\n", |
|
412 | 406 | "\n", |
413 | 407 | " def configure_optimizers(self):\n", |
414 | 408 | " if self.config_optimizers is not None:\n", |
| 409 | + " # return the customized optimizer settings if specified\n", |
415 | 410 | " return self.config_optimizers\n", |
416 | | - " \n", |
417 | | - " if self.optimizer:\n", |
418 | | - " optimizer_signature = inspect.signature(self.optimizer)\n", |
419 | | - " optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n", |
420 | | - " if 'lr' in optimizer_signature.parameters:\n", |
421 | | - " if 'lr' in optimizer_kwargs:\n", |
422 | | - " warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n", |
423 | | - " optimizer_kwargs['lr'] = self.learning_rate\n", |
424 | | - " optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n", |
425 | | - " else:\n", |
426 | | - " if self.optimizer_kwargs:\n", |
427 | | - " warnings.warn(\n", |
428 | | - " \"ignoring optimizer_kwargs as the optimizer is not specified\"\n", |
429 | | - " )\n", |
430 | | - " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", |
431 | 411 | " \n", |
432 | | - " lr_scheduler = {'frequency': 1, 'interval': 'step'}\n", |
433 | | - " if self.lr_scheduler:\n", |
434 | | - " lr_scheduler_signature = inspect.signature(self.lr_scheduler)\n", |
435 | | - " lr_scheduler_kwargs = deepcopy(self.lr_scheduler_kwargs)\n", |
436 | | - " if 'optimizer' in lr_scheduler_signature.parameters:\n", |
437 | | - " if 'optimizer' in lr_scheduler_kwargs:\n", |
438 | | - " warnings.warn(\"ignoring optimizer passed in lr_scheduler_kwargs, using the model's optimizer\")\n", |
439 | | - " del lr_scheduler_kwargs['optimizer']\n", |
440 | | - " lr_scheduler['scheduler'] = self.lr_scheduler(optimizer=optimizer, **lr_scheduler_kwargs)\n", |
441 | | - " else:\n", |
442 | | - " if self.lr_scheduler_kwargs:\n", |
443 | | - " warnings.warn(\n", |
444 | | - " \"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\"\n", |
445 | | - " ) \n", |
446 | | - " lr_scheduler['scheduler'] = torch.optim.lr_scheduler.StepLR(\n", |
| 412 | + " # default choice\n", |
| 413 | + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", |
| 414 | + " scheduler = {\n", |
| 415 | + " \"scheduler\": torch.optim.lr_scheduler.StepLR(\n", |
447 | 416 | " optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5\n", |
448 | | - " )\n", |
449 | | - " return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}\n", |
| 417 | + " ),\n", |
| 418 | + " \"frequency\": 1,\n", |
| 419 | + " \"interval\": \"step\",\n", |
| 420 | + " }\n", |
| 421 | + " return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n", |
450 | 422 | "\n", |
451 | 423 | " def set_configure_optimizers(\n", |
452 | 424 | " self, \n", |
|
528 | 500 | " model.load_state_dict(content[\"state_dict\"], strict=True)\n", |
529 | 501 | " return model" |
530 | 502 | ] |
| 503 | + }, |
| 504 | + { |
| 505 | + "cell_type": "code", |
| 506 | + "execution_count": null, |
| 507 | + "id": "077ea025", |
| 508 | + "metadata": {}, |
| 509 | + "outputs": [], |
| 510 | + "source": [] |
| 511 | + }, |
| 512 | + { |
| 513 | + "cell_type": "code", |
| 514 | + "execution_count": null, |
| 515 | + "id": "2b36e87a", |
| 516 | + "metadata": {}, |
| 517 | + "outputs": [], |
| 518 | + "source": [] |
531 | 519 | } |
532 | 520 | ], |
533 | 521 | "metadata": { |
|
0 commit comments