Skip to content

Commit f1de3f9

Browse files
masnesralpytorchmergebot
authored andcommitted
Rename "output_tensor" -> "out" in autotune_process.py (pytorch#153169)
Summary: This change is to support remote autotuning. I want to use all the same benchmarking utilities in select_algorithm.py. For remote autotuning, I'll reuse the TritonBenchmarkRequest class used for subprocess autotuning because it's already serializable. That class is also used in standard, in-process autotuning, but via TritonTemplateCaller.benchmark() which sets the output_tensor param when calling the underlying TritonBenchmarkRequest. For remote, I'll be using the TritonBenchmarkRequest request directly so I want the parameter to be named 'out' to avoid "got an unexpected keyword argument 'out'". Test Plan: Existing unit tests Pull Request resolved: pytorch#153169 Approved by: https://github.com/aorenste, https://github.com/eellison
1 parent 9f98e37 commit f1de3f9

File tree

6 files changed

+31
-37
lines changed

6 files changed

+31
-37
lines changed

torch/_inductor/autotune_process.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def __init__(
424424
self.extra_args = extra_args
425425

426426
def make_run_fn(
427-
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
427+
self, *input_tensors: torch.Tensor, out: torch.Tensor
428428
) -> Callable[[], None]:
429429
raise NotImplementedError
430430

@@ -435,30 +435,30 @@ def do_bench(
435435
self,
436436
fn,
437437
*input_tensors: torch.Tensor,
438-
output_tensor: Optional[torch.Tensor] = None,
438+
out: Optional[torch.Tensor] = None,
439439
) -> float:
440440
raise NotImplementedError
441441

442442
def benchmark(
443443
self,
444444
*input_tensors: torch.Tensor,
445-
output_tensor: Optional[torch.Tensor] = None,
445+
out: Optional[torch.Tensor] = None,
446446
) -> float:
447447
debug = autotuning_log.isEnabledFor(logging.DEBUG)
448448
if debug:
449449
start_ts = time.time()
450450

451451
# create args and out tensor
452-
if output_tensor is None:
452+
if out is None:
453453
assert len(input_tensors) == 0
454454
input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta)
455-
output_tensor = self.output_tensor_meta.to_tensor()
455+
out = self.output_tensor_meta.to_tensor()
456456

457457
if debug:
458458
create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
459459
start_ts = time.time()
460460
try:
461-
fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
461+
fn = self.make_run_fn(*input_tensors, out=out)
462462
except NonzeroWorkspaceNotSupportedError:
463463
# Skipping all ops with nonzero workspace requirements
464464
autotuning_log.info("Skipping op due to nonzero workspace requirement")
@@ -468,7 +468,7 @@ def benchmark(
468468
load_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
469469
start_ts = time.time()
470470

471-
out = self.do_bench(fn, *input_tensors, output_tensor)
471+
res = self.do_bench(fn, *input_tensors, out)
472472

473473
if debug:
474474
bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
@@ -480,7 +480,7 @@ def benchmark(
480480
bench_elapse,
481481
)
482482
self.cleanup_run_fn()
483-
return out
483+
return res
484484

485485

486486
class _TestBenchmarkRequest(BenchmarkRequest):
@@ -504,7 +504,7 @@ def __init__(
504504
self.crash = crash
505505

506506
def benchmark(
507-
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
507+
self, *input_tensors: torch.Tensor, out: Optional[torch.Tensor] = None
508508
) -> float:
509509
if self.device is not None:
510510
assert os.environ.get(CUDA_VISIBLE_DEVICES, None) == str(self.device)
@@ -522,11 +522,11 @@ def do_bench(
522522
self,
523523
fn,
524524
*input_tensors: torch.Tensor,
525-
output_tensor: Optional[torch.Tensor] = None,
525+
out: Optional[torch.Tensor] = None,
526526
) -> float:
527527
device_idx_set = OrderedSet(
528528
tensor.device.index
529-
for tensor in [*input_tensors, output_tensor]
529+
for tensor in [*input_tensors, out]
530530
if isinstance(tensor, torch.Tensor)
531531
and is_gpu(tensor.device.type)
532532
and tensor.device.index is not None
@@ -546,18 +546,18 @@ def do_bench(
546546
else:
547547
device_idx = device_interface.current_device()
548548
with device_interface.device(device_idx): # type: ignore[attr-defined]
549-
out = benchmarker.benchmark_gpu(fn)
549+
res = benchmarker.benchmark_gpu(fn)
550550
device_interface.synchronize() # shake out any CUDA errors
551551

552-
return out
552+
return res
553553

554554

555555
class CPUDeviceBenchmarkMixin:
556556
def do_bench(
557557
self,
558558
fn,
559559
*input_tensors: torch.Tensor,
560-
output_tensor: Optional[torch.Tensor] = None,
560+
out: Optional[torch.Tensor] = None,
561561
) -> float:
562562
return benchmarker.benchmark_cpu(fn)
563563

@@ -593,7 +593,7 @@ def __init__(
593593
self.kpack = kpack
594594

595595
def make_run_fn(
596-
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
596+
self, *input_tensors: torch.Tensor, out: torch.Tensor
597597
) -> Callable[[], None]:
598598
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
599599
autotuning_log.debug(
@@ -614,10 +614,10 @@ def make_run_fn(
614614
if "warmup" in inspect.signature(run_method).parameters:
615615
warmup_arg["warmup"] = False
616616

617-
if output_tensor.device.type == "cpu":
617+
if out.device.type == "cpu":
618618
stream = 0
619619
else:
620-
device_type = output_tensor.device.type
620+
device_type = out.device.type
621621
device_interface = get_interface_for_device(device_type)
622622
stream = device_interface.get_raw_stream(
623623
self.output_tensor_meta.device.index
@@ -630,7 +630,7 @@ def make_run_fn(
630630
return functools.partial(
631631
run_method,
632632
*input_tensors,
633-
output_tensor,
633+
out,
634634
*extra_args,
635635
**warmup_arg,
636636
stream=stream,
@@ -639,7 +639,7 @@ def make_run_fn(
639639
return functools.partial(
640640
run_method,
641641
*input_tensors,
642-
output_tensor,
642+
out,
643643
*extra_args,
644644
**warmup_arg,
645645
stream=stream,
@@ -692,14 +692,11 @@ def precompile(self):
692692
autotuning_log.debug("Done precompiling %s", self)
693693

694694
def make_run_fn(
695-
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
695+
self, *input_tensors: torch.Tensor, out: torch.Tensor
696696
) -> Callable[[], None]:
697697
self.ensure_dll_loaded()
698698
self.update_workspace_size()
699-
args = [
700-
c_void_p(tensor.data_ptr())
701-
for tensor in list(input_tensors) + [output_tensor]
702-
]
699+
args = [c_void_p(tensor.data_ptr()) for tensor in list(input_tensors) + [out]]
703700
autotuning_log.debug(
704701
"make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
705702
self.kernel_name,
@@ -716,7 +713,7 @@ def make_run_fn(
716713
self.workspace = torch.zeros(
717714
(self.workspace_size + 7) // 8,
718715
dtype=torch.float64,
719-
device=output_tensor.device,
716+
device=out.device,
720717
)
721718
workspace_ptr = c_void_p(self.workspace.data_ptr())
722719

@@ -806,11 +803,11 @@ def precompile(self):
806803
autotuning_log.debug("Done precompiling %s", self)
807804

808805
def make_run_fn(
809-
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
806+
self, *input_tensors: torch.Tensor, out: torch.Tensor
810807
) -> Callable[[], None]:
811808
# TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf
812809
self.DLL = CppCodeCache.load(self.source_code, device_type="cpu")
813-
args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]]
810+
args = [tensor.data_ptr() for tensor in list(input_tensors) + [out]]
814811
autotuning_log.debug(
815812
"make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s",
816813
self.kernel_name,

torch/_inductor/codegen/cpp_template_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def precompile(self) -> None:
553553

554554
def benchmark(self, *args, out) -> float:
555555
assert self.bmreq is not None
556-
return self.bmreq.benchmark(*args, output_tensor=out)
556+
return self.bmreq.benchmark(*args, out=out)
557557

558558
def hash_key(self) -> str:
559559
return "-".join(

torch/_inductor/codegen/cuda/cuda_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def precompile(self) -> None:
590590
def benchmark(self, *args, out) -> float:
591591
assert self.bmreq is not None
592592
return self.bmreq.benchmark(
593-
*args, output_tensor=out
593+
*args, out=out
594594
) # @TODO: Hack for ensuring that Cutlass Kernel is preferred
595595

596596
def __str__(self) -> str:

torch/_inductor/codegen/rocm/rocm_benchmark_request.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,11 @@ def precompile(self):
5555
log.debug("Done precompiling %s", self)
5656

5757
def make_run_fn(
58-
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
58+
self, *input_tensors: torch.Tensor, out: torch.Tensor
5959
) -> Callable[[], None]:
6060
self.ensure_dll_loaded()
6161
self.update_workspace_size()
62-
args = [
63-
c_void_p(tensor.data_ptr())
64-
for tensor in list(input_tensors) + [output_tensor]
65-
]
62+
args = [c_void_p(tensor.data_ptr()) for tensor in list(input_tensors) + [out]]
6663
size_args = [c_int(arg) for arg in self.extra_args]
6764
log.debug(
6865
"make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
@@ -80,7 +77,7 @@ def make_run_fn(
8077
self.workspace = torch.zeros(
8178
(self.workspace_size + 7) // 8,
8279
dtype=torch.float64,
83-
device=output_tensor.device,
80+
device=out.device,
8481
)
8582
workspace_ptr = c_void_p(self.workspace.data_ptr())
8683

torch/_inductor/codegen/rocm/rocm_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def precompile(self) -> None:
246246

247247
def benchmark(self, *args, out) -> float:
248248
assert self.bmreq is not None
249-
return self.bmreq.benchmark(*args, output_tensor=out)
249+
return self.bmreq.benchmark(*args, out=out)
250250

251251
def __str__(self) -> str:
252252
return f"ROCmTemplateCaller(source_file={self.bmreq.source_file}, {self.info_dict()})"

torch/_inductor/select_algorithm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1500,7 +1500,7 @@ def __init__(
15001500

15011501
def benchmark(self, *args, out):
15021502
assert self.bmreq is not None
1503-
return self.bmreq.benchmark(*args, output_tensor=out)
1503+
return self.bmreq.benchmark(*args, out=out)
15041504

15051505
def precompile(self):
15061506
assert self.bmreq is not None

0 commit comments

Comments
 (0)