|
202 | 202 | "outputs": [], |
203 | 203 | "source": [ |
204 | 204 | "#| export\n", |
205 | | - "def assign_columns(df: DataFrame, names: Union[str, List[str]], values: Union[np.ndarray, pd.Series, pl_Series]) -> DataFrame:\n", |
| 205 | + "def assign_columns(df: DataFrame, names: Union[str, List[str]], values: Union[np.ndarray, pd.Series, pl_Series, List[float]]) -> DataFrame:\n", |
| 206 | + " if isinstance(values, list) and (len(values) != df.shape[0] or not isinstance(names, str)):\n", |
| 207 | + " raise ValueError('Only single column assignment is supported for lists.')\n", |
206 | 208 | " if isinstance(df, pd.DataFrame):\n", |
207 | 209 | " df[names] = values\n", |
208 | 210 | " else:\n", |
|
214 | 216 | " assert isinstance(names, str)\n", |
215 | 217 | " vals = values.alias(names)\n", |
216 | 218 | " else:\n", |
217 | | - " if isinstance(names, str):\n", |
218 | | - " names = [names]\n", |
219 | | - " vals = pl.from_numpy(values, schema=names, orient='row')\n", |
| 219 | + " if isinstance(values, np.ndarray):\n", |
| 220 | + " if isinstance(names, str):\n", |
| 221 | + " names = [names]\n", |
| 222 | + " vals = pl.from_numpy(values, schema=names, orient='row')\n", |
| 223 | + " elif isinstance(values, list):\n", |
| 224 | + " assert isinstance(names, str)\n", |
| 225 | + " vals = pl_Series(name=names, values=values)\n", |
220 | 226 | " df = df.with_columns(vals)\n", |
221 | 227 | " return df" |
222 | 228 | ] |
|
248 | 254 | " series = assign_columns(series, 'ones', 1)\n", |
249 | 255 | " series = assign_columns(series, 'zeros', np.zeros(series.shape[0]))\n", |
250 | 256 | " series = assign_columns(series, 'as', 'a')\n", |
| 257 | + " series = assign_columns(series, 'bs', series.shape[0] * ['b'])\n", |
251 | 258 | " np.testing.assert_allclose(\n", |
252 | 259 | " series[['x', 'y', 'z']],\n", |
253 | 260 | " np.vstack([x, x, x]).T\n", |
254 | 261 | " )\n", |
255 | 262 | " np.testing.assert_equal(series['ones'], np.ones(series.shape[0]))\n", |
256 | | - " np.testing.assert_equal(series['as'], np.full(series.shape[0], 'a'))" |
| 263 | + " np.testing.assert_equal(series['as'], np.full(series.shape[0], 'a'))\n", |
| 264 | + " np.testing.assert_equal(series['bs'], np.full(series.shape[0], 'b'))" |
257 | 265 | ] |
258 | 266 | }, |
259 | 267 | { |
|
0 commit comments