|
125 | 125 | parser.add_argument("--verbose", action="store_true", help="increase output verbosity") |
126 | 126 | parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)") |
127 | 127 | parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth") |
| 128 | +parser.add_argument("--plot_log_scale", action="store_true", help="use log scale for x axis in plots (off by default)") |
128 | 129 |
|
129 | 130 | known_args, unknown_args = parser.parse_known_args() |
130 | 131 |
|
@@ -612,7 +613,7 @@ def valid_format(data_files: list[str]) -> bool: |
612 | 613 | headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"] |
613 | 614 |
|
614 | 615 | if known_args.plot: |
615 | | - def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str): |
| 616 | + def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str, log_scale: bool = False): |
616 | 617 | try: |
617 | 618 | import matplotlib.pyplot as plt |
618 | 619 | import matplotlib |
@@ -661,7 +662,7 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas |
661 | 662 | base_test = test_name.split("@d")[0] |
662 | 663 | x_value = int(test_name.split("@d")[1]) |
663 | 664 | else: |
664 | | - assert False |
| 665 | + base_test = test_name |
665 | 666 |
|
666 | 667 | if base_test.strip(): |
667 | 668 | group_key_parts.append(f"Test={base_test}") |
@@ -731,7 +732,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)): |
731 | 732 | ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8, |
732 | 733 | label=f'{compare_name}', linewidth=2, markersize=6) |
733 | 734 |
|
734 | | - if plot_x_param == "n_depth" and min(x_values) > 0 and max(x_values) > min(x_values) * 4: |
| 735 | + if log_scale and min(x_values) > 0: |
735 | 736 | ax.set_xscale('log', base=2) |
736 | 737 | unique_x = sorted(set(x_values)) |
737 | 738 | ax.set_xticks(unique_x) |
@@ -764,7 +765,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)): |
764 | 765 | plt.savefig(output_file, dpi=300, bbox_inches='tight') |
765 | 766 | plt.close() |
766 | 767 |
|
767 | | - create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x) |
| 768 | + create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x, known_args.plot_log_scale) |
768 | 769 |
|
769 | 770 | print(tabulate( # noqa: NP100 |
770 | 771 | table, |
|
0 commit comments