99
1010
1111def draw_curves (
12- for_pr : bool = True ,
12+ mode : str ,
1313 axes_setting : dict = None ,
1414 curves_npy_path : list = None ,
1515 row_num : int = 1 ,
@@ -20,14 +20,13 @@ def draw_curves(
2020 ncol_of_legend : int = 1 ,
2121 separated_legend : bool = False ,
2222 sharey : bool = False ,
23- line_styles = ("-" , "--" ),
2423 line_width = 3 ,
2524 save_name = None ,
2625):
2726 """A better curve painter!
2827
2928 Args:
30- for_pr (bool, optional ): Plot for PR curves or FM curves. Defaults to True .
29+ mode (str ): `pr` for PR curves, `fm` for F-measure curves, and `em' for E-measure curves .
3130 axes_setting (dict, optional): Setting for axes. Defaults to None.
3231 curves_npy_path (list, optional): Paths of curve npy files. Defaults to None.
3332 row_num (int, optional): Number of rows. Defaults to 1.
@@ -38,11 +37,10 @@ def draw_curves(
3837 ncol_of_legend (int, optional): Number of columns for the legend. Defaults to 1.
3938 separated_legend (bool, optional): Use the separated legend. Defaults to False.
4039 sharey (bool, optional): Use a shared y-axis. Defaults to False.
41- line_styles (tuple, optional): Styles of lines. Defaults to ("-", "--").
4240 line_width (int, optional): Width of lines. Defaults to 3.
4341 save_name (str, optional): Name or path (without the extension format). Defaults to None.
4442 """
45- mode = "pr" if for_pr else "fm"
43+ assert mode in [ "pr" , "fm" , "em" ]
4644 save_name = save_name or mode
4745 mode_axes_setting = axes_setting [mode ]
4846
@@ -97,23 +95,36 @@ def draw_curves(
9795 # assert len(our_methods) <= len(line_styles)
9896 else :
9997 our_methods = []
98+ num_our_methods = len (our_methods )
10099
101- # Give each method a unique color.
100+ # Give each method a unique color and style .
102101 color_table = sorted (
103102 [
104103 color
105104 for name , color in colors .cnames .items ()
106105 if name not in ["red" , "white" ] or not name .startswith ("light" ) or "gray" in name
107106 ]
108107 )
108+ style_table = ["-" , "--" , "-." , ":" , "." ]
109+
109110 unique_method_settings = OrderedDict ()
110111 for i , method_name in enumerate (target_unique_method_names ):
112+ if i < num_our_methods :
113+ line_color = "red"
114+ line_style = style_table [i % len (style_table )]
115+ else :
116+ other_idx = i - num_our_methods
117+ line_color = color_table [other_idx ]
118+ line_style = style_table [other_idx % 2 ]
119+
111120 unique_method_settings [method_name ] = {
112- "line_color" : "red" if i < len ( our_methods ) else color_table [ i ] ,
121+ "line_color" : line_color ,
113122 "line_label" : method_aliases .get (method_name , method_name ),
114- "line_style" : line_styles [ i % len ( line_styles )] ,
123+ "line_style" : line_style ,
115124 "line_width" : line_width ,
116125 }
126+ # ensure that our methods are drawn last to avoid being overwritten by other methods
127+ target_unique_method_names .reverse ()
117128
118129 curve_drawer = CurveDrawer (
119130 row_num = row_num ,
@@ -135,9 +146,13 @@ def draw_curves(
135146 y_ticks = y_ticks ,
136147 )
137148
138- for method_name , method_setting in unique_method_settings .items ():
149+ for method_name in target_unique_method_names :
150+ method_setting = unique_method_settings [method_name ]
151+
139152 if method_name not in dataset_results :
140- raise KeyError (f"{ method_name } not in { sorted (dataset_results .keys ())} " )
153+ print (f"{ method_name } will be skipped for { dataset_name } !" )
154+ continue
155+
141156 method_results = dataset_results [method_name ]
142157 if mode == "pr" :
143158 y_data = method_results .get ("p" )
0 commit comments