33
44import unittest
55import os
6+ from functools import wraps
7+ from typing import Callable
68
79import numpy as np
810
@@ -13,13 +15,42 @@ def get_dataset(name):
1315 return os .path .join (os .path .dirname (__file__ ), "xlsx_files" , name )
1416
1517
16- def read_file (name ):
17- return io .ExcelReader (get_dataset (name )).read ()
18+ def get_xlsx_reader (name : str ) -> io .ExcelReader :
19+ return io .ExcelReader (get_dataset (name ))
20+
21+
22+ def get_xls_reader (name : str ) -> io .XlsReader :
23+ return io .XlsReader (get_dataset (name ))
24+
25+
26+ def read_file (reader : Callable , name : str ) -> Table :
27+ return reader (name ).read ()
28+
29+
30+ def test_xlsx_xls (f ):
31+ @wraps (f )
32+ def wrapper (self ):
33+ f (self , get_xlsx_reader )
34+ f (self , get_xls_reader )
35+ return wrapper
36+
37+
38+ class TestExcelReader (unittest .TestCase ):
39+ def test_read_round_floats (self ):
40+ table = read_file (get_xlsx_reader , "round_floats.xlsx" )
41+ domain = table .domain
42+ self .assertIsNone (domain .class_var )
43+ self .assertEqual (len (domain .metas ), 0 )
44+ self .assertEqual (len (domain .attributes ), 3 )
45+ self .assertIsInstance (domain [0 ], ContinuousVariable )
46+ self .assertIsInstance (domain [1 ], ContinuousVariable )
47+ self .assertListEqual (domain [2 ].values , ["1" , "2" ])
1848
1949
2050class TestExcelHeader0 (unittest .TestCase ):
21- def test_read (self ):
22- table = read_file ("header_0.xlsx" )
51+ @test_xlsx_xls
52+ def test_read (self , reader : Callable [[str ], io .FileFormat ]):
53+ table = read_file (reader , "header_0.xlsx" )
2354 domain = table .domain
2455 self .assertIsNone (domain .class_var )
2556 self .assertEqual (len (domain .metas ), 0 )
@@ -35,29 +66,37 @@ def test_read(self):
3566
3667
3768class TextExcelSheets (unittest .TestCase ):
38- def setUp (self ):
39- self .reader = io .ExcelReader (get_dataset ("header_0_sheet.xlsx" ))
40-
41- def test_sheets (self ):
42- self .assertSequenceEqual (self .reader .sheets ,
69+ @test_xlsx_xls
70+ def test_sheets (self , reader : Callable [[str ], io .FileFormat ]):
71+ reader = reader ("header_0_sheet.xlsx" )
72+ self .assertSequenceEqual (reader .sheets ,
4373 ["Sheet1" , "my_sheet" , "Sheet3" ])
4474
45- def test_named_sheet (self ):
46- self .reader .select_sheet ("my_sheet" )
47- table = self .reader .read ()
75+ @test_xlsx_xls
76+ def test_named_sheet (self , reader : Callable [[str ], io .FileFormat ]):
77+ reader = reader ("header_0_sheet.xlsx" )
78+ reader .select_sheet ("my_sheet" )
79+ table = reader .read ()
4880 self .assertEqual (len (table .domain .attributes ), 4 )
4981 self .assertEqual (table .name , 'header_0_sheet-my_sheet' )
5082
51- def test_named_sheet_table (self ):
83+ def test_named_sheet_table_xlsx (self ):
5284 table = Table .from_file (get_dataset ("header_0_sheet.xlsx" ),
5385 sheet = "my_sheet" )
5486 self .assertEqual (len (table .domain .attributes ), 4 )
5587 self .assertEqual (table .name , 'header_0_sheet-my_sheet' )
5688
89+ def test_named_sheet_table_xls (self ):
90+ table = Table .from_file (get_dataset ("header_0_sheet.xls" ),
91+ sheet = "my_sheet" )
92+ self .assertEqual (len (table .domain .attributes ), 4 )
93+ self .assertEqual (table .name , 'header_0_sheet-my_sheet' )
94+
5795
5896class TestExcelHeader1 (unittest .TestCase ):
59- def test_no_flags (self ):
60- table = read_file ("header_1_no_flags.xlsx" )
97+ @test_xlsx_xls
98+ def test_no_flags (self , reader : Callable [[str ], io .FileFormat ]):
99+ table = read_file (reader , "header_1_no_flags.xlsx" )
61100 domain = table .domain
62101 self .assertEqual (len (domain .metas ), 0 )
63102 self .assertEqual (len (domain .attributes ), 4 )
@@ -74,8 +113,9 @@ def test_no_flags(self):
74113 [0 , 0 , np .nan , 0 ]]))
75114 np .testing .assert_equal (table .Y , np .array ([]).reshape (3 , 0 ))
76115
77- def test_flags (self ):
78- table = read_file ("header_1_flags.xlsx" )
116+ @test_xlsx_xls
117+ def test_flags (self , reader : Callable [[str ], io .FileFormat ]):
118+ table = read_file (reader , "header_1_flags.xlsx" )
79119 domain = table .domain
80120
81121 self .assertEqual (len (domain .attributes ), 1 )
@@ -104,8 +144,9 @@ def test_flags(self):
104144
105145
106146class TestExcelHeader3 (unittest .TestCase ):
107- def test_read (self ):
108- table = read_file ("header_3.xlsx" )
147+ @test_xlsx_xls
148+ def test_read (self , reader : Callable [[str ], io .FileFormat ]):
149+ table = read_file (reader , "header_3.xlsx" )
109150 domain = table .domain
110151
111152 self .assertEqual (len (domain .attributes ), 2 )
0 commit comments