Skip to content

Commit 6a59ca1

Browse files
authored
lv2v: Restart infer.py process and cleanup previous stream quickly (#527)
* live/api: Cleanup multipart temp files dir As I was understanding the code to add the last params file I cleaned up usage of that other one which was a little confusing (I wrote it myself haha) and had no explanations. * live/api: You know what? Remove multipart altogether We don't use it anymore, the runner API talks to us only in JSON. * runner/api: Cleanup previous stream trickle channels on start * runner/app: Restart infer.py process on crashes * runner/app: Add a 1s grace period for process startup * lv2v: Final fixes from testing turns out the stdout/err streams dont close automatically when the process exits... Tested many ways but python process management is really bad. Had to workaround a potential thread leak that could happen. Also joined STDERR and STDOUT again ona single stream as that seemed less error prone. Tested that both are still streamed.
1 parent 5f980fd commit 6a59ca1

File tree

2 files changed

+82
-83
lines changed

2 files changed

+82
-83
lines changed

runner/app/live/api/api.py

Lines changed: 36 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
import asyncio
2-
import hashlib
32
import logging
4-
import mimetypes
3+
import json
54
import os
65
import tempfile
76
import time
87
from typing import Optional, cast
98

10-
from aiohttp import BodyPartReader, web
9+
from aiohttp import web
1110
from pydantic import BaseModel, Field
1211
from typing import Annotated, Dict
1312

1413
from streamer import PipelineStreamer, ProcessGuardian
1514
from streamer.protocol.trickle import TrickleProtocol
1615
from streamer.process import config_logging
1716

18-
TEMP_SUBDIR = "infer_temp"
1917
MAX_FILE_AGE = 86400 # 1 day
2018
STREAMER_INPUT_TIMEOUT = 60 # 60s
2119

20+
# File to store the last params that a stream was started with. Used to cleanup
21+
# left over resources (e.g. trickle channels) left by a crashed process.
22+
last_params_file = os.path.join(tempfile.gettempdir(), "ai_runner_last_params.json")
2223

2324
class StartStreamParams(BaseModel):
2425
subscribe_url: Annotated[
@@ -62,44 +63,32 @@ class StartStreamParams(BaseModel):
6263
Field(default="", description="Unique identifier for the stream."),
6364
]
6465

66+
async def cleanup_last_stream():
67+
if not os.path.exists(last_params_file):
68+
logging.debug("No last stream params found to cleanup")
69+
return
6570

66-
def cleanup_old_files(temp_dir):
67-
current_time = time.time()
68-
for filename in os.listdir(temp_dir):
69-
file_path = os.path.join(temp_dir, filename)
70-
if os.path.isfile(file_path):
71-
file_age = current_time - os.path.getmtime(file_path)
72-
if file_age > MAX_FILE_AGE:
73-
os.remove(file_path)
74-
logging.info(f"Removed old file: {file_path}")
71+
try:
72+
with open(last_params_file, "r") as f:
73+
params = StartStreamParams(**json.load(f))
74+
os.remove(last_params_file)
7575

76+
logging.info(f"Cleaning up last stream trickle channels for request_id={params.request_id} subscribe_url={params.subscribe_url} publish_url={params.publish_url} control_url={params.control_url} events_url={params.events_url}")
77+
protocol = TrickleProtocol(
78+
params.subscribe_url,
79+
params.publish_url,
80+
params.control_url,
81+
params.events_url,
82+
)
83+
# Start and stop the protocol to immediately to make sure trickle channels are closed.
84+
await protocol.start()
85+
await protocol.stop()
86+
except:
87+
logging.exception(f"Error cleaning up last stream trickle channels")
7688

77-
async def parse_request_data(request: web.Request, temp_dir: str) -> Dict:
89+
async def parse_request_data(request: web.Request) -> Dict:
7890
if request.content_type.startswith("application/json"):
7991
return await request.json()
80-
elif request.content_type.startswith("multipart/"):
81-
params_data = {}
82-
reader = await request.multipart()
83-
async for part in reader:
84-
if not isinstance(part, BodyPartReader):
85-
continue
86-
elif part.name == "params":
87-
part_data = await part.json()
88-
if part_data:
89-
params_data.update(part_data)
90-
else:
91-
content = await part.read()
92-
file_hash = hashlib.md5(content).hexdigest()
93-
content_type = part.headers.get(
94-
"Content-Type", "application/octet-stream"
95-
)
96-
ext = mimetypes.guess_extension(content_type) or ""
97-
new_filename = f"{file_hash}{ext}"
98-
file_path = os.path.join(temp_dir, new_filename)
99-
with open(file_path, "wb") as f:
100-
f.write(content)
101-
params_data[part.name] = file_path
102-
return params_data
10392
else:
10493
raise ValueError(f"Unknown content type: {request.content_type}")
10594

@@ -118,13 +107,15 @@ async def handle_start_stream(request: web.Request):
118107
logging.error(f"Timeout stopping streamer: {e}")
119108
raise web.HTTPBadRequest(text="Timeout stopping previous streamer")
120109

121-
temp_dir = os.path.join(tempfile.gettempdir(), TEMP_SUBDIR)
122-
os.makedirs(temp_dir, exist_ok=True)
123-
cleanup_old_files(temp_dir)
124-
125-
params_data = await parse_request_data(request, temp_dir)
110+
params_data = await parse_request_data(request)
126111
params = StartStreamParams(**params_data)
127112

113+
try:
114+
with open(last_params_file, "w") as f:
115+
json.dump(params.model_dump(), f)
116+
except Exception as e:
117+
logging.error(f"Error saving last params to file: {e}")
118+
128119
config_logging(request_id=params.request_id, stream_id=params.stream_id)
129120

130121
protocol = TrickleProtocol(
@@ -156,11 +147,7 @@ async def handle_start_stream(request: web.Request):
156147

157148
async def handle_params_update(request: web.Request):
158149
try:
159-
temp_dir = os.path.join(tempfile.gettempdir(), TEMP_SUBDIR)
160-
os.makedirs(temp_dir, exist_ok=True)
161-
cleanup_old_files(temp_dir)
162-
163-
params = await parse_request_data(request, temp_dir)
150+
params = await parse_request_data(request)
164151

165152
process = cast(ProcessGuardian, request.app["process"])
166153
await process.update_params(params)
@@ -180,6 +167,8 @@ async def handle_get_status(request: web.Request):
180167
async def start_http_server(
181168
port: int, process: ProcessGuardian, streamer: Optional[PipelineStreamer] = None
182169
):
170+
asyncio.create_task(cleanup_last_stream())
171+
183172
app = web.Application()
184173
app["process"] = process
185174
app["streamer"] = streamer

runner/app/pipelines/live_video_to_video.py

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,8 @@ def __init__(self, model_id: str):
2525
self.infer_script_path = (
2626
Path(__file__).parent.parent / "live" / "infer.py"
2727
)
28-
try:
29-
logging.info("Starting pipeline process")
30-
self.start_process(
31-
pipeline=self.model_id, # we use the model_id as the pipeline name for now
32-
http_port=8888,
33-
# TODO: set torch device from self.torch_device
34-
)
35-
except Exception as e:
36-
raise InferenceError(original_exception=e)
37-
28+
self.restart_count = 0
29+
self.start_process()
3830

3931
def __call__( # type: ignore
4032
self, *, subscribe_url: str, publish_url: str, control_url: str, events_url: str, params: dict, request_id: str, stream_id: str, **kwargs
@@ -106,37 +98,34 @@ class PipelineStatus(BaseModel):
10698
threading.Thread(target=lambda: self.log_process_diagnostics(full=True)).start()
10799
raise ConnectionError(f"Failed to get status: {e}")
108100

109-
def start_process(self, **kwargs):
101+
def start_process(self):
102+
logging.info("Starting pipeline process")
110103
cmd = [sys.executable, str(self.infer_script_path)]
111-
112-
# Add any additional kwargs as command-line arguments
113-
for key, value in kwargs.items():
114-
kebab_key = key.replace("_", "-")
115-
if isinstance(value, str):
116-
escaped_value = str(value).replace("'", "'\\''")
117-
cmd.extend([f"--{kebab_key}", f"{escaped_value}"])
118-
else:
119-
cmd.extend([f"--{kebab_key}", f"{value}"])
104+
cmd.extend(["--pipeline", self.model_id]) # we use the model_id as the pipeline name for now
105+
cmd.extend(["--http-port", "8888"])
106+
# TODO: set torch device from self.torch_device
120107

121108
env = os.environ.copy()
122109
env["HUGGINGFACE_HUB_CACHE"] = str(self.model_dir)
123110

124111
try:
125112
self.process = subprocess.Popen(
126-
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env
113+
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, env=env
127114
)
128115

129116
self.monitor_thread = threading.Thread(target=self.monitor_process)
130117
self.monitor_thread.start()
131-
self.stdout_log_thread = threading.Thread(target=log_output, args=(self.process.stdout,))
132-
self.stdout_log_thread.start()
133-
self.stderr_log_thread = threading.Thread(target=log_output, args=(self.process.stderr,))
134-
self.stderr_log_thread.start()
118+
self.log_thread = threading.Thread(target=log_output, daemon=True, args=(self.process.stdout,))
119+
self.log_thread.start()
135120

136121
except subprocess.CalledProcessError as e:
137122
raise InferenceError(f"Error starting infer.py: {e}")
138123

139124
def monitor_process(self):
125+
# Wait 1 sec before starting to monitor the process. This gives it some
126+
# time to start and also ensures we won't restart the process too often.
127+
time.sleep(1)
128+
140129
while True:
141130
if not self.process:
142131
logging.error("No process to monitor")
@@ -157,11 +146,33 @@ def monitor_process(self):
157146

158147
logging.info(f"infer.py process exited with return_code={return_code}")
159148
self.log_process_diagnostics(full=True)
160-
self.stop_process(is_monitor_thread=True)
161-
return
149+
break
150+
151+
self.restart_count += 1
152+
if self.restart_count > 10:
153+
logging.error("infer.py process has restarted more than 10 times. Exiting.")
154+
os._exit(1)
155+
156+
# Start a separate thread to restart the process since it will
157+
# restart the monitor thread itself (the current thread).
158+
def restart_process():
159+
try:
160+
logging.info(f"Restarting infer.py process restart_count={self.restart_count}")
161+
self.stop_process()
162+
self.start_process()
163+
except Exception as e:
164+
logging.error(f"Error restarting infer.py process: {e}")
165+
os._exit(1)
166+
threading.Thread(target=restart_process).start()
162167

163-
def stop_process(self, is_monitor_thread: bool = False):
168+
def stop_process(self):
164169
if self.process:
170+
if self.process.stdout:
171+
# Closing the output stream sometimes hangs, so we do it in a separate daemon thread
172+
# and join the log_thread below with a timeout. If it does hang there might be a thread
173+
# leak which is why we limit to up to 10 restarts.
174+
stdout = self.process.stdout
175+
threading.Thread(target=lambda: stdout.close(), daemon=True).start()
165176
self.process.terminate()
166177
try:
167178
self.process.wait(timeout=10)
@@ -174,15 +185,14 @@ def stop_process(self, is_monitor_thread: bool = False):
174185
logging.error(f"Error while force killing process: {e}")
175186
os._exit(1)
176187
self.process = None
177-
if self.monitor_thread and not is_monitor_thread:
188+
if self.monitor_thread:
178189
self.monitor_thread.join()
179190
self.monitor_thread = None
180-
if hasattr(self, 'stdout_log_thread') and self.stdout_log_thread:
181-
self.stdout_log_thread.join()
182-
self.stdout_log_thread = None
183-
if hasattr(self, 'stderr_log_thread') and self.stderr_log_thread:
184-
self.stderr_log_thread.join()
185-
self.stderr_log_thread = None
191+
if self.log_thread:
192+
self.log_thread.join(timeout=1)
193+
if self.log_thread.is_alive():
194+
logging.warning("Log thread did not finish")
195+
self.log_thread = None
186196
logging.info("Infer process stopped successfully")
187197

188198

@@ -252,7 +262,7 @@ def read_proc_as_map(path: str) -> dict | str:
252262
with open(path, "r") as f:
253263
return f.read()
254264

255-
os_proc_info = {}
265+
os_proc_info: dict[str, str | dict] = {}
256266
for proc_file in ["status", "wchan", "io"]:
257267
try:
258268
path = f"/proc/{pid}/{proc_file}"

0 commit comments

Comments
 (0)