@@ -30,6 +30,18 @@ class Extend(data_algebra.pipe.PipeStep):
3030 ops : Dict [str , data_algebra .expr_rep .Expression ]
3131
3232 def __init__ (self , ops , * , partition_by = None , order_by = None , reverse = None ):
33+ if isinstance (partition_by , str ):
34+ partition_by = [partition_by ]
35+ if isinstance (order_by , str ):
36+ order_by = [order_by ]
37+ if isinstance (reverse , str ):
38+ reverse = [reverse ]
39+ if reverse is not None and len (reverse ) > 0 :
40+ if order_by is None :
41+ raise ValueError ("set is None when order_by is not None" )
42+ unknown = set (reverse ) - set (order_by )
43+ if len (unknown ) > 0 :
44+ raise ValueError ("columns in reverse that are not in order_by: " + str (unknown ))
3345 data_algebra .pipe .PipeStep .__init__ (self , name = "Extend" )
3446 self ._ops = ops
3547 self .partition_by = partition_by
@@ -68,6 +80,8 @@ class Project(data_algebra.pipe.PipeStep):
6880 ops : Dict [str , data_algebra .expr_rep .Expression ]
6981
7082 def __init__ (self , ops , * , group_by = None ):
83+ if isinstance (group_by , str ):
84+ group_by = [group_by ]
7185 data_algebra .pipe .PipeStep .__init__ (self , name = "Project" )
7286 self ._ops = ops
7387 self .group_by = group_by
@@ -124,6 +138,8 @@ class SelectColumns(data_algebra.pipe.PipeStep):
124138 column_selection : List [str ]
125139
126140 def __init__ (self , columns ):
141+ if isinstance (columns , str ):
142+ columns = [columns ]
127143 column_selection = [c for c in columns ]
128144 self .column_selection = column_selection
129145 data_algebra .pipe .PipeStep .__init__ (self , name = "SelectColumns" )
@@ -151,6 +167,8 @@ class DropColumns(data_algebra.pipe.PipeStep):
151167 column_deletions : List [str ]
152168
153169 def __init__ (self , column_deletions ):
170+ if isinstance (column_deletions , str ):
171+ column_deletions = [column_deletions ]
154172 column_deletions = [c for c in column_deletions ]
155173 self .column_deletions = column_deletions
156174 data_algebra .pipe .PipeStep .__init__ (self , name = "DropColumns" )
@@ -179,6 +197,16 @@ class OrderRows(data_algebra.pipe.PipeStep):
179197 reverse : List [str ]
180198
181199 def __init__ (self , columns , * , reverse = None , limit = None ):
200+ if isinstance (columns , str ):
201+ columns = [columns ]
202+ if isinstance (reverse , str ):
203+ reverse = [reverse ]
204+ if reverse is not None and len (reverse ) > 0 :
205+ if columns is None :
206+ raise ValueError ("set is None when order_by is not None" )
207+ unknown = set (reverse ) - set (columns )
208+ if len (unknown ) > 0 :
209+ raise ValueError ("columns in reverse that are not in order_by: " + str (unknown ))
182210 self .order_columns = [c for c in columns ]
183211 if reverse is None :
184212 reverse = []
@@ -239,6 +267,8 @@ class NaturalJoin(data_algebra.pipe.PipeStep):
239267 def __init__ (self , * , b = None , by = None , jointype = "INNER" ):
240268 if not isinstance (b , data_algebra .data_ops .ViewRepresentation ):
241269 raise TypeError ("b must be a data_algebra.data_ops.ViewRepresentation" )
270+ if isinstance (by , str ):
271+ by = [by ]
242272 self ._by = by
243273 self ._jointype = jointype
244274 self ._b = b
0 commit comments