|
1481 | 1481 | " continue\n", |
1482 | 1482 | "\n", |
1483 | 1483 | " model_name = repr(model)\n", |
1484 | | - " model_class_name = model.__class__.__name__.lower()\n", |
| 1484 | + " if model.__class__.__name__.lower() in MODEL_FILENAME_DICT:\n", |
| 1485 | + " model_class_name = model.__class__.__name__.lower()\n", |
| 1486 | + " elif model.__class__.__base__.__name__.lower() in MODEL_FILENAME_DICT:\n", |
| 1487 | + " model_class_name = model.__class__.__base__.__name__.lower()\n", |
| 1488 | + " else:\n", |
| 1489 | + " raise ValueError(\n", |
| 1490 | + " f\"Model {model.__class__.__name__} is not supported for saving.\"\n", |
| 1491 | + " )\n", |
1485 | 1492 | " alias_to_model[model_name] = model_class_name\n", |
1486 | 1493 | " count_names[model_name] = count_names.get(model_name, -1) + 1\n", |
1487 | 1494 | " model.save(f\"{path}/{model_name}_{count_names[model_name]}.ckpt\")\n", |
|
1529 | 1536 | "\n", |
1530 | 1537 | " with fsspec.open(f\"{path}/configuration.pkl\", \"wb\") as f:\n", |
1531 | 1538 | " pickle.dump(config_dict, f)\n", |
1532 | | - "\n", |
| 1539 | + " \n", |
1533 | 1540 | " @staticmethod\n", |
1534 | 1541 | " def load(path, verbose=False, **kwargs):\n", |
1535 | 1542 | " \"\"\"Load NeuralForecast\n", |
|
1550 | 1557 | " Instantiated `NeuralForecast` class.\n", |
1551 | 1558 | " \"\"\"\n", |
1552 | 1559 | " # Standarize path without '/'\n", |
1553 | | - " if path[-1] == '/':\n", |
| 1560 | + " if path[-1] == \"/\":\n", |
1554 | 1561 | " path = path[:-1]\n", |
1555 | | - " \n", |
| 1562 | + "\n", |
1556 | 1563 | " fs, _, _ = fsspec.get_fs_token_paths(path)\n", |
1557 | | - " files = [f.split('/')[-1] for f in fs.ls(path) if fs.isfile(f)]\n", |
| 1564 | + " files = [f.split(\"/\")[-1] for f in fs.ls(path) if fs.isfile(f)]\n", |
1558 | 1565 | "\n", |
1559 | 1566 | " # Load models\n", |
1560 | | - " models_ckpt = [f for f in files if f.endswith('.ckpt')]\n", |
| 1567 | + " models_ckpt = [f for f in files if f.endswith(\".ckpt\")]\n", |
1561 | 1568 | " if len(models_ckpt) == 0:\n", |
1562 | | - " raise Exception('No model found in directory.') \n", |
1563 | | - " \n", |
1564 | | - " if verbose: print(10 * '-' + ' Loading models ' + 10 * '-')\n", |
| 1569 | + " raise Exception(\"No model found in directory.\")\n", |
| 1570 | + "\n", |
| 1571 | + " if verbose:\n", |
| 1572 | + " print(10 * \"-\" + \" Loading models \" + 10 * \"-\")\n", |
1565 | 1573 | " models = []\n", |
1566 | 1574 | " try:\n", |
1567 | | - " with fsspec.open(f'{path}/alias_to_model.pkl', 'rb') as f:\n", |
| 1575 | + " with fsspec.open(f\"{path}/alias_to_model.pkl\", \"rb\") as f:\n", |
1568 | 1576 | " alias_to_model = pickle.load(f)\n", |
1569 | 1577 | " except FileNotFoundError:\n", |
1570 | 1578 | " alias_to_model = {}\n", |
| 1579 | + " \n", |
1571 | 1580 | " for model in models_ckpt:\n", |
1572 | | - " model_name = '_'.join(model.split('_')[:-1])\n", |
| 1581 | + " model_name = \"_\".join(model.split(\"_\")[:-1])\n", |
1573 | 1582 | " model_class_name = alias_to_model.get(model_name, model_name)\n", |
1574 | | - " loaded_model = MODEL_FILENAME_DICT[model_class_name].load(f'{path}/{model}', **kwargs)\n", |
| 1583 | + " loaded_model = MODEL_FILENAME_DICT[model_class_name].load(\n", |
| 1584 | + " f\"{path}/{model}\", **kwargs\n", |
| 1585 | + " )\n", |
1575 | 1586 | " loaded_model.alias = model_name\n", |
1576 | 1587 | " models.append(loaded_model)\n", |
1577 | | - " if verbose: print(f\"Model {model_name} loaded.\")\n", |
| 1588 | + " if verbose:\n", |
| 1589 | + " print(f\"Model {model_name} loaded.\")\n", |
1578 | 1590 | "\n", |
1579 | | - " if verbose: print(10*'-' + ' Loading dataset ' + 10*'-')\n", |
| 1591 | + " if verbose:\n", |
| 1592 | + " print(10 * \"-\" + \" Loading dataset \" + 10 * \"-\")\n", |
1580 | 1593 | " # Load dataset\n", |
1581 | 1594 | " try:\n", |
1582 | 1595 | " with fsspec.open(f\"{path}/dataset.pkl\", \"rb\") as f:\n", |
1583 | 1596 | " dataset = pickle.load(f)\n", |
1584 | | - " if verbose: print('Dataset loaded.')\n", |
| 1597 | + " if verbose:\n", |
| 1598 | + " print(\"Dataset loaded.\")\n", |
1585 | 1599 | " except FileNotFoundError:\n", |
1586 | 1600 | " dataset = None\n", |
1587 | | - " if verbose: print('No dataset found in directory.')\n", |
| 1601 | + " if verbose:\n", |
| 1602 | + " print(\"No dataset found in directory.\")\n", |
1588 | 1603 | "\n", |
1589 | | - " if verbose: print(10*'-' + ' Loading configuration ' + 10*'-')\n", |
| 1604 | + " if verbose:\n", |
| 1605 | + " print(10 * \"-\" + \" Loading configuration \" + 10 * \"-\")\n", |
1590 | 1606 | " # Load configuration\n", |
1591 | 1607 | " try:\n", |
1592 | 1608 | " with fsspec.open(f\"{path}/configuration.pkl\", \"rb\") as f:\n", |
1593 | 1609 | " config_dict = pickle.load(f)\n", |
1594 | | - " if verbose: print('Configuration loaded.')\n", |
| 1610 | + " if verbose:\n", |
| 1611 | + " print(\"Configuration loaded.\")\n", |
1595 | 1612 | " except FileNotFoundError:\n", |
1596 | | - " raise Exception('No configuration found in directory.')\n", |
| 1613 | + " raise Exception(\"No configuration found in directory.\")\n", |
1597 | 1614 | "\n", |
1598 | 1615 | " # in 1.6.4, `local_scaler_type` / `scalers_` lived on the dataset.\n", |
1599 | 1616 | " # in order to preserve backwards-compatibility, we check to see if these are found on the dataset\n", |
|
1604 | 1621 | " # Create NeuralForecast object\n", |
1605 | 1622 | " neuralforecast = NeuralForecast(\n", |
1606 | 1623 | " models=models,\n", |
1607 | | - " freq=config_dict['freq'],\n", |
| 1624 | + " freq=config_dict[\"freq\"],\n", |
1608 | 1625 | " local_scaler_type=config_dict.get(\"local_scaler_type\", default_scalar_type),\n", |
1609 | 1626 | " )\n", |
1610 | 1627 | "\n", |
1611 | | - " attr_to_default = {\n", |
1612 | | - " \"id_col\": \"unique_id\",\n", |
1613 | | - " \"time_col\": \"ds\",\n", |
1614 | | - " \"target_col\": \"y\"\n", |
1615 | | - " }\n", |
| 1628 | + " attr_to_default = {\"id_col\": \"unique_id\", \"time_col\": \"ds\", \"target_col\": \"y\"}\n", |
1616 | 1629 | " for attr, default in attr_to_default.items():\n", |
1617 | 1630 | " setattr(neuralforecast, attr, config_dict.get(attr, default))\n", |
1618 | 1631 | " # only restore attribute if available\n", |
1619 | | - " for attr in ['prediction_intervals', '_cs_df']:\n", |
| 1632 | + " for attr in [\"prediction_intervals\", \"_cs_df\"]:\n", |
1620 | 1633 | " setattr(neuralforecast, attr, config_dict.get(attr, None))\n", |
1621 | 1634 | "\n", |
1622 | 1635 | " # Dataset\n", |
1623 | 1636 | " if dataset is not None:\n", |
1624 | 1637 | " neuralforecast.dataset = dataset\n", |
1625 | 1638 | " restore_attrs = [\n", |
1626 | | - " 'uids',\n", |
1627 | | - " 'last_dates',\n", |
1628 | | - " 'ds',\n", |
| 1639 | + " \"uids\",\n", |
| 1640 | + " \"last_dates\",\n", |
| 1641 | + " \"ds\",\n", |
1629 | 1642 | " ]\n", |
1630 | 1643 | " for attr in restore_attrs:\n", |
1631 | 1644 | " setattr(neuralforecast, attr, config_dict[attr])\n", |
1632 | 1645 | "\n", |
1633 | 1646 | " # Fitted flag\n", |
1634 | | - " neuralforecast._fitted = config_dict['_fitted']\n", |
| 1647 | + " neuralforecast._fitted = config_dict[\"_fitted\"]\n", |
1635 | 1648 | "\n", |
1636 | 1649 | " neuralforecast.scalers_ = config_dict.get(\"scalers_\", default_scalars_)\n", |
1637 | 1650 | "\n", |
|
0 commit comments