@@ -56,7 +56,7 @@ def __init__(self, model, *column_names, **extras):
56
56
self .model = model
57
57
self ._distinct = None
58
58
if column_names :
59
- self .columns = [ getattr (model , name ) for name in column_names ]
59
+ self .columns = self . _column_loader (model , column_names )
60
60
else :
61
61
self .columns = model
62
62
self .extras = dict ((key , self .get (value ))
@@ -123,11 +123,28 @@ def get_from(self):
123
123
124
124
def load (self , * column_names , ** extras ):
125
125
if column_names :
126
- self .columns = [getattr (self .model , name ) for name in column_names ]
126
+ self .columns = self ._column_loader (self .model , column_names )
127
+
127
128
self .extras .update ((key , self .get (value ))
128
129
for key , value in extras .items ())
129
130
return self
130
131
132
+ @classmethod
133
+ def _column_loader (cls , model , column_names ):
134
+ def column_formatter (column_name ):
135
+ if isinstance (column_name , str ):
136
+ return getattr (model , column_name )
137
+ elif isinstance (column_name , Column ):
138
+ if column_name not in model :
139
+ raise AttributeError ('Column {} does not belong '
140
+ 'to this model' .format (column_name ))
141
+ return column_name
142
+ else :
143
+ raise TypeError ('Unknown column name {} type {}' .
144
+ format (column_name , type (column_name )))
145
+
146
+ return [column_formatter (column_name ) for column_name in column_names ]
147
+
131
148
def on (self , on_clause ):
132
149
self .on_clause = on_clause
133
150
return self
0 commit comments