@@ -20,7 +20,6 @@ def plot_pareto_curve(df: pd.DataFrame, maxsize: int):
2020 if len (df ) == 0 or "Equation" not in df .columns :
2121 return fig
2222
23- # Plotting the data
2423 ax .loglog (
2524 df ["Complexity" ],
2625 df ["Loss" ],
@@ -31,23 +30,12 @@ def plot_pareto_curve(df: pd.DataFrame, maxsize: int):
3130 markersize = 6 ,
3231 )
3332
34- # Set the axis limits
3533 ax .set_xlim (0.5 , maxsize + 1 )
3634 ytop = 2 ** (np .ceil (np .log2 (df ["Loss" ].max ())))
3735 ybottom = 2 ** (np .floor (np .log2 (df ["Loss" ].min () + 1e-20 )))
3836 ax .set_ylim (ybottom , ytop )
3937
40- ax .grid (True , which = "both" , ls = "--" , linewidth = 0.5 , color = "gray" , alpha = 0.5 )
41- ax .spines ["top" ].set_visible (False )
42- ax .spines ["right" ].set_visible (False )
43-
44- # Range-frame the plot
45- for direction in ["bottom" , "left" ]:
46- ax .spines [direction ].set_position (("outward" , 10 ))
47-
48- # Delete far ticks
49- ax .tick_params (axis = "both" , which = "major" , labelsize = 10 , direction = "out" , length = 5 )
50- ax .tick_params (axis = "both" , which = "minor" , labelsize = 8 , direction = "out" , length = 3 )
38+ stylize_axis (ax )
5139
5240 ax .set_xlabel ("Complexity" )
5341 ax .set_ylabel ("Loss" )
@@ -57,14 +45,23 @@ def plot_pareto_curve(df: pd.DataFrame, maxsize: int):
5745
5846
5947def plot_example_data (test_equation , num_points , noise_level , data_seed ):
48+ fig , ax = plt .subplots (figsize = (6 , 6 ), dpi = 100 )
49+
6050 X , y = generate_data (test_equation , num_points , noise_level , data_seed )
6151 x = X ["x" ]
6252
63- plt .rcParams ["font.family" ] = "IBM Plex Mono"
64- fig , ax = plt .subplots (figsize = (6 , 6 ), dpi = 100 )
65-
6653 ax .scatter (x , y , alpha = 0.7 , edgecolors = "w" , s = 50 )
6754
55+ stylize_axis (ax )
56+
57+ ax .set_xlabel ("x" )
58+ ax .set_ylabel ("y" )
59+ fig .tight_layout (pad = 2 )
60+
61+ return fig
62+
63+
64+ def stylize_axis (ax ):
6865 ax .grid (True , which = "both" , ls = "--" , linewidth = 0.5 , color = "gray" , alpha = 0.5 )
6966 ax .spines ["top" ].set_visible (False )
7067 ax .spines ["right" ].set_visible (False )
@@ -76,9 +73,3 @@ def plot_example_data(test_equation, num_points, noise_level, data_seed):
7673 # Delete far ticks
7774 ax .tick_params (axis = "both" , which = "major" , labelsize = 10 , direction = "out" , length = 5 )
7875 ax .tick_params (axis = "both" , which = "minor" , labelsize = 8 , direction = "out" , length = 3 )
79-
80- ax .set_xlabel ("x" )
81- ax .set_ylabel ("y" )
82- fig .tight_layout (pad = 2 )
83-
84- return fig
0 commit comments