File tree Expand file tree Collapse file tree 1 file changed +4
-7
lines changed Expand file tree Collapse file tree 1 file changed +4
-7
lines changed Original file line number Diff line number Diff line change @@ -1031,13 +1031,8 @@ def run(
1031
1031
# before we hand them to the captured graph. Otherwise we can
1032
1032
# read partially initialized values (e.g. from torch.randint)
1033
1033
# and hit device-side asserts in the baseline kernels.
1034
- if (
1035
- self .use_cuda_graphs
1036
- and self .device
1037
- and self .device .startswith ("cuda" )
1038
- and torch .cuda .is_available ()
1039
- ):
1040
- torch .cuda .synchronize ()
1034
+ if self .use_cuda_graphs :
1035
+ torch .accelerator .synchronize ()
1041
1036
self .baseline_fn = None
1042
1037
self .baseline_metrics = None
1043
1038
self ._op_flops = {}
@@ -1108,6 +1103,8 @@ def _reduce_benchmarks(acc, bm_name: str):
1108
1103
quantiles = quantiles ,
1109
1104
baseline = baseline ,
1110
1105
)
1106
+ # Synchronize after each benchmark to make errors surface sooner
1107
+ torch .accelerator .synchronize ()
1111
1108
if baseline :
1112
1109
self .baseline_metrics = acc [bm_name ]
1113
1110
if sleep :
You can’t perform that action at this time.
0 commit comments