@@ -169,6 +169,9 @@ def get_sub_sheets_lines(self):
169
169
def get_sheet_lines (self , sheet_name ):
170
170
raise NotImplementedError
171
171
172
+ def get_sheet_headings (self , sheet_name ):
173
+ raise NotImplementedError
174
+
172
175
def read_sheets (self ):
173
176
raise NotImplementedError
174
177
@@ -190,6 +193,11 @@ def do_unflatten(self):
190
193
sheets = [(self .main_sheet_name , self .get_main_sheet_lines ())] + list (self .get_sub_sheets_lines ())
191
194
for i , sheet in enumerate (sheets ):
192
195
sheet_name , lines = sheet
196
+ try :
197
+ actual_headings = self .get_sheet_headings (sheet_name )
198
+ except NotImplementedError :
199
+ # The ListInput type used in the tests doesn't support getting headings.
200
+ actual_headings = None
193
201
for j , line in enumerate (lines ):
194
202
if all (x is None or x == '' for x in line .values ()):
195
203
#if all(x == '' for x in line.values()):
@@ -198,7 +206,10 @@ def do_unflatten(self):
198
206
if WITH_CELLS :
199
207
cells = OrderedDict ()
200
208
for k , header in enumerate (line ):
201
- cells [header ] = Cell (line [header ], (sheet_name , _get_column_letter (k + 1 ), j + 2 , header ))
209
+ if actual_headings :
210
+ cells [header ] = Cell (line [header ], (sheet_name , _get_column_letter (k + 1 ), j + 2 , actual_headings [k ]))
211
+ else :
212
+ cells [header ] = Cell (line [header ], (sheet_name , _get_column_letter (k + 1 ), j + 2 , header ))
202
213
unflattened = unflatten_main_with_parser (self .parser , cells , self .timezone )
203
214
else :
204
215
unflattened = unflatten_main_with_parser (self .parser , line , self .timezone )
@@ -241,10 +252,11 @@ def fancy_unflatten(self):
241
252
cell_tree = self .do_unflatten ()
242
253
result = extract_list_to_value (cell_tree )
243
254
cell_source_map = extract_list_to_error_path ([self .main_sheet_name .lower ()], cell_tree )
244
- ordered_cell_source_map = OrderedDict (( '/' .join (str (x ) for x in path ), location ) for path , location in sorted (cell_source_map .items ()))
255
+ ordered_items = sorted (cell_source_map .items ())
256
+ ordered_cell_source_map = OrderedDict (( '/' .join (str (x ) for x in path ), location ) for path , location in ordered_items )
245
257
row_source_map = OrderedDict ()
246
258
heading_source_map = {}
247
- for path in cell_source_map :
259
+ for path , _ in ordered_items :
248
260
cells = cell_source_map [path ]
249
261
# Prepare row_source_map key
250
262
key = '/' .join (str (x ) for x in path [:- 1 ])
@@ -327,6 +339,20 @@ def extract_dict_to_value(input):
327
339
class CSVInput (SpreadsheetInput ):
328
340
encoding = 'utf-8'
329
341
342
+ def get_sheet_headings (self , sheet_name ):
343
+ if sys .version > '3' : # If Python 3 or greater
344
+ with open (os .path .join (self .input_name , sheet_name + '.csv' ), encoding = self .encoding ) as main_sheet_file :
345
+ r = csvreader (main_sheet_file )
346
+ for row in enumerate (r ):
347
+ # Just return the first row
348
+ return row [1 ]
349
+ else : # If Python 2
350
+ with open (os .path .join (self .input_name , sheet_name + '.csv' )) as main_sheet_file :
351
+ r = csvreader (main_sheet_file , encoding = self .encoding )
352
+ for row in enumerate (r ):
353
+ # Just return the first row
354
+ return row [1 ]
355
+
330
356
def read_sheets (self ):
331
357
sheet_file_names = os .listdir (self .input_name )
332
358
if self .main_sheet_name + '.csv' not in sheet_file_names :
@@ -367,6 +393,10 @@ def read_sheets(self):
367
393
sheet_names .remove (self .main_sheet_name )
368
394
self .sub_sheet_names = sheet_names
369
395
396
+ def get_sheet_headings (self , sheet_name ):
397
+ worksheet = self .workbook [self .sheet_names_map [sheet_name ]]
398
+ return [cell .value for cell in worksheet .rows [0 ]]
399
+
370
400
def get_sheet_lines (self , sheet_name ):
371
401
worksheet = self .workbook [self .sheet_names_map [sheet_name ]]
372
402
header_row = worksheet .rows [0 ]
0 commit comments