@@ -424,7 +424,7 @@ def __init__(
424424 self .extra_args = extra_args
425425
426426 def make_run_fn (
427- self , * input_tensors : torch .Tensor , output_tensor : torch .Tensor
427+ self , * input_tensors : torch .Tensor , out : torch .Tensor
428428 ) -> Callable [[], None ]:
429429 raise NotImplementedError
430430
@@ -435,30 +435,30 @@ def do_bench(
435435 self ,
436436 fn ,
437437 * input_tensors : torch .Tensor ,
438- output_tensor : Optional [torch .Tensor ] = None ,
438+ out : Optional [torch .Tensor ] = None ,
439439 ) -> float :
440440 raise NotImplementedError
441441
442442 def benchmark (
443443 self ,
444444 * input_tensors : torch .Tensor ,
445- output_tensor : Optional [torch .Tensor ] = None ,
445+ out : Optional [torch .Tensor ] = None ,
446446 ) -> float :
447447 debug = autotuning_log .isEnabledFor (logging .DEBUG )
448448 if debug :
449449 start_ts = time .time ()
450450
451451 # create args and out tensor
452- if output_tensor is None :
452+ if out is None :
453453 assert len (input_tensors ) == 0
454454 input_tensors = tuple (x .to_tensor () for x in self .input_tensor_meta )
455- output_tensor = self .output_tensor_meta .to_tensor ()
455+ out = self .output_tensor_meta .to_tensor ()
456456
457457 if debug :
458458 create_tensor_elapse = time .time () - start_ts # type: ignore[possibly-undefined]
459459 start_ts = time .time ()
460460 try :
461- fn = self .make_run_fn (* input_tensors , output_tensor = output_tensor )
461+ fn = self .make_run_fn (* input_tensors , out = out )
462462 except NonzeroWorkspaceNotSupportedError :
463463 # Skipping all ops with nonzero workspace requirements
464464 autotuning_log .info ("Skipping op due to nonzero workspace requirement" )
@@ -468,7 +468,7 @@ def benchmark(
468468 load_elapse = time .time () - start_ts # type: ignore[possibly-undefined]
469469 start_ts = time .time ()
470470
471- out = self .do_bench (fn , * input_tensors , output_tensor )
471+ res = self .do_bench (fn , * input_tensors , out )
472472
473473 if debug :
474474 bench_elapse = time .time () - start_ts # type: ignore[possibly-undefined]
@@ -480,7 +480,7 @@ def benchmark(
480480 bench_elapse ,
481481 )
482482 self .cleanup_run_fn ()
483- return out
483+ return res
484484
485485
486486class _TestBenchmarkRequest (BenchmarkRequest ):
@@ -504,7 +504,7 @@ def __init__(
504504 self .crash = crash
505505
506506 def benchmark (
507- self , * input_tensors : torch .Tensor , output_tensor : Optional [torch .Tensor ] = None
507+ self , * input_tensors : torch .Tensor , out : Optional [torch .Tensor ] = None
508508 ) -> float :
509509 if self .device is not None :
510510 assert os .environ .get (CUDA_VISIBLE_DEVICES , None ) == str (self .device )
@@ -522,11 +522,11 @@ def do_bench(
522522 self ,
523523 fn ,
524524 * input_tensors : torch .Tensor ,
525- output_tensor : Optional [torch .Tensor ] = None ,
525+ out : Optional [torch .Tensor ] = None ,
526526 ) -> float :
527527 device_idx_set = OrderedSet (
528528 tensor .device .index
529- for tensor in [* input_tensors , output_tensor ]
529+ for tensor in [* input_tensors , out ]
530530 if isinstance (tensor , torch .Tensor )
531531 and is_gpu (tensor .device .type )
532532 and tensor .device .index is not None
@@ -546,18 +546,18 @@ def do_bench(
546546 else :
547547 device_idx = device_interface .current_device ()
548548 with device_interface .device (device_idx ): # type: ignore[attr-defined]
549- out = benchmarker .benchmark_gpu (fn )
549+ res = benchmarker .benchmark_gpu (fn )
550550 device_interface .synchronize () # shake out any CUDA errors
551551
552- return out
552+ return res
553553
554554
555555class CPUDeviceBenchmarkMixin :
556556 def do_bench (
557557 self ,
558558 fn ,
559559 * input_tensors : torch .Tensor ,
560- output_tensor : Optional [torch .Tensor ] = None ,
560+ out : Optional [torch .Tensor ] = None ,
561561 ) -> float :
562562 return benchmarker .benchmark_cpu (fn )
563563
@@ -593,7 +593,7 @@ def __init__(
593593 self .kpack = kpack
594594
595595 def make_run_fn (
596- self , * input_tensors : torch .Tensor , output_tensor : torch .Tensor
596+ self , * input_tensors : torch .Tensor , out : torch .Tensor
597597 ) -> Callable [[], None ]:
598598 mod = PyCodeCache .load_by_key_path (self .module_cache_key , self .module_path )
599599 autotuning_log .debug (
@@ -614,10 +614,10 @@ def make_run_fn(
614614 if "warmup" in inspect .signature (run_method ).parameters :
615615 warmup_arg ["warmup" ] = False
616616
617- if output_tensor .device .type == "cpu" :
617+ if out .device .type == "cpu" :
618618 stream = 0
619619 else :
620- device_type = output_tensor .device .type
620+ device_type = out .device .type
621621 device_interface = get_interface_for_device (device_type )
622622 stream = device_interface .get_raw_stream (
623623 self .output_tensor_meta .device .index
@@ -630,7 +630,7 @@ def make_run_fn(
630630 return functools .partial (
631631 run_method ,
632632 * input_tensors ,
633- output_tensor ,
633+ out ,
634634 * extra_args ,
635635 ** warmup_arg ,
636636 stream = stream ,
@@ -639,7 +639,7 @@ def make_run_fn(
639639 return functools .partial (
640640 run_method ,
641641 * input_tensors ,
642- output_tensor ,
642+ out ,
643643 * extra_args ,
644644 ** warmup_arg ,
645645 stream = stream ,
@@ -692,14 +692,11 @@ def precompile(self):
692692 autotuning_log .debug ("Done precompiling %s" , self )
693693
694694 def make_run_fn (
695- self , * input_tensors : torch .Tensor , output_tensor : torch .Tensor
695+ self , * input_tensors : torch .Tensor , out : torch .Tensor
696696 ) -> Callable [[], None ]:
697697 self .ensure_dll_loaded ()
698698 self .update_workspace_size ()
699- args = [
700- c_void_p (tensor .data_ptr ())
701- for tensor in list (input_tensors ) + [output_tensor ]
702- ]
699+ args = [c_void_p (tensor .data_ptr ()) for tensor in list (input_tensors ) + [out ]]
703700 autotuning_log .debug (
704701 "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s" ,
705702 self .kernel_name ,
@@ -716,7 +713,7 @@ def make_run_fn(
716713 self .workspace = torch .zeros (
717714 (self .workspace_size + 7 ) // 8 ,
718715 dtype = torch .float64 ,
719- device = output_tensor .device ,
716+ device = out .device ,
720717 )
721718 workspace_ptr = c_void_p (self .workspace .data_ptr ())
722719
@@ -806,11 +803,11 @@ def precompile(self):
806803 autotuning_log .debug ("Done precompiling %s" , self )
807804
808805 def make_run_fn (
809- self , * input_tensors : torch .Tensor , output_tensor : torch .Tensor
806+ self , * input_tensors : torch .Tensor , out : torch .Tensor
810807 ) -> Callable [[], None ]:
811808 # TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf
812809 self .DLL = CppCodeCache .load (self .source_code , device_type = "cpu" )
813- args = [tensor .data_ptr () for tensor in list (input_tensors ) + [output_tensor ]]
810+ args = [tensor .data_ptr () for tensor in list (input_tensors ) + [out ]]
814811 autotuning_log .debug (
815812 "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s" ,
816813 self .kernel_name ,
0 commit comments