Skip to content

Commit 86e9755

Browse files
committed
Integrate gui in main codebase
1 parent 519fcb9 commit 86e9755

File tree

6 files changed

+39
-16
lines changed

6 files changed

+39
-16
lines changed

pysr/_cli/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import click
77

8+
from ..gui import main as gui_main
89
from ..test import (
910
get_runtests_cli,
1011
runtests,
@@ -48,6 +49,11 @@ def _install(julia_project, quiet, precompile):
4849
)
4950

5051

52+
@pysr.command("gui", help="Start a Gradio-based GUI.")
53+
def _gui():
54+
gui_main()
55+
56+
5157
TEST_OPTIONS = {"main", "jax", "torch", "cli", "dev", "startup"}
5258

5359

gui/app.py renamed to pysr/gui/app.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1-
import gradio as gr
2-
31
from .data import test_equations
42
from .plots import replot, replot_pareto
53
from .processing import processing
64

75

6+
def get_gr():
7+
import gradio as gr
8+
9+
return gr
10+
11+
812
def _data_layout():
13+
gr = get_gr()
14+
915
with gr.Tab("Example Data"):
1016
# Plot of the example data:
1117
with gr.Row():
@@ -43,6 +49,8 @@ def _data_layout():
4349

4450

4551
def _settings_layout():
52+
gr = get_gr()
53+
4654
with gr.Tab("Basic Settings"):
4755
binary_operators = gr.CheckboxGroup(
4856
choices=["+", "-", "*", "/", "^", "max", "min", "mod", "cond"],
@@ -171,6 +179,8 @@ def _settings_layout():
171179

172180

173181
def main():
182+
gr = get_gr()
183+
174184
blocks = {}
175185
with gr.Blocks() as demo:
176186
with gr.Row():
@@ -245,7 +255,3 @@ def main():
245255
demo.load(replot, eqn_components, blocks["example_plot"])
246256

247257
demo.launch(debug=True)
248-
249-
250-
if __name__ == "__main__":
251-
main()
File renamed without changes.

gui/plots.py renamed to pysr/gui/plots.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
11
import numpy as np
22
import pandas as pd
3-
from matplotlib import pyplot as plt
4-
5-
plt.ioff()
6-
plt.rcParams["font.family"] = [
7-
"IBM Plex Mono",
8-
# Fallback fonts:
9-
"DejaVu Sans Mono",
10-
"Courier New",
11-
"monospace",
12-
]
133

144
from .data import generate_data
155

6+
FIRST_LOAD = True
7+
8+
9+
def get_plt():
10+
from matplotlib import pyplot as plt
11+
12+
if FIRST_LOAD:
13+
plt.ioff()
14+
plt.rcParams["font.family"] = [
15+
"IBM Plex Mono",
16+
# Fallback fonts:
17+
"DejaVu Sans Mono",
18+
"Courier New",
19+
"monospace",
20+
]
21+
22+
FIRST_LOAD = False
23+
return plt
24+
1625

1726
def replot_pareto(df: pd.DataFrame, maxsize: int):
27+
plt = get_plt()
1828
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
1929

2030
if len(df) == 0 or "Equation" not in df.columns:
@@ -60,6 +70,7 @@ def replot(test_equation, num_points, noise_level, data_seed):
6070
X, y = generate_data(test_equation, num_points, noise_level, data_seed)
6171
x = X["x"]
6272

73+
plt = get_plt()
6374
plt.rcParams["font.family"] = "IBM Plex Mono"
6475
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
6576

File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)