@@ -36,18 +36,18 @@ def _summarize_statistics(times, quantiles, return_mode):
3636 return getattr (torch , return_mode )(times ).item ()
3737
3838
39- def do_bench_ipex (fn , warmup = 25 , rep = 100 , grad_to_none = None , quantiles = None , return_mode = "mean" , device = "xpu" ,
39+ def do_bench_ipex (fn , n_warmup = 25 , n_repeat = 100 , grad_to_none = None , quantiles = None , return_mode = "mean" , device = "xpu" ,
4040 sync_submitting = True , kernel_name = None ): # pylint: disable=unused-argument
4141 """
4242 Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
4343 the 20-th and 80-th performance percentile.
4444
4545 :param fn: Function to benchmark
4646 :type fn: Callable
47- :param warmup: Warmup time (in ms)
48- :type warmup : int
49- :param rep: Repetition time (in ms)
50- :type rep : int
47+ :param n_warmup: Number of repetitions for warmup
48+ :type n_warmup : int
49+ :param n_repeat: Number of repetitions to collect measurements
50+ :type n_repeat : int
5151 :param grad_to_none: Reset the gradient of the provided tensor to None
5252 :type grad_to_none: torch.tensor, optional
5353 :param quantiles: Performance percentile to return in addition to the median.
@@ -69,20 +69,6 @@ def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, ret
6969 cache_size = 256 * 1024 * 1024
7070 cache = torch .empty (int (cache_size // 4 ), dtype = torch .int , device = device )
7171
72- # Estimate the runtime of the function
73- start_event = torch .xpu .Event (enable_timing = True )
74- end_event = torch .xpu .Event (enable_timing = True )
75- start_event .record ()
76- for _ in range (5 ):
77- cache .zero_ ()
78- fn ()
79- end_event .record ()
80- synchronize ()
81- estimate_ms = start_event .elapsed_time (end_event ) / 5
82-
83- # compute number of warmup and repeat
84- n_warmup = max (1 , int (warmup / estimate_ms ))
85- n_repeat = max (1 , int (rep / estimate_ms ))
8672 # Warm-up
8773 for _ in range (n_warmup ):
8874 fn ()
@@ -121,18 +107,18 @@ def extract_kernels(funcs):
121107 return _summarize_statistics (times , quantiles , return_mode )
122108
123109
124- def do_bench_elapsed_time (fn , warmup = 25 , rep = 100 , grad_to_none = None , quantiles = None , return_mode = "mean" , device = "xpu " ,
125- kernel_name = None ): # pylint: disable=unused-argument
110+ def do_bench_elapsed_time (fn , n_warmup = 25 , n_repeat = 100 , grad_to_none = None , quantiles = None , return_mode = "mean" ,
111+ device = "xpu" , kernel_name = None ): # pylint: disable=unused-argument
126112 """
127113 Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
128114 the 20-th and 80-th performance percentile.
129115
130116 :param fn: Function to benchmark
131117 :type fn: Callable
132- :param warmup: Warmup time (in ms)
133- :type warmup : int
134- :param rep: Repetition time (in ms)
135- :type rep : int
118+ :param n_warmup: Number of repetitions for warmup
119+ :type n_warmup : int
120+ :param n_repeat: Number of repetitions to collect measurements
121+ :type n_repeat : int
136122 :param grad_to_none: Reset the gradient of the provided tensor to None
137123 :type grad_to_none: torch.tensor, optional
138124 :param quantiles: Performance percentile to return in addition to the median.
@@ -142,24 +128,49 @@ def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=N
142128 import torch
143129 from triton .testing import do_bench as triton_do_bench
144130
145- times = triton_do_bench (fn , warmup = warmup , rep = rep , grad_to_none = grad_to_none , return_mode = "all" ,
131+ # We maintain a buffer of 256 MB that we clear
132+ # before each kernel call to make sure that the L2
133+ # doesn't contain any input data before the run
134+ cache_size = 256 * 1024 * 1024
135+ cache = torch .empty (int (cache_size // 4 ), dtype = torch .int , device = device )
136+
137+ # Estimate the runtime of the function
138+ start_event = torch .xpu .Event (enable_timing = True )
139+ end_event = torch .xpu .Event (enable_timing = True )
140+ start_event .record ()
141+ for _ in range (5 ):
142+ cache .zero_ ()
143+ fn ()
144+ end_event .record ()
145+ synchronize ()
146+ estimate_ms = start_event .elapsed_time (end_event ) / 5
147+
148+ # The cache is also maintained in `triton_do_bench` function,
149+ # there is no need to duplicate the amount of memory used.
150+ del cache
151+
152+ # compute warmup and repeat times
153+ warmup_time = n_warmup * estimate_ms
154+ rep_time = n_repeat * estimate_ms
155+
156+ times = triton_do_bench (fn , warmup = warmup_time , rep = rep_time , grad_to_none = grad_to_none , return_mode = "all" ,
146157 device_type = device )
147158 times = torch .tensor (times , dtype = torch .float )
148159 return _summarize_statistics (times , quantiles , return_mode )
149160
150161
151- def do_bench_upstream_pytorch_profiler (fn , warmup = 25 , rep = 100 , grad_to_none = None , quantiles = None , return_mode = "mean" ,
152- device = "xpu" , sync_submitting = True , kernel_name = None ):
162+ def do_bench_upstream_pytorch_profiler (fn , n_warmup = 25 , n_repeat = 100 , grad_to_none = None , quantiles = None ,
163+ return_mode = "mean" , device = "xpu" , sync_submitting = True , kernel_name = None ):
153164 """
154165 Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
155166 the 20-th and 80-th performance percentile.
156167
157168 :param fn: Function to benchmark
158169 :type fn: Callable
159- :param warmup: Warmup time (in ms)
160- :type warmup : int
161- :param rep: Repetition time (in ms)
162- :type rep : int
170+ :param n_warmup: Number of repetitions for warmup
171+ :type n_warmup : int
172+ :param n_repeat: Number of repetitions to collect measurements
173+ :type n_repeat : int
163174 :param grad_to_none: Reset the gradient of the provided tensor to None
164175 :type grad_to_none: torch.tensor, optional
165176 :param quantiles: Performance percentile to return in addition to the median.
@@ -179,20 +190,6 @@ def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None
179190 cache_size = 256 * 1024 * 1024
180191 cache = torch .empty (int (cache_size // 4 ), dtype = torch .int , device = device )
181192
182- # Estimate the runtime of the function
183- start_event = torch .xpu .Event (enable_timing = True )
184- end_event = torch .xpu .Event (enable_timing = True )
185- start_event .record ()
186- for _ in range (5 ):
187- cache .zero_ ()
188- fn ()
189- end_event .record ()
190- synchronize ()
191- estimate_ms = start_event .elapsed_time (end_event ) / 5
192-
193- # compute number of warmup and repeat
194- n_warmup = max (1 , int (warmup / estimate_ms ))
195- n_repeat = max (1 , int (rep / estimate_ms ))
196193 # Warm-up
197194 for _ in range (n_warmup ):
198195 fn ()
0 commit comments