33import tempfile
44import time
55from pathlib import Path
6+ from typing import Callable
67
78import pandas as pd
89from data import generate_data , read_csv
@@ -37,8 +38,6 @@ def pysr_fit(queue: mp.Queue, out_queue: mp.Queue):
3738
3839
3940def pysr_predict (queue : mp .Queue , out_queue : mp .Queue ):
40- import numpy as np
41-
4241 import pysr
4342
4443 while True :
@@ -49,7 +48,7 @@ def pysr_predict(queue: mp.Queue, out_queue: mp.Queue):
4948
5049 X = args ["X" ]
5150 equation_file = str (args ["equation_file" ])
52- complexity = args ["complexity " ]
51+ index = args ["index " ]
5352
5453 equation_file_pkl = equation_file .replace (".csv" , ".pkl" )
5554 equation_file_bkup = equation_file + ".bkup"
@@ -66,31 +65,29 @@ def pysr_predict(queue: mp.Queue, out_queue: mp.Queue):
6665 except pd .errors .EmptyDataError :
6766 continue
6867
69- index = np .abs (model .equations_ .complexity - complexity ).argmin
7068 ypred = model .predict (X , index )
7169
72- out_queue .put (ypred )
70+ # Rename the columns to uppercase
71+ equations = model .equations_ [["complexity" , "loss" , "equation" ]].copy ()
7372
73+ # Remove any row that has worse loss than previous row:
74+ equations = equations [equations ["loss" ].cummin () == equations ["loss" ]]
75+ # TODO: Why is this needed? Are rows not being removed?
7476
75- class PySRProcess :
76- def __init__ (self ):
77- self .queue = mp .Queue ()
78- self .out_queue = mp .Queue ()
79- self .process = mp .Process (target = pysr_fit , args = (self .queue , self .out_queue ))
80- self .process .start ()
77+ equations .columns = ["Complexity" , "Loss" , "Equation" ]
78+ out_queue .put (dict (ypred = ypred , equations = equations ))
8179
8280
83- class PySRReaderProcess :
84- def __init__ (self ):
85- self .queue = mp .Queue ()
86- self .out_queue = mp .Queue ()
87- self .process = mp .Process (
88- target = pysr_predict , args = (self .queue , self .out_queue )
89- )
81+ class ProcessWrapper :
82+ def __init__ (self , target : Callable [[mp .Queue , mp .Queue ], None ]):
83+ self .queue = mp .Queue (maxsize = 1 )
84+ self .out_queue = mp .Queue (maxsize = 1 )
85+ self .process = mp .Process (target = target , args = (self .queue , self .out_queue ))
9086 self .process .start ()
9187
9288
9389PERSISTENT_WRITER = None
90+ PERSISTENT_READER = None
9491
9592
9693def processing (
@@ -118,9 +115,15 @@ def processing(
118115):
119116 """Load data, then spawn a process to run the greet function."""
120117 global PERSISTENT_WRITER
118+ global PERSISTENT_READER
119+
121120 if PERSISTENT_WRITER is None :
122- print ("Starting PySR process" )
123- PERSISTENT_WRITER = PySRProcess ()
121+ print ("Starting PySR fit process" )
122+ PERSISTENT_WRITER = ProcessWrapper (pysr_fit )
123+
124+ if PERSISTENT_READER is None :
125+ print ("Starting PySR predict process" )
126+ PERSISTENT_READER = ProcessWrapper (pysr_predict )
124127
125128 if file_input is not None :
126129 try :
@@ -130,67 +133,62 @@ def processing(
130133 else :
131134 X , y = generate_data (test_equation , num_points , noise_level , data_seed )
132135
133- with tempfile .TemporaryDirectory () as tmpdirname :
134- base = Path (tmpdirname )
135- equation_file = base / "hall_of_fame.csv"
136- equation_file_bkup = base / "hall_of_fame.csv.bkup"
137- # Check if queue is empty, if not, kill the process
138- # and start a new one
139- if not PERSISTENT_WRITER .queue .empty ():
140- print ("Restarting PySR process" )
141- if PERSISTENT_WRITER .process .is_alive ():
142- PERSISTENT_WRITER .process .terminate ()
143- PERSISTENT_WRITER .process .join ()
144-
145- PERSISTENT_WRITER = PySRProcess ()
146- # Write these to queue instead:
147- PERSISTENT_WRITER .queue .put (
148- dict (
149- X = X ,
150- y = y ,
151- kwargs = dict (
152- niterations = niterations ,
153- maxsize = maxsize ,
154- binary_operators = binary_operators ,
155- unary_operators = unary_operators ,
136+ tmpdirname = tempfile .mkdtemp ()
137+ base = Path (tmpdirname )
138+ equation_file = base / "hall_of_fame.csv"
139+ # Check if queue is empty, if not, kill the process
140+ # and start a new one
141+ if not PERSISTENT_WRITER .queue .empty ():
142+ print ("Restarting PySR fit process" )
143+ if PERSISTENT_WRITER .process .is_alive ():
144+ PERSISTENT_WRITER .process .terminate ()
145+ PERSISTENT_WRITER .process .join ()
146+
147+ PERSISTENT_WRITER = ProcessWrapper (pysr_fit )
148+
149+ if not PERSISTENT_READER .queue .empty ():
150+ print ("Restarting PySR predict process" )
151+ if PERSISTENT_READER .process .is_alive ():
152+ PERSISTENT_READER .process .terminate ()
153+ PERSISTENT_READER .process .join ()
154+
155+ PERSISTENT_READER = ProcessWrapper (pysr_predict )
156+
157+ PERSISTENT_WRITER .queue .put (
158+ dict (
159+ X = X ,
160+ y = y ,
161+ kwargs = dict (
162+ niterations = niterations ,
163+ maxsize = maxsize ,
164+ binary_operators = binary_operators ,
165+ unary_operators = unary_operators ,
166+ equation_file = equation_file ,
167+ parsimony = parsimony ,
168+ populations = populations ,
169+ population_size = population_size ,
170+ ncycles_per_iteration = ncycles_per_iteration ,
171+ elementwise_loss = elementwise_loss ,
172+ adaptive_parsimony_scaling = adaptive_parsimony_scaling ,
173+ optimizer_algorithm = optimizer_algorithm ,
174+ optimizer_iterations = optimizer_iterations ,
175+ batching = batching ,
176+ batch_size = batch_size ,
177+ ),
178+ )
179+ )
180+ while PERSISTENT_WRITER .out_queue .empty ():
181+ if equation_file .exists ():
182+ # First, copy the file to a the copy file
183+ PERSISTENT_READER .queue .put (
184+ dict (
185+ X = X ,
156186 equation_file = equation_file ,
157- parsimony = parsimony ,
158- populations = populations ,
159- population_size = population_size ,
160- ncycles_per_iteration = ncycles_per_iteration ,
161- elementwise_loss = elementwise_loss ,
162- adaptive_parsimony_scaling = adaptive_parsimony_scaling ,
163- optimizer_algorithm = optimizer_algorithm ,
164- optimizer_iterations = optimizer_iterations ,
165- batching = batching ,
166- batch_size = batch_size ,
167- ),
187+ index = - 1 ,
188+ )
168189 )
169- )
170- while PERSISTENT_WRITER .out_queue .empty ():
171- if equation_file_bkup .exists ():
172- # First, copy the file to a the copy file
173- equation_file_copy = base / "hall_of_fame_copy.csv"
174- os .system (f"cp { equation_file_bkup } { equation_file_copy } " )
175- try :
176- equations = pd .read_csv (equation_file_copy )
177- except pd .errors .EmptyDataError :
178- continue
179-
180- # Ensure it is pareto dominated, with more complex expressions
181- # having higher loss. Otherwise remove those rows.
182- # TODO: Not sure why this occurs; could be the result of a late copy?
183- equations .sort_values ("Complexity" , ascending = True , inplace = True )
184- equations .reset_index (inplace = True )
185- bad_idx = []
186- min_loss = None
187- for i in equations .index :
188- if min_loss is None or equations .loc [i , "Loss" ] < min_loss :
189- min_loss = float (equations .loc [i , "Loss" ])
190- else :
191- bad_idx .append (i )
192- equations .drop (index = bad_idx , inplace = True )
193-
194- yield equations [["Complexity" , "Loss" , "Equation" ]]
195-
196- time .sleep (0.1 )
190+ out = PERSISTENT_READER .out_queue .get ()
191+ equations = out ["equations" ]
192+ yield equations [["Complexity" , "Loss" , "Equation" ]]
193+
194+ time .sleep (0.1 )
0 commit comments