@@ -329,6 +329,32 @@ def group_split(
329
329
).items ()
330
330
}
331
331
332
+
333
+ def new_columns (
334
+ self : Dataset [T ], ** kwargs : Callable [pd .Dataframe , pd .Series ]
335
+ ) -> Dataset [T ]:
336
+ '''
337
+ Append new column(s) to the :attr:`.Dataset.dataframe` by passing the
338
+ new column names as keywords with functions that take the
339
+ :attr:`.Dataset.dataframe` as input and return :func:`pandas.Series`.
340
+
341
+ >>> (
342
+ ... Dataset.from_dataframe(pd.DataFrame(dict(number=[1, 2, 3])))
343
+ ... .new_columns(twice=lambda df: df['number'] * 2)
344
+ ... .map(lambda row: row['twice'])
345
+ ... )[-1]
346
+ 6
347
+ '''
348
+ if len (set (kwargs .keys ()) & set (self .dataframe .columns )) >= 1 :
349
+ raise ValueError ('Should not replace existing columns' )
350
+
351
+ dataframe = self .dataframe .assign (** kwargs )
352
+ return Dataset (
353
+ dataframe = dataframe ,
354
+ length = len (dataframe ),
355
+ get_item = self .get_item ,
356
+ )
357
+
332
358
def zip_index (self : Dataset [T ]) -> Dataset [Tuple [T , int ]]:
333
359
'''
334
360
Zip the output with its index. The output of the pipeline will be
@@ -550,6 +576,18 @@ def test_subset():
550
576
assert dataset [0 ]['number' ] == numbers [2 ]
551
577
552
578
579
+ def test_new_columns ():
580
+ from pytest import raises
581
+
582
+ with raises (ValueError ):
583
+ dataset = (
584
+ Dataset .from_dataframe (pd .DataFrame (dict (
585
+ key = np .arange (100 ),
586
+ )))
587
+ .new_columns (key = lambda df : df ['key' ] * 2 )
588
+ )
589
+
590
+
553
591
def test_concat_dataset ():
554
592
dataset = Dataset .concat ([
555
593
Dataset .from_subscriptable (list (range (5 ))),
0 commit comments