Skip to content

Commit 8ba2d91

Browse files
authored
Implements a csv observer (#294)
* Starts implementing a csv observer * Adds a test for initializing with csv observer init info
1 parent 373821f commit 8ba2d91

File tree

3 files changed

+224
-0
lines changed

3 files changed

+224
-0
lines changed

src/poli/core/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,9 @@ class FoldXNotFoundException(PoliException):
1717
"""Exception raised when FoldX wasn't found in ~/foldx/foldx."""
1818

1919
pass
20+
21+
22+
class ObserverNotInitializedError(PoliException):
23+
"""Exception raised when the observer is not initialized."""
24+
25+
pass
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from pathlib import Path
5+
from time import time
6+
from uuid import uuid4
7+
8+
import numpy as np
9+
10+
from poli.core.black_box_information import BlackBoxInformation
11+
from poli.core.exceptions import ObserverNotInitializedError
12+
from poli.core.util.abstract_observer import AbstractObserver
13+
14+
15+
@dataclass
16+
class CSVObserverInitInfo:
17+
"""Initialization information for the CSVObserver."""
18+
19+
experiment_id: str
20+
experiment_path: str | Path = "./poli_results"
21+
22+
23+
class CSVObserver(AbstractObserver):
24+
"""
25+
A simple observer that logs to a CSV file, appending rows on each query.
26+
"""
27+
28+
def __init__(self):
29+
self.has_been_initialized = False
30+
super().__init__()
31+
32+
def initialize_observer(
33+
self,
34+
problem_setup_info: BlackBoxInformation,
35+
caller_info: CSVObserverInitInfo | dict,
36+
seed: int,
37+
) -> object:
38+
"""
39+
Initializes the observer with the given information.
40+
41+
Parameters
42+
----------
43+
black_box_info : BlackBoxInformation
44+
The information about the black box.
45+
caller_info : dict | CSVObserverInitInfo
46+
Information used for logging. If a dictionary, it should contain the
47+
keys `experiment_id` and `experiment_path`.
48+
seed : int
49+
The seed used for the experiment. This is only logged, not used.
50+
"""
51+
self.info = problem_setup_info
52+
self.seed = seed
53+
self.unique_id = f"{uuid4()}"[:8]
54+
55+
if isinstance(caller_info, CSVObserverInitInfo):
56+
caller_info = caller_info.__dict__
57+
58+
self.all_results_path = Path(
59+
caller_info.get("experiment_path", "./poli_results")
60+
)
61+
self.experiment_path = self.all_results_path / problem_setup_info.name
62+
self.experiment_path.mkdir(exist_ok=True, parents=True)
63+
self._write_gitignore()
64+
65+
self.experiment_id = caller_info.get(
66+
"experiment_id",
67+
f"{int(time())}_experiment_{problem_setup_info.name}_{seed}_{self.unique_id}",
68+
)
69+
70+
self.csv_file_path = self.experiment_path / f"{self.experiment_id}.csv"
71+
self.save_header()
72+
self.has_been_initialized = True
73+
74+
def _write_gitignore(self):
75+
if not (self.all_results_path / ".gitignore").exists():
76+
with open(self.all_results_path / ".gitignore", "w") as f:
77+
f.write("*\n")
78+
79+
def _make_folder_for_experiment(self):
80+
self.experiment_path.mkdir(exist_ok=True, parents=True)
81+
82+
def _validate_input(self, x: np.ndarray, y: np.ndarray) -> None:
83+
if x.ndim != 2:
84+
raise ValueError(f"x should be 2D, got {x.ndim}D instead.")
85+
if y.ndim != 2:
86+
raise ValueError(f"y should be 2D, got {y.ndim}D instead.")
87+
if x.shape[0] != y.shape[0]:
88+
raise ValueError(
89+
f"x and y should have the same number of samples, got {x.shape[0]} and {y.shape[0]} respectively."
90+
)
91+
92+
def _ensure_proper_shape(self, x: np.ndarray) -> np.ndarray:
93+
if x.ndim == 1:
94+
return x.reshape(-1, 1)
95+
return x
96+
97+
def observe(self, x: np.ndarray, y: np.ndarray, context=None) -> None:
98+
if not self.has_been_initialized:
99+
raise ObserverNotInitializedError(
100+
"The observer has not been initialized. Please call `initialize_observer` first."
101+
)
102+
x = self._ensure_proper_shape(x)
103+
self._validate_input(x, y)
104+
self.append_results(["".join(x_i) for x_i in x], [y_i for y_i in y.flatten()])
105+
106+
def save_header(self):
107+
self._make_folder_for_experiment()
108+
with open(self.csv_file_path, "w") as f:
109+
f.write("x,y\n")
110+
111+
def append_results(self, x: list[str], y: list[float]):
112+
with open(self.csv_file_path, "a") as f:
113+
for x_i, y_i in zip(x, y):
114+
f.write(f"{x_i},{y_i}\n")
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import csv
2+
3+
import numpy as np
4+
import pytest
5+
6+
from poli.core.exceptions import ObserverNotInitializedError
7+
from poli.core.util.observers.csv_observer import CSVObserver, CSVObserverInitInfo
8+
from poli.repository import AlohaBlackBox
9+
10+
11+
def test_csv_observer_logs_on_aloha():
12+
f = AlohaBlackBox()
13+
observer = CSVObserver()
14+
observer.initialize_observer(
15+
f.info,
16+
{
17+
"experiment_id": "test_csv_observer_logs_on_aloha",
18+
"experiment_path": "./poli_results",
19+
},
20+
seed=0,
21+
)
22+
23+
f.set_observer(observer)
24+
f(np.array([list("MIGUE")]))
25+
f(np.array([list("ALOOF")]))
26+
f(np.array([list("ALOHA"), list("OMAHA")]))
27+
28+
assert observer.csv_file_path.exists()
29+
30+
# Loading up the csv and checking results
31+
with open(observer.csv_file_path, "r") as f:
32+
reader = csv.reader(f)
33+
results = list(reader)
34+
35+
assert results[0] == ["x", "y"]
36+
assert results[1][0] == "MIGUE" and float(results[1][1]) == 0.0
37+
assert results[2][0] == "ALOOF" and float(results[2][1]) == 3.0
38+
assert results[3][0] == "ALOHA" and float(results[3][1]) == 5.0
39+
assert results[4][0] == "OMAHA" and float(results[4][1]) == 2.0
40+
41+
42+
def test_csv_observer_works_with_incomplete_caller_info():
43+
f = AlohaBlackBox()
44+
observer = CSVObserver()
45+
observer.initialize_observer(
46+
f.info,
47+
{},
48+
seed=0,
49+
)
50+
51+
f.set_observer(observer)
52+
f(np.array([list("MIGUE")]))
53+
f(np.array([list("ALOOF")]))
54+
f(np.array([list("ALOHA"), list("OMAHA")]))
55+
56+
assert observer.csv_file_path.exists()
57+
58+
# Loading up the csv and checking results
59+
with open(observer.csv_file_path, "r") as f:
60+
reader = csv.reader(f)
61+
results = list(reader)
62+
63+
assert results[0] == ["x", "y"]
64+
assert results[1][0] == "MIGUE" and float(results[1][1]) == 0.0
65+
assert results[2][0] == "ALOOF" and float(results[2][1]) == 3.0
66+
assert results[3][0] == "ALOHA" and float(results[3][1]) == 5.0
67+
assert results[4][0] == "OMAHA" and float(results[4][1]) == 2.0
68+
69+
70+
def test_observer_without_initialization():
71+
f = AlohaBlackBox()
72+
observer = CSVObserver()
73+
74+
f.set_observer(observer)
75+
76+
with pytest.raises(ObserverNotInitializedError):
77+
f(np.array([list("MIGUE")]))
78+
79+
80+
def test_works_with_csv_init_object():
81+
f = AlohaBlackBox()
82+
observer = CSVObserver()
83+
observer.initialize_observer(
84+
f.info,
85+
CSVObserverInitInfo(
86+
experiment_id="test_csv_observer_logs_on_aloha",
87+
experiment_path="./poli_results",
88+
),
89+
seed=0,
90+
)
91+
f.set_observer(observer)
92+
f(np.array([list("MIGUE")]))
93+
f(np.array([list("ALOOF")]))
94+
f(np.array([list("ALOHA"), list("OMAHA")]))
95+
assert observer.csv_file_path.exists()
96+
# Loading up the csv and checking results
97+
with open(observer.csv_file_path, "r") as f:
98+
reader = csv.reader(f)
99+
results = list(reader)
100+
assert results[0] == ["x", "y"]
101+
assert results[1][0] == "MIGUE" and float(results[1][1]) == 0.0
102+
assert results[2][0] == "ALOOF" and float(results[2][1]) == 3.0
103+
assert results[3][0] == "ALOHA" and float(results[3][1]) == 5.0
104+
assert results[4][0] == "OMAHA" and float(results[4][1]) == 2.0

0 commit comments

Comments
 (0)