Skip to content

Commit 148878b

Browse files
Merge remote-tracking branch 'origin/slurm-update-and-fixes' into slurm-update-and-fixes
2 parents 9857490 + 686680c commit 148878b

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

bibigrid/core/actions/create.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,10 @@ def write_cluster_state(self, state):
8080
with open(CLUSTER_MEMORY_PATH, mode="w+", encoding="UTF-8") as cluster_memory_file:
8181
yaml.safe_dump(data=state, stream=cluster_memory_file)
8282
# all clusters
83-
with open(os.path.join(CLUSTER_INFO_FOLDER, f"{self.cluster_id}.yaml"), mode="w+",
84-
encoding="UTF-8") as cluster_info_file:
83+
cluster_info_path = os.path.normpath(os.path.join(CLUSTER_INFO_FOLDER, f"{self.cluster_id}.yaml"))
84+
if not cluster_info_path.startswith(os.path.normpath(CLUSTER_INFO_FOLDER)):
85+
raise Exception("Invalid cluster_id resulting in path traversal")
86+
with open(cluster_info_path, mode="w+", encoding="UTF-8") as cluster_info_file:
8587
yaml.safe_dump(data=state, stream=cluster_info_file)
8688

8789
def create_defaults(self):

bibigrid/core/actions/terminate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ def write_cluster_state(cluster_id, state):
2020
with open(CLUSTER_MEMORY_PATH, mode="w+", encoding="UTF-8") as cluster_memory_file:
2121
yaml.safe_dump(data=state, stream=cluster_memory_file)
2222
# all clusters
23-
with open(os.path.join(CLUSTER_INFO_FOLDER, f"{cluster_id}.yaml"), mode="w+",
24-
encoding="UTF-8") as cluster_info_file:
23+
cluster_info_path = os.path.normpath(os.path.join(CLUSTER_INFO_FOLDER, f"{cluster_id}.yaml"))
24+
if not cluster_info_path.startswith(CLUSTER_INFO_FOLDER):
25+
raise Exception("Invalid cluster_id")
26+
with open(cluster_info_path, mode="w+", encoding="UTF-8") as cluster_info_file:
2527
yaml.safe_dump(data=state, stream=cluster_info_file)
2628

27-
2829
def terminate(cluster_id, providers, log, debug=False, assume_yes=False):
2930
"""
3031
Goes through all providers and gets info of all servers which name contains cluster ID.

0 commit comments

Comments
 (0)