Skip to content

Commit f81b368

Browse files
committed
test_reference tweaks
1 parent af1f4cc commit f81b368

File tree

9 files changed

+104
-102
lines changed

9 files changed

+104
-102
lines changed

configs/local_test.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@ ModelTraining:
1515
max_workers: 1 # suppress assertion for multigpu training
1616

1717
CP2K:
18-
cores_per_worker: 2
19-
max_evaluation_time: 1
18+
cores_per_worker: 1
19+
max_evaluation_time: 0.1
2020
container_uri: 'oras://ghcr.io/molmod/cp2k:2024.1'
2121

2222
GPAW:
23-
cores_per_worker: 2
24-
max_evaluation_time: 1
23+
cores_per_worker: 1
24+
max_evaluation_time: 0.1
2525
container_uri: 'oras://ghcr.io/molmod/gpaw:24.1'
2626

2727
ORCA:
28-
cores_per_worker: 2
29-
max_evaluation_time: 1
28+
cores_per_worker: 1
29+
max_evaluation_time: 0.1
3030

3131

3232
...

psiflow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def resolve_and_check(path: Path) -> Path:
5656
# - see chatgpt convo for process memory limits and such
5757
# - make /tmp for app workdirs an option?
5858
# - what is scaling_cores_per_worker in WQ
59+
# - can we clean up psiflow_internal slightly?
5960
# -
6061
# TODO: REFERENCE
6162
# - reference MPI args not really checked

psiflow/execution.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
logger = logging.getLogger(__name__) # logging per module
3636

3737

38+
PSIFLOW_INTERNAL = "psiflow_internal"
39+
3840
EXECUTION_KWARGS = (
3941
"container_uri",
4042
"container_engine",
@@ -552,10 +554,10 @@ def from_config(
552554
default_threads: int = 4,
553555
htex_address: str = "127.0.0.1",
554556
zip_staging: Optional[bool] = None,
555-
make_symlinks: bool = True,
557+
make_symlinks: bool = False,
556558
**kwargs,
557559
) -> ExecutionContext:
558-
path = Path.cwd().resolve() / "psiflow_internal"
560+
path = Path.cwd().resolve() / PSIFLOW_INTERNAL
559561
psiflow.resolve_and_check(path)
560562
if path.exists():
561563
shutil.rmtree(path)

psiflow/reference/dummy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def parse_output(self, stdout: str) -> dict:
4242
geom = Geometry.from_string(stdout)
4343
data = {
4444
"status": Status.SUCCESS,
45-
"runtime": np.nan,
4645
"positions": geom.per_atom.positions,
4746
"natoms": len(geom),
4847
"energy": np.random.uniform(),

psiflow/reference/reference.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations # necessary for type-guarding class methods
22

3-
import logging
43
import warnings
54
from typing import ClassVar, Optional, Union, Callable, Sequence
65
from pathlib import Path
@@ -30,14 +29,27 @@ class Status(Enum):
3029
INCONSISTENT = 2
3130

3231

32+
class SinglePointResult:
33+
"""All dict keys update_geometry understands"""
34+
35+
status: Status
36+
natoms: int # optional
37+
positions: np.ndarray
38+
energy: float
39+
forces: np.ndarray # optional
40+
stdout: Path
41+
stderr: Path # optional
42+
runtime: float # optional
43+
44+
3345
def update_geometry(geom: Geometry, data: dict) -> Geometry:
3446
""""""
3547
_, task_id, task_name = data["stdout"].stem.split("_", maxsplit=2)
3648
logger.info(f'Task "{task_name}" (ID {task_id}): {data["status"].name}')
3749

3850
geom = geom.copy()
3951
geom.reset()
40-
geom.order['status'], geom.order['task_id'] = data["status"].name, task_id
52+
geom.order["status"], geom.order["task_id"] = data["status"].name, task_id
4153
if data["status"] != Status.SUCCESS:
4254
return geom
4355
geom.order["runtime"] = data.get("runtime")

psiflow/utils/io.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def _save_txt(data: str, outputs: list[File] = []) -> None:
8282

8383
@typeguard.typechecked
8484
def _load_metrics(inputs: list = []) -> np.recarray:
85+
# TODO: stop using recarrays
8586
return np.load(inputs[0], allow_pickle=True)
8687

8788

@@ -90,6 +91,7 @@ def _load_metrics(inputs: list = []) -> np.recarray:
9091

9192
@typeguard.typechecked
9293
def _save_metrics(data: np.recarray, outputs: list = []) -> None:
94+
# TODO: stop using recarrays
9395
with open(outputs[0], "wb") as f:
9496
data.dump(f)
9597

psiflow/utils/parse.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import datetime
2+
from pathlib import Path
23

34
import numpy as np
45

6+
from psiflow.execution import PSIFLOW_INTERNAL
7+
58

69
class LineNotFoundError(Exception):
710
"""Call to find_line failed"""
11+
812
pass
913

1014

@@ -27,7 +31,7 @@ def find_line(
2731
return idx_start + i
2832
else:
2933
return idx_start - i
30-
raise LineNotFoundError('Could not find line starting with \"{}\".'.format(line))
34+
raise LineNotFoundError('Could not find line starting with "{}".'.format(line))
3135

3236

3337
def lines_to_array(
@@ -38,8 +42,17 @@ def lines_to_array(
3842

3943

4044
def string_to_timedelta(timedelta: str) -> datetime.timedelta:
45+
""""""
4146
allowed_units = "weeks", "days", "hours", "minutes", "seconds"
4247
time_list = timedelta.split()
4348
values, units = time_list[:-1:2], time_list[1::2]
4449
kwargs = {u: float(v) for u, v in zip(units, values) if u in allowed_units}
4550
return datetime.timedelta(**kwargs)
51+
52+
53+
def get_task_logs(task_id: int) -> tuple[Path, Path]:
54+
""""""
55+
path = Path.cwd().resolve() / PSIFLOW_INTERNAL / "000/task_logs" # TODO
56+
stdout = next(path.rglob(f"task_{task_id}_*.stdout"))
57+
stderr = next(path.rglob(f"task_{task_id}_*.stderr"))
58+
return stdout, stderr

tests/conftest.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,7 @@ def dataset_h2(context):
128128
)
129129
data = [h2.copy() for i in range(20)]
130130
for atoms in data:
131-
atoms.set_positions(
132-
atoms.get_positions() + np.random.uniform(-0.05, 0.05, size=(2, 3))
133-
)
131+
atoms.positions += np.random.uniform(-0.05, 0.05, size=(2, 3))
134132
return Dataset([Geometry.from_atoms(a) for a in data])
135133

136134

0 commit comments

Comments
 (0)