|
31 | 31 | "import numpy as np\n", |
32 | 32 | "import pandas as pd\n", |
33 | 33 | "\n", |
34 | | - "from utilsforecast.compat import DataFrame, pl_DataFrame" |
| 34 | + "from utilsforecast.compat import DataFrame, Series, pl_DataFrame, pl_Series, pl" |
| 35 | + ] |
| 36 | + }, |
| 37 | + { |
| 38 | + "cell_type": "code", |
| 39 | + "execution_count": null, |
| 40 | + "id": "7638f2a6-0303-41a9-a103-d3055e7bc572", |
| 41 | + "metadata": {}, |
| 42 | + "outputs": [], |
| 43 | + "source": [ |
| 44 | + "from fastcore.test import test_fail" |
| 45 | + ] |
| 46 | + }, |
| 47 | + { |
| 48 | + "cell_type": "code", |
| 49 | + "execution_count": null, |
| 50 | + "id": "e001379f-a552-4fd9-ac82-3f4d5c76f555", |
| 51 | + "metadata": {}, |
| 52 | + "outputs": [], |
| 53 | + "source": [ |
| 54 | + "#| polars\n", |
| 55 | + "import polars.testing" |
| 56 | + ] |
| 57 | + }, |
| 58 | + { |
| 59 | + "cell_type": "code", |
| 60 | + "execution_count": null, |
| 61 | + "id": "98e0f72a-3128-46d7-8d6c-cbd0c62ce68a", |
| 62 | + "metadata": {}, |
| 63 | + "outputs": [], |
| 64 | + "source": [ |
| 65 | + "#| exporti\n", |
| 66 | + "def _is_dt_or_int(s: Series) -> bool:\n", |
| 67 | + " dtype = s.head(1).to_numpy().dtype\n", |
| 68 | + " is_dt = np.issubdtype(dtype, np.datetime64)\n", |
| 69 | + " is_int = np.issubdtype(dtype, np.integer)\n", |
| 70 | + " return is_dt or is_int" |
| 71 | + ] |
| 72 | + }, |
| 73 | + { |
| 74 | + "cell_type": "code", |
| 75 | + "execution_count": null, |
| 76 | + "id": "6679513a-4f2b-4eda-9970-e5c6725dd761", |
| 77 | + "metadata": {}, |
| 78 | + "outputs": [], |
| 79 | + "source": [ |
| 80 | + "#| export\n", |
| 81 | + "def ensure_shallow_copy(df: pd.DataFrame) -> pd.DataFrame:\n", |
| 82 | + " from packaging.version import Version\n", |
| 83 | + "\n", |
| 84 | + " if Version(pd.__version__) < Version(\"1.4\"):\n", |
| 85 | + " # https://github.com/pandas-dev/pandas/pull/43406\n", |
| 86 | + " df = df.copy()\n", |
| 87 | + " return df" |
| 88 | + ] |
| 89 | + }, |
| 90 | + { |
| 91 | + "cell_type": "code", |
| 92 | + "execution_count": null, |
| 93 | + "id": "ccb15f20-56e4-4d0d-96bc-830c5effff19", |
| 94 | + "metadata": {}, |
| 95 | + "outputs": [], |
| 96 | + "source": [ |
| 97 | + "#| export\n", |
| 98 | + "def ensure_time_dtype(df: DataFrame, time_col: str = 'ds') -> DataFrame:\n", |
| 99 | + " \"\"\"Make sure that `time_col` contains timestamps or integers.\n", |
| 100 | + " If it contains strings, try to cast them as timestamps.\"\"\"\n", |
| 101 | + " times = df[time_col]\n", |
| 102 | + " if _is_dt_or_int(times):\n", |
| 103 | + " return df\n", |
| 104 | + " parse_err_msg = (\n", |
| 105 | + " f\"Failed to parse '{time_col}' from string to datetime. \"\n", |
| 106 | + " 'Please make sure that it contains valid timestamps or integers.'\n", |
| 107 | + " )\n", |
| 108 | + " if isinstance(times, pd.Series) and pd.api.types.is_object_dtype(times):\n", |
| 109 | + " try:\n", |
| 110 | + " times = pd.to_datetime(times)\n", |
| 111 | + " except ValueError:\n", |
| 112 | + " raise ValueError(parse_err_msg)\n", |
| 113 | + " df = ensure_shallow_copy(df.copy(deep=False))\n", |
| 114 | + " df[time_col] = times\n", |
| 115 | + " elif isinstance(times, pl_Series) and times.dtype == pl.Utf8:\n", |
| 116 | + " try:\n", |
| 117 | + " times = times.str.to_datetime()\n", |
| 118 | + " except pl.exceptions.ComputeError:\n", |
| 119 | + " raise ValueError(parse_err_msg)\n", |
| 120 | + " df = df.with_columns(times)\n", |
| 121 | + " else:\n", |
| 122 | + " raise ValueError(f\"'{time_col}' should have valid timestamps or integers.\")\n", |
| 123 | + " return df" |
| 124 | + ] |
| 125 | + }, |
| 126 | + { |
| 127 | + "cell_type": "code", |
| 128 | + "execution_count": null, |
| 129 | + "id": "604ec21b-314f-42cb-9ee0-5950fd611b97", |
| 130 | + "metadata": {}, |
| 131 | + "outputs": [], |
| 132 | + "source": [ |
| 133 | + "pd.testing.assert_frame_equal(\n", |
| 134 | + " ensure_time_dtype(pd.DataFrame({'ds': ['2000-01-01']})),\n", |
| 135 | + " pd.DataFrame({'ds': pd.to_datetime(['2000-01-01'])})\n", |
| 136 | + ")\n", |
| 137 | + "df = pd.DataFrame({'ds': [1, 2]})\n", |
| 138 | + "assert df is ensure_time_dtype(df)\n", |
| 139 | + "test_fail(\n", |
| 140 | + " lambda: ensure_time_dtype(pd.DataFrame({'ds': ['2000-14-14']})),\n", |
| 141 | + " contains='Please make sure that it contains valid timestamps',\n", |
| 142 | + ")" |
| 143 | + ] |
| 144 | + }, |
| 145 | + { |
| 146 | + "cell_type": "code", |
| 147 | + "execution_count": null, |
| 148 | + "id": "5335d217-e240-46df-90a0-6aeeb07a0586", |
| 149 | + "metadata": {}, |
| 150 | + "outputs": [], |
| 151 | + "source": [ |
| 152 | + "#| polars\n", |
| 153 | + "pl.testing.assert_frame_equal(\n", |
| 154 | + " ensure_time_dtype(pl.DataFrame({'ds': ['2000-01-01']})),\n", |
| 155 | + " pl.DataFrame().with_columns(ds=pl.datetime(2000, 1, 1))\n", |
| 156 | + ")\n", |
| 157 | + "df = pl.DataFrame({'ds': [1, 2]})\n", |
| 158 | + "assert df is ensure_time_dtype(df)\n", |
| 159 | + "test_fail(\n", |
| 160 | + " lambda: ensure_time_dtype(pl.DataFrame({'ds': ['hello']})),\n", |
| 161 | + " contains='Please make sure that it contains valid timestamps',\n", |
| 162 | + ")" |
35 | 163 | ] |
36 | 164 | }, |
37 | 165 | { |
|
76 | 204 | " raise ValueError(f\"The following columns are missing: {missing_cols}\")\n", |
77 | 205 | "\n", |
78 | 206 | " # time col\n", |
79 | | - " times_dtype = df[time_col].head(1).to_numpy().dtype\n", |
80 | | - " if not (np.issubdtype(times_dtype, np.datetime64) or np.issubdtype(times_dtype, np.integer)):\n", |
| 207 | + " if not _is_dt_or_int(df[time_col]):\n", |
| 208 | + " times_dtype = df[time_col].head(1).to_numpy().dtype\n", |
81 | 209 | " raise ValueError(f\"The time column ('{time_col}') should have either timestamps or integers, got '{times_dtype}'.\")\n", |
82 | 210 | "\n", |
83 | 211 | " # target col\n", |
84 | | - " target_dtype = df[target_col].head(1).to_numpy().dtype\n", |
85 | | - " if not np.issubdtype(target_dtype, np.number):\n", |
86 | | - " raise ValueError(f\"The target column ('{target_col}') should have a numeric data type, got '{target_dtype}')\")" |
| 212 | + " target = df[target_col]\n", |
| 213 | + " if isinstance(target, pd.Series):\n", |
| 214 | + " is_numeric = np.issubdtype(target.dtype.type, np.number)\n", |
| 215 | + " else:\n", |
| 216 | + " is_numeric = target.is_numeric()\n", |
| 217 | + " if not is_numeric:\n", |
| 218 | + " raise ValueError(f\"The target column ('{target_col}') should have a numeric data type, got '{target.dtype}')\")" |
87 | 219 | ] |
88 | 220 | }, |
89 | 221 | { |
|
108 | 240 | "text/markdown": [ |
109 | 241 | "---\n", |
110 | 242 | "\n", |
111 | | - "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L12){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 243 | + "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L57){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
112 | 244 | "\n", |
113 | 245 | "### validate_format\n", |
114 | 246 | "\n", |
|
130 | 262 | "text/plain": [ |
131 | 263 | "---\n", |
132 | 264 | "\n", |
133 | | - "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L12){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 265 | + "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L57){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
134 | 266 | "\n", |
135 | 267 | "### validate_format\n", |
136 | 268 | "\n", |
|
168 | 300 | "source": [ |
169 | 301 | "import datetime\n", |
170 | 302 | "\n", |
171 | | - "import pandas as pd\n", |
172 | | - "from fastcore.test import test_fail\n", |
173 | | - "\n", |
174 | | - "from utilsforecast.compat import POLARS_INSTALLED, pl\n", |
| 303 | + "from utilsforecast.compat import POLARS_INSTALLED\n", |
175 | 304 | "from utilsforecast.data import generate_series" |
176 | 305 | ] |
177 | 306 | }, |
|
0 commit comments