Skip to content

Commit a2492c3

Browse files
committed
Better support for live reading and predictions
1 parent fd28328 commit a2492c3

File tree

1 file changed

+80
-82
lines changed

1 file changed

+80
-82
lines changed

gui/processing.py

Lines changed: 80 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tempfile
44
import time
55
from pathlib import Path
6+
from typing import Callable
67

78
import pandas as pd
89
from data import generate_data, read_csv
@@ -37,8 +38,6 @@ def pysr_fit(queue: mp.Queue, out_queue: mp.Queue):
3738

3839

3940
def pysr_predict(queue: mp.Queue, out_queue: mp.Queue):
40-
import numpy as np
41-
4241
import pysr
4342

4443
while True:
@@ -49,7 +48,7 @@ def pysr_predict(queue: mp.Queue, out_queue: mp.Queue):
4948

5049
X = args["X"]
5150
equation_file = str(args["equation_file"])
52-
complexity = args["complexity"]
51+
index = args["index"]
5352

5453
equation_file_pkl = equation_file.replace(".csv", ".pkl")
5554
equation_file_bkup = equation_file + ".bkup"
@@ -66,31 +65,29 @@ def pysr_predict(queue: mp.Queue, out_queue: mp.Queue):
6665
except pd.errors.EmptyDataError:
6766
continue
6867

69-
index = np.abs(model.equations_.complexity - complexity).argmin
7068
ypred = model.predict(X, index)
7169

72-
out_queue.put(ypred)
70+
# Rename the columns to uppercase
71+
equations = model.equations_[["complexity", "loss", "equation"]].copy()
7372

73+
# Remove any row that has worse loss than previous row:
74+
equations = equations[equations["loss"].cummin() == equations["loss"]]
75+
# TODO: Why is this needed? Are rows not being removed?
7476

75-
class PySRProcess:
76-
def __init__(self):
77-
self.queue = mp.Queue()
78-
self.out_queue = mp.Queue()
79-
self.process = mp.Process(target=pysr_fit, args=(self.queue, self.out_queue))
80-
self.process.start()
77+
equations.columns = ["Complexity", "Loss", "Equation"]
78+
out_queue.put(dict(ypred=ypred, equations=equations))
8179

8280

83-
class PySRReaderProcess:
84-
def __init__(self):
85-
self.queue = mp.Queue()
86-
self.out_queue = mp.Queue()
87-
self.process = mp.Process(
88-
target=pysr_predict, args=(self.queue, self.out_queue)
89-
)
81+
class ProcessWrapper:
82+
def __init__(self, target: Callable[[mp.Queue, mp.Queue], None]):
83+
self.queue = mp.Queue(maxsize=1)
84+
self.out_queue = mp.Queue(maxsize=1)
85+
self.process = mp.Process(target=target, args=(self.queue, self.out_queue))
9086
self.process.start()
9187

9288

9389
PERSISTENT_WRITER = None
90+
PERSISTENT_READER = None
9491

9592

9693
def processing(
@@ -118,9 +115,15 @@ def processing(
118115
):
119116
"""Load data, then spawn a process to run the greet function."""
120117
global PERSISTENT_WRITER
118+
global PERSISTENT_READER
119+
121120
if PERSISTENT_WRITER is None:
122-
print("Starting PySR process")
123-
PERSISTENT_WRITER = PySRProcess()
121+
print("Starting PySR fit process")
122+
PERSISTENT_WRITER = ProcessWrapper(pysr_fit)
123+
124+
if PERSISTENT_READER is None:
125+
print("Starting PySR predict process")
126+
PERSISTENT_READER = ProcessWrapper(pysr_predict)
124127

125128
if file_input is not None:
126129
try:
@@ -130,67 +133,62 @@ def processing(
130133
else:
131134
X, y = generate_data(test_equation, num_points, noise_level, data_seed)
132135

133-
with tempfile.TemporaryDirectory() as tmpdirname:
134-
base = Path(tmpdirname)
135-
equation_file = base / "hall_of_fame.csv"
136-
equation_file_bkup = base / "hall_of_fame.csv.bkup"
137-
# Check if queue is empty, if not, kill the process
138-
# and start a new one
139-
if not PERSISTENT_WRITER.queue.empty():
140-
print("Restarting PySR process")
141-
if PERSISTENT_WRITER.process.is_alive():
142-
PERSISTENT_WRITER.process.terminate()
143-
PERSISTENT_WRITER.process.join()
144-
145-
PERSISTENT_WRITER = PySRProcess()
146-
# Write these to queue instead:
147-
PERSISTENT_WRITER.queue.put(
148-
dict(
149-
X=X,
150-
y=y,
151-
kwargs=dict(
152-
niterations=niterations,
153-
maxsize=maxsize,
154-
binary_operators=binary_operators,
155-
unary_operators=unary_operators,
136+
tmpdirname = tempfile.mkdtemp()
137+
base = Path(tmpdirname)
138+
equation_file = base / "hall_of_fame.csv"
139+
# Check if queue is empty, if not, kill the process
140+
# and start a new one
141+
if not PERSISTENT_WRITER.queue.empty():
142+
print("Restarting PySR fit process")
143+
if PERSISTENT_WRITER.process.is_alive():
144+
PERSISTENT_WRITER.process.terminate()
145+
PERSISTENT_WRITER.process.join()
146+
147+
PERSISTENT_WRITER = ProcessWrapper(pysr_fit)
148+
149+
if not PERSISTENT_READER.queue.empty():
150+
print("Restarting PySR predict process")
151+
if PERSISTENT_READER.process.is_alive():
152+
PERSISTENT_READER.process.terminate()
153+
PERSISTENT_READER.process.join()
154+
155+
PERSISTENT_READER = ProcessWrapper(pysr_predict)
156+
157+
PERSISTENT_WRITER.queue.put(
158+
dict(
159+
X=X,
160+
y=y,
161+
kwargs=dict(
162+
niterations=niterations,
163+
maxsize=maxsize,
164+
binary_operators=binary_operators,
165+
unary_operators=unary_operators,
166+
equation_file=equation_file,
167+
parsimony=parsimony,
168+
populations=populations,
169+
population_size=population_size,
170+
ncycles_per_iteration=ncycles_per_iteration,
171+
elementwise_loss=elementwise_loss,
172+
adaptive_parsimony_scaling=adaptive_parsimony_scaling,
173+
optimizer_algorithm=optimizer_algorithm,
174+
optimizer_iterations=optimizer_iterations,
175+
batching=batching,
176+
batch_size=batch_size,
177+
),
178+
)
179+
)
180+
while PERSISTENT_WRITER.out_queue.empty():
181+
if equation_file.exists():
182+
# First, copy the file to a the copy file
183+
PERSISTENT_READER.queue.put(
184+
dict(
185+
X=X,
156186
equation_file=equation_file,
157-
parsimony=parsimony,
158-
populations=populations,
159-
population_size=population_size,
160-
ncycles_per_iteration=ncycles_per_iteration,
161-
elementwise_loss=elementwise_loss,
162-
adaptive_parsimony_scaling=adaptive_parsimony_scaling,
163-
optimizer_algorithm=optimizer_algorithm,
164-
optimizer_iterations=optimizer_iterations,
165-
batching=batching,
166-
batch_size=batch_size,
167-
),
187+
index=-1,
188+
)
168189
)
169-
)
170-
while PERSISTENT_WRITER.out_queue.empty():
171-
if equation_file_bkup.exists():
172-
# First, copy the file to a the copy file
173-
equation_file_copy = base / "hall_of_fame_copy.csv"
174-
os.system(f"cp {equation_file_bkup} {equation_file_copy}")
175-
try:
176-
equations = pd.read_csv(equation_file_copy)
177-
except pd.errors.EmptyDataError:
178-
continue
179-
180-
# Ensure it is pareto dominated, with more complex expressions
181-
# having higher loss. Otherwise remove those rows.
182-
# TODO: Not sure why this occurs; could be the result of a late copy?
183-
equations.sort_values("Complexity", ascending=True, inplace=True)
184-
equations.reset_index(inplace=True)
185-
bad_idx = []
186-
min_loss = None
187-
for i in equations.index:
188-
if min_loss is None or equations.loc[i, "Loss"] < min_loss:
189-
min_loss = float(equations.loc[i, "Loss"])
190-
else:
191-
bad_idx.append(i)
192-
equations.drop(index=bad_idx, inplace=True)
193-
194-
yield equations[["Complexity", "Loss", "Equation"]]
195-
196-
time.sleep(0.1)
190+
out = PERSISTENT_READER.out_queue.get()
191+
equations = out["equations"]
192+
yield equations[["Complexity", "Loss", "Equation"]]
193+
194+
time.sleep(0.1)

0 commit comments

Comments
 (0)