Skip to content

Commit a7b7a81

Browse files
committed
Update folder_utils.py
1 parent a8acd8b commit a7b7a81

File tree

1 file changed

+4
-45
lines changed

1 file changed

+4
-45
lines changed

src/pownet/folder_utils.py

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,61 +2,20 @@
22

33

44
def get_pownet_dir() -> str:
5-
"""Does not assume the user saves the folder under their home directory"""
5+
"""Returns the root directory of the pownet package."""
66
return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
77

88

9-
def get_output_dir() -> str:
10-
return os.path.join(get_pownet_dir(), "outputs")
11-
12-
13-
def get_input_dir() -> str:
14-
return os.path.join(get_pownet_dir(), "user_inputs")
15-
16-
17-
def get_temp_dir() -> str:
18-
return os.path.join(get_pownet_dir(), "temp")
19-
20-
21-
def get_model_dir() -> str:
22-
return os.path.join(get_pownet_dir(), "model_library")
23-
24-
25-
def get_reservoir_dir(model_name: str) -> str:
26-
return os.path.join(get_model_dir(), model_name, "reservoir")
27-
28-
29-
def get_reservoir_file(model_name: str, filename: str) -> str:
30-
return os.path.join(get_reservoir_dir(model_name), f"{filename}")
31-
32-
339
def get_home_dir() -> str:
10+
"""Returns the home directory of the user. This is useful for testing purposes."""
3411
return os.path.expanduser("~")
3512

3613

3714
def get_database_dir() -> str:
15+
"""Returns the database directory of the pownet package."""
3816
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "database")
3917

4018

4119
def get_test_dir() -> str:
20+
"""Returns the test directory of the pownet package."""
4221
return os.path.join(get_pownet_dir(), "src", "test_pownet")
43-
44-
45-
def delete_all_gurobi_solutions() -> None:
46-
"""Remove all Gurobi solution files from the output folder.
47-
Use this function at the beginning of the simulation when warmstart is on.
48-
"""
49-
solution_files = os.listdir(get_output_dir())
50-
for s_file in solution_files:
51-
file_extension = os.path.splitext(s_file)[1]
52-
if ".sol" in file_extension:
53-
os.remove(os.path.join(get_output_dir(), s_file))
54-
55-
56-
def count_mps_files(instance_folder: str) -> int:
57-
"""Count the number of files ending with .mps"""
58-
num_instances = 0
59-
for file in os.listdir(instance_folder):
60-
if file.endswith(".mps"):
61-
num_instances += 1
62-
return num_instances

0 commit comments

Comments
 (0)