Skip to content

Commit 18d697e

Browse files
committed
Change helper function to return single output dir
1 parent 9052ec3 commit 18d697e

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

tests/test_helper.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,9 @@ def test_input_directories():
6868
assert len(input_dirs) > 0
6969

7070

71-
def test_output_directories():
72-
output_dirs = hlp.get_output_directories()
73-
assert isinstance(output_dirs, list)
74-
assert len(output_dirs) > 0
71+
def test_output_directory():
72+
output_dir = hlp.get_output_directory()
73+
assert isinstance(output_dir, pathlib.Path)
7574

7675

7776
def test_tau_thresholds():
@@ -389,8 +388,8 @@ def test_get_output_row(mocker):
389388

390389

391390
def test_get_output_filename(mocker):
392-
get_output_directories_mock = mocker.patch('tokenomics_decentralization.helper.get_output_directories')
393-
get_output_directories_mock.return_value = [pathlib.Path(__file__).resolve().parent]
391+
get_output_directory_mock = mocker.patch('tokenomics_decentralization.helper.get_output_directory')
392+
get_output_directory_mock.return_value = pathlib.Path(__file__).resolve().parent
394393
get_exclude_contracts_mock = mocker.patch('tokenomics_decentralization.helper.get_exclude_contracts_flag')
395394
get_exclude_contracts_mock.return_value = False
396395
get_exclude_below_fees_mock = mocker.patch('tokenomics_decentralization.helper.get_exclude_below_fees_flag')

tokenomics_decentralization/helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,13 @@ def increment_date(date, by):
141141
raise ValueError(f'Invalid granularity: {by}')
142142

143143

144-
def get_output_directories():
144+
def get_output_directory():
145145
"""
146146
Reads the config file and retrieves the output directories
147147
:returns: a list of directories that might contain the db files
148148
"""
149149
config = get_config_data()
150-
return [pathlib.Path(db_dir).resolve() for db_dir in config['output_directories']]
150+
return [pathlib.Path(db_dir).resolve() for db_dir in config['output_directories']][0]
151151

152152

153153
def get_input_directories():
@@ -496,7 +496,7 @@ def get_output_filename():
496496
if exclude_below_usd_cent_flag:
497497
output_filename += '-exclude_below_usd_cent'
498498
output_filename += '.csv'
499-
return get_output_directories()[0] / output_filename
499+
return get_output_directory() / output_filename
500500

501501

502502
def write_csv_output(output_rows):

0 commit comments

Comments
 (0)