|
1 | 1 | import argparse |
2 | 2 | import itertools |
3 | 3 | import os |
4 | | -from typing import Any, Dict, List |
| 4 | + |
| 5 | +from triton.testing import Benchmark |
5 | 6 |
|
6 | 7 | BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER") |
7 | 8 |
|
@@ -171,73 +172,6 @@ def perf_report(benchmarks): |
171 | 172 | return wrapper |
172 | 173 |
|
173 | 174 |
|
174 | | -# # pylint: disable=too-many-instance-attributes |
175 | | -class Benchmark: |
176 | | - """ |
177 | | - This class is used by the :code:`perf_report` function to generate line plots with a concise API. |
178 | | - """ |
179 | | - |
180 | | - def __init__( |
181 | | - self, |
182 | | - x_names: List[str], |
183 | | - x_vals: List[Any], |
184 | | - line_arg: str, |
185 | | - line_vals: List[Any], |
186 | | - line_names: List[str], |
187 | | - plot_name: str, |
188 | | - args: Dict[str, Any], |
189 | | - xlabel: str = "", |
190 | | - ylabel: str = "", |
191 | | - x_log: bool = False, |
192 | | - y_log: bool = False, |
193 | | - color=None, # pylint: disable=unused-argument |
194 | | - styles=None, |
195 | | - ): |
196 | | - """ |
197 | | - Constructor. |
198 | | - x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list |
199 | | - of scalars and there are multiple x_names, all arguments will have the same value. |
200 | | - If x_vals is a list of tuples/lists, each element should have the same length as |
201 | | - x_names. |
202 | | -
|
203 | | - :param x_names: Name of the arguments that should appear on the x axis of the plot. |
204 | | - :type x_names: List[str] |
205 | | - :param x_vals: List of values to use for the arguments in :code:`x_names`. |
206 | | - :type x_vals: List[Any] |
207 | | - :param line_arg: Argument name for which different values correspond to different lines in the plot. |
208 | | - :type line_arg: str |
209 | | - :param line_vals: List of values to use for the arguments in :code:`line_arg`. |
210 | | - :type line_vals: List[Any] |
211 | | - :param line_names: Label names for the different lines. |
212 | | - :type line_names: List[str] |
213 | | - :param plot_name: Name of the plot. |
214 | | - :type plot_name: str |
215 | | - :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. |
216 | | - :type args: Dict[str, Any] |
217 | | - :param xlabel: Label for the x axis of the plot. |
218 | | - :type xlabel: str, optional |
219 | | - :param ylabel: Label for the y axis of the plot. |
220 | | - :type ylabel: str, optional |
221 | | - :param x_log: Whether the x axis should be log scale. |
222 | | - :type x_log: bool, optional |
223 | | - :param y_log: Whether the y axis should be log scale. |
224 | | - :type y_log: bool, optional |
225 | | - """ |
226 | | - self.x_names = x_names |
227 | | - self.x_vals = x_vals |
228 | | - self.x_log = x_log |
229 | | - self.line_arg = line_arg |
230 | | - self.line_vals = line_vals |
231 | | - self.line_names = line_names |
232 | | - self.y_log = y_log |
233 | | - self.styles = styles |
234 | | - # plot info |
235 | | - self.xlabel = xlabel |
236 | | - self.ylabel = ylabel |
237 | | - self.plot_name = plot_name |
238 | | - self.args = args |
239 | | - |
240 | | - |
241 | 175 | class Mark: |
242 | 176 |
|
243 | 177 | def __init__(self, fn, benchmarks): |
|
0 commit comments