Skip to content

Commit bbf39ca

Browse files
nmacchionipytorchmergebot
authored andcommitted
[inductor][fix] subproc autotuning respect cache dir changes (pytorch#167918)
Summary: noticed this bug with subproc autotuning while working on async autotuning the created subprocs don't respect changes to cache dirs, specifically the Triton cache dir, which causes subproc autotuning to cache miss on otherwise cached Triton kernels, net effect being that precompile in gemm autotuning path became an expensive no-op on the torchbench model I tested with, compile time with subproc autotuning went down from ~1k seconds to ~500 seconds, now matching in-process autotuning Test Plan: CI Differential Revision: D87170069 Pull Request resolved: pytorch#167918 Approved by: https://github.com/aorenste
1 parent 654f3f6 commit bbf39ca

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

torch/_inductor/autotune_process.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,13 @@ def process_main(read_pipe: IO[bytes], write_pipe: IO[bytes]) -> None:
7878

7979
def workloop():
8080
while True:
81-
job = TuningProcess.recv(read_pipe)
81+
job, extra_env = TuningProcess.recv(read_pipe)
8282
if job is None:
8383
# None is a sentinel for the child to shut down
8484
break
8585
try:
86+
if extra_env:
87+
os.environ.update(extra_env)
8688
result = job()
8789
except Exception as e:
8890
result = e
@@ -95,8 +97,10 @@ def workloop():
9597
pass
9698

9799
@staticmethod
98-
def send(obj: Any, write_pipe: IO[bytes]) -> None:
99-
pickle.dump(obj, write_pipe)
100+
def send(
101+
obj: Any, write_pipe: IO[bytes], extra_env: dict[str, str] | None = None
102+
) -> None:
103+
pickle.dump((obj, extra_env), write_pipe)
100104
write_pipe.flush()
101105

102106
@staticmethod
@@ -158,13 +162,13 @@ def alive(self) -> bool:
158162
"""
159163
return self.running and self.process.poll() is None
160164

161-
def put(self, req: Any) -> None:
165+
def put(self, req: Any, extra_env: dict[str, str] | None = None) -> None:
162166
"""
163167
Push a work item to the child process.
164168
"""
165169
if not self.alive():
166170
self.start()
167-
TuningProcess.send(req, self.write_pipe)
171+
TuningProcess.send(req, self.write_pipe, extra_env=extra_env)
168172

169173
def get(self, timeout: float = 120.0) -> Any:
170174
"""
@@ -174,7 +178,7 @@ def get(self, timeout: float = 120.0) -> Any:
174178
try:
175179
if not self.selector.select(timeout):
176180
raise TimeoutError(f"Timeout in autotune subprocess {self.process.pid}")
177-
result = TuningProcess.recv(self.read_pipe)
181+
result, _ = TuningProcess.recv(self.read_pipe)
178182
except TimeoutError:
179183
self.kill()
180184
raise
@@ -305,8 +309,10 @@ def target(self, choice: TritonTemplateCaller) -> float:
305309
"""
306310
assert choice.bmreq is not None
307311

312+
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
313+
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
308314
process = self.process_queue.get()
309-
process.put(choice.bmreq.benchmark)
315+
process.put(choice.bmreq.benchmark, extra_env=extra_env)
310316
try:
311317
return process.get(
312318
config.max_autotune_subproc_result_timeout_seconds,

0 commit comments

Comments
 (0)