Skip to content

Commit 89994fd

Browse files
authored
support lists in assign_columns (#94)
1 parent 675172e commit 89994fd

File tree

3 files changed

+42
-9
lines changed

3 files changed

+42
-9
lines changed

.pre-commit-config.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
fail_fast: true
2+
3+
repos:
4+
- repo: https://github.com/fastai/nbdev
5+
rev: 2.2.10
6+
hooks:
7+
- id: nbdev_clean
8+
- id: nbdev_export
9+
- repo: https://github.com/astral-sh/ruff-pre-commit
10+
rev: v0.2.1
11+
hooks:
12+
- id: ruff
13+
- repo: https://github.com/pre-commit/mirrors-mypy
14+
rev: v1.8.0
15+
hooks:
16+
- id: mypy
17+
args: [--ignore-missing-imports]

nbs/processing.ipynb

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@
202202
"outputs": [],
203203
"source": [
204204
"#| export\n",
205-
"def assign_columns(df: DataFrame, names: Union[str, List[str]], values: Union[np.ndarray, pd.Series, pl_Series]) -> DataFrame:\n",
205+
"def assign_columns(df: DataFrame, names: Union[str, List[str]], values: Union[np.ndarray, pd.Series, pl_Series, List[float]]) -> DataFrame:\n",
206+
" if isinstance(values, list) and (len(values) != df.shape[0] or not isinstance(names, str)):\n",
207+
" raise ValueError('Only single column assignment is supported for lists.')\n",
206208
" if isinstance(df, pd.DataFrame):\n",
207209
" df[names] = values\n",
208210
" else:\n",
@@ -214,9 +216,13 @@
214216
" assert isinstance(names, str)\n",
215217
" vals = values.alias(names)\n",
216218
" else:\n",
217-
" if isinstance(names, str):\n",
218-
" names = [names]\n",
219-
" vals = pl.from_numpy(values, schema=names, orient='row')\n",
219+
" if isinstance(values, np.ndarray):\n",
220+
" if isinstance(names, str):\n",
221+
" names = [names]\n",
222+
" vals = pl.from_numpy(values, schema=names, orient='row')\n",
223+
" elif isinstance(values, list):\n",
224+
" assert isinstance(names, str)\n",
225+
" vals = pl_Series(name=names, values=values)\n",
220226
" df = df.with_columns(vals)\n",
221227
" return df"
222228
]
@@ -248,12 +254,14 @@
248254
" series = assign_columns(series, 'ones', 1)\n",
249255
" series = assign_columns(series, 'zeros', np.zeros(series.shape[0]))\n",
250256
" series = assign_columns(series, 'as', 'a')\n",
257+
" series = assign_columns(series, 'bs', series.shape[0] * ['b'])\n",
251258
" np.testing.assert_allclose(\n",
252259
" series[['x', 'y', 'z']],\n",
253260
" np.vstack([x, x, x]).T\n",
254261
" )\n",
255262
" np.testing.assert_equal(series['ones'], np.ones(series.shape[0]))\n",
256-
" np.testing.assert_equal(series['as'], np.full(series.shape[0], 'a'))"
263+
" np.testing.assert_equal(series['as'], np.full(series.shape[0], 'a'))\n",
264+
" np.testing.assert_equal(series['bs'], np.full(series.shape[0], 'b'))"
257265
]
258266
},
259267
{

utilsforecast/processing.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,12 @@ def maybe_compute_sort_indices(
120120
def assign_columns(
121121
df: DataFrame,
122122
names: Union[str, List[str]],
123-
values: Union[np.ndarray, pd.Series, pl_Series],
123+
values: Union[np.ndarray, pd.Series, pl_Series, List[float]],
124124
) -> DataFrame:
125+
if isinstance(values, list) and (
126+
len(values) != df.shape[0] or not isinstance(names, str)
127+
):
128+
raise ValueError("Only single column assignment is supported for lists.")
125129
if isinstance(df, pd.DataFrame):
126130
df[names] = values
127131
else:
@@ -133,9 +137,13 @@ def assign_columns(
133137
assert isinstance(names, str)
134138
vals = values.alias(names)
135139
else:
136-
if isinstance(names, str):
137-
names = [names]
138-
vals = pl.from_numpy(values, schema=names, orient="row")
140+
if isinstance(values, np.ndarray):
141+
if isinstance(names, str):
142+
names = [names]
143+
vals = pl.from_numpy(values, schema=names, orient="row")
144+
elif isinstance(values, list):
145+
assert isinstance(names, str)
146+
vals = pl_Series(name=names, values=values)
139147
df = df.with_columns(vals)
140148
return df
141149

0 commit comments

Comments
 (0)