|
2 | 2 |
|
3 | 3 |
|
4 | 4 | def get_pownet_dir() -> str: |
5 | | - """Does not assume the user saves the folder under their home directory""" |
| 5 | + """Returns the root directory of the pownet package.""" |
6 | 6 | return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
7 | 7 |
|
8 | 8 |
|
9 | | -def get_output_dir() -> str: |
10 | | - return os.path.join(get_pownet_dir(), "outputs") |
11 | | - |
12 | | - |
13 | | -def get_input_dir() -> str: |
14 | | - return os.path.join(get_pownet_dir(), "user_inputs") |
15 | | - |
16 | | - |
17 | | -def get_temp_dir() -> str: |
18 | | - return os.path.join(get_pownet_dir(), "temp") |
19 | | - |
20 | | - |
21 | | -def get_model_dir() -> str: |
22 | | - return os.path.join(get_pownet_dir(), "model_library") |
23 | | - |
24 | | - |
25 | | -def get_reservoir_dir(model_name: str) -> str: |
26 | | - return os.path.join(get_model_dir(), model_name, "reservoir") |
27 | | - |
28 | | - |
29 | | -def get_reservoir_file(model_name: str, filename: str) -> str: |
30 | | - return os.path.join(get_reservoir_dir(model_name), f"{filename}") |
31 | | - |
32 | | - |
33 | 9 | def get_home_dir() -> str: |
| 10 | + """Returns the home directory of the user. This is useful for testing purposes.""" |
34 | 11 | return os.path.expanduser("~") |
35 | 12 |
|
36 | 13 |
|
37 | 14 | def get_database_dir() -> str: |
| 15 | + """Returns the database directory of the pownet package.""" |
38 | 16 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "database") |
39 | 17 |
|
40 | 18 |
|
41 | 19 | def get_test_dir() -> str: |
| 20 | + """Returns the test directory of the pownet package.""" |
42 | 21 | return os.path.join(get_pownet_dir(), "src", "test_pownet") |
43 | | - |
44 | | - |
45 | | -def delete_all_gurobi_solutions() -> None: |
46 | | - """Remove all Gurobi solution files from the output folder. |
47 | | - Use this function at the beginning of the simulation when warmstart is on. |
48 | | - """ |
49 | | - solution_files = os.listdir(get_output_dir()) |
50 | | - for s_file in solution_files: |
51 | | - file_extension = os.path.splitext(s_file)[1] |
52 | | - if ".sol" in file_extension: |
53 | | - os.remove(os.path.join(get_output_dir(), s_file)) |
54 | | - |
55 | | - |
56 | | -def count_mps_files(instance_folder: str) -> int: |
57 | | - """Count the number of files ending with .mps""" |
58 | | - num_instances = 0 |
59 | | - for file in os.listdir(instance_folder): |
60 | | - if file.endswith(".mps"): |
61 | | - num_instances += 1 |
62 | | - return num_instances |
0 commit comments