Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ParProcCo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
version_info = (2, 1, 2)
version_info = (2, 2, 0)
__version__ = ".".join(str(c) for c in version_info)
13 changes: 7 additions & 6 deletions ParProcCo/job_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
JobSubmitReq,
JobStateEnum,
StringArray,
Uint32NoVal,
Uint64NoVal,
Uint32NoValStruct,
Uint64NoValStruct,
)
from .utils import check_jobscript_is_readable

Expand Down Expand Up @@ -151,7 +151,7 @@ def fetch_and_update_state(
job_id = job_info.job_id
if job_id is None or job_id < 0:
raise ValueError(f"Job info has invalid job id: {job_info}")
state = job_info.job_state
state = job_info.job_state.root
slurm_state = SLURMSTATE[state[0].value] if state else None

start_time = (
Expand Down Expand Up @@ -327,23 +327,24 @@ def make_job_submission(
cpus_per_task=job_scheduling_info.job_resources.cpu_cores,
tres_per_task=f"gres/gpu:{job_scheduling_info.job_resources.gpus}",
tasks=1,
time_limit=Uint32NoVal(
time_limit=Uint32NoValStruct(
number=int((job_scheduling_info.timeout.total_seconds() + 59) // 60),
set=True,
),
environment=StringArray(root=env_list),
memory_per_cpu=Uint64NoVal(
memory_per_cpu=Uint64NoValStruct(
number=job_scheduling_info.job_resources.memory, set=True
),
current_working_directory=str(job_scheduling_info.working_directory),
standard_output=str(job_scheduling_info.get_stdout_path()),
standard_error=str(job_scheduling_info.get_stderr_path()),
script=job_script_command,
)
if job_scheduling_info.job_resources.extra_properties:
for k, v in job_scheduling_info.job_resources.extra_properties.items():
setattr(job, k, v)

return JobSubmitReq(script=job_script_command, job=job)
return JobSubmitReq(job=job)

def wait_all_jobs(
self,
Expand Down
53 changes: 31 additions & 22 deletions ParProcCo/slurm/generate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,47 @@
import yaml


def replace_refs(in_dict: dict, version_prefix: str, db_version_prefix: str):
def filter_refs(in_dict: dict, version_prefix: str, all_refs: set):
sind = len("#/components/schemas/")
for k, v in in_dict.items():
if isinstance(v, dict):
replace_refs(v, version_prefix, db_version_prefix)
filter_refs(v, version_prefix, all_refs)
elif isinstance(v, list):
for i in v:
if isinstance(i, dict):
replace_refs(i, version_prefix, db_version_prefix)
filter_refs(i, version_prefix, all_refs)
if k == "$ref": # and isinstance(v, str):
assert isinstance(v, str)
nv = v.replace(db_version_prefix, "db").replace(version_prefix, "")
in_dict[k] = nv
if version_prefix in v:
nv = v.replace(version_prefix, "")
all_refs.add(nv[sind:])
in_dict[k] = nv


def filter_paths(paths: dict, version: str, slurm_only: bool):
new_dict = {}
path_head = "slurm" if slurm_only else "slurmdb"
for k, v in paths.items():
kparts = k.split("/")
kparts = k.split(
"/"
) # '/slurm/v0.0.40/shares' => ['', 'slurm', 'v0.0.40', 'shares']
kp1 = kparts[1]
if len(kparts) > 2:
if kparts[2] == version:
if (slurm_only and kp1 == "slurm") or (
not slurm_only and kp1 == "slurmdb"
):
new_dict[k] = v
else:
if kp1 == path_head and kparts[2] == version:
new_dict[k] = v
else: # global paths
new_dict[k] = v
print(new_dict.keys())
return new_dict


def filter_components(components: dict, version: str, slurm_only: bool):
def filter_components(components: dict, version_prefix: str, all_refs: dict):
new_dict = {}
if not slurm_only:
version = f"db{version}"
vind = len(version) + 1
kp = "" if slurm_only else "db_"
vind = len(version_prefix)
for k, v in components.items():
if k.startswith(version):
new_dict[kp + k[vind:]] = v
if k.startswith(version_prefix):
filter_refs(v, version_prefix, all_refs)
new_dict[k[vind:]] = v
return new_dict


Expand All @@ -52,10 +53,18 @@ def generate_slurm_models(input_file: str, version: str, slurm_only: bool):
schema = json.load(f)

schema["paths"] = filter_paths(schema["paths"], version, slurm_only)
schema["components"]["schemas"] = filter_components(
schema["components"]["schemas"], version, slurm_only
all_refs = set()
version_prefix = f"{version}_"
filter_refs(schema["paths"], version_prefix, all_refs)
all_schemas = filter_components(
schema["components"]["schemas"], version_prefix, all_refs
)
replace_refs(schema, f"{version}_", f"db{version}")
print(
"Removing these unreferenced schema parts:", set(all_schemas.keys()) - all_refs
)
schema["components"]["schemas"] = {
k: s for k, s in all_schemas.items() if k in all_refs
}
return schema


Expand Down
19 changes: 9 additions & 10 deletions ParProcCo/slurm/slurm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@

from .slurm_rest import (
JobInfo,
JobSubmitResponseMsg,
JobSubmitReq,
OpenapiJobInfoResp,
OpenapiJobSubmitResponse,
OpenapiResp,
OpenapiKillJobResp,
)

_SLURM_VERSION = "v0.0.40"
_SLURM_VERSION = "v0.0.42"


def get_slurm_token() -> str:
Expand Down Expand Up @@ -77,7 +76,7 @@ def _get_response_json(self, response: requests.Response) -> dict:
def _has_openapi_errors(
self,
heading: str,
oar: OpenapiResp | OpenapiJobInfoResp | OpenapiJobSubmitResponse,
oar: OpenapiJobInfoResp | OpenapiJobSubmitResponse | OpenapiKillJobResp,
) -> bool:
if oar.warnings and oar.warnings.root:
logging.warning(heading)
Expand Down Expand Up @@ -111,21 +110,21 @@ def get_job(self, job_id: int) -> JobInfo:
raise ValueError(f"Multiple jobs returned {jobs}")
raise ValueError(f"No job info found for job id {job_id}")

def submit_job(self, job_submission: JobSubmitReq) -> JobSubmitResponseMsg:
def submit_job(self, job_submission: JobSubmitReq) -> OpenapiJobSubmitResponse:
response = self._post("job/submit", job_submission)
if not response.ok:
logging.error(job_submission.model_dump(exclude_defaults=True))
ojsr = OpenapiJobSubmitResponse.model_validate(
self._get_response_json(response)
)
self._has_openapi_errors(
f"Job submit {ojsr.result.job_id if ojsr.result else 'None'}:", ojsr
f"Job submit {ojsr.job_id if ojsr.job_id else 'None'}:", ojsr
)
response.raise_for_status()
assert ojsr.result
return ojsr.result
assert ojsr.job_id is not None
return ojsr

def cancel_job(self, job_id: int) -> bool:
response = self._delete(f"job/{job_id}")
oar = OpenapiResp.model_validate(self._get_response_json(response))
return not self._has_openapi_errors(f"Job query {job_id}:", oar)
oar = OpenapiKillJobResp.model_validate(self._get_response_json(response))
return not self._has_openapi_errors(f"Job delete {job_id}:", oar)
Loading