diff --git a/flattentool/json_input.py b/flattentool/json_input.py index e1cc6a1..8b737ff 100644 --- a/flattentool/json_input.py +++ b/flattentool/json_input.py @@ -161,12 +161,11 @@ def __init__( self.main_sheet = PersistentSheet.from_sheet( schema_parser.main_sheet, self.connection ) - for sheet_name, sheet in list(self.sub_sheets.items()): + for sheet_name, sheet in list(schema_parser.sub_sheets.items()): self.sub_sheets[sheet_name] = PersistentSheet.from_sheet( sheet, self.connection ) - self.sub_sheets = copy.deepcopy(schema_parser.sub_sheets) if remove_empty_schema_columns: # Don't use columns from the schema parser # (avoids empty columns) @@ -312,7 +311,7 @@ def parse(self): if self.remove_empty_schema_columns: # Remove sheets with no lines of data for sheet_name, sheet in list(self.sub_sheets.items()): - if not sheet.lines: + if not sheet.line_count: del self.sub_sheets[sheet_name] if self.preserve_fields_input: diff --git a/flattentool/sheet.py b/flattentool/sheet.py index 48d7a98..2531f5f 100644 --- a/flattentool/sheet.py +++ b/flattentool/sheet.py @@ -21,6 +21,10 @@ def __init__(self, columns=None, root_id="", name=None): def lines(self): return self._lines + @property + def line_count(self): + return len(self._lines) + def add_field(self, field, id_field=False): columns = self.id_columns if id_field else self.columns if field not in columns: @@ -65,6 +69,10 @@ def lines(self): self.connection.cacheMinimize() yield value + @property + def line_count(self): + return self.index + def append_line(self, flattened_dict): self.connection.root.sheet_store[self.name][self.index] = flattened_dict self.index += 1 diff --git a/flattentool/tests/test_json_input.py b/flattentool/tests/test_json_input.py index 3535786..b9e067e 100644 --- a/flattentool/tests/test_json_input.py +++ b/flattentool/tests/test_json_input.py @@ -327,7 +327,7 @@ def test_sub_sheets(self, tmpdir, remove_empty_schema_columns): assert list(parser.sub_sheets["g"]) == list(["ocid", "g/0/h"]) else: assert list(parser.sub_sheets["c"]) == list(["ocid", "c/0/d"]) - assert parser.sub_sheets["c"].lines == [{"c/0/d": "e"}] + assert list(parser.sub_sheets["c"].lines) == [{"c/0/d": "e"}] def test_column_matching(self, tmpdir): test_schema = tmpdir.join("test.json") @@ -387,7 +387,7 @@ def test_rollup(self): assert set(parser.sub_sheets["testA"]) == set( ["ocid", "testA/0/testB", "testA/0/testC"] ) - assert parser.sub_sheets["testA"].lines == [ + assert list(parser.sub_sheets["testA"].lines) == [ {"testA/0/testB": "1", "testA/0/testC": "2"} ] @@ -438,7 +438,7 @@ def test_rollup_multiple_values(self, recwarn): assert set(parser.sub_sheets["testA"]) == set( ["testA/0/testB", "testA/0/testC"] ) - assert parser.sub_sheets["testA"].lines == [ + assert list(parser.sub_sheets["testA"].lines) == [ {"testA/0/testB": "1", "testA/0/testC": "2"}, {"testA/0/testB": "3", "testA/0/testC": "4"}, ]