-
Notifications
You must be signed in to change notification settings - Fork 91
truss watch --no-sleep #2211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
truss watch --no-sleep #2211
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,10 +2,12 @@ | |
| import json | ||
| import os | ||
| import sys | ||
| import threading | ||
| import time | ||
| from pathlib import Path | ||
| from typing import Optional, cast | ||
|
|
||
| import requests as requests_lib | ||
| import rich.table | ||
| import rich_click as click | ||
| from rich import progress | ||
|
|
@@ -36,7 +38,7 @@ | |
| get_dev_version_from_versions, | ||
| ) | ||
| from truss.remote.baseten.remote import BasetenRemote | ||
| from truss.remote.baseten.service import BasetenService | ||
| from truss.remote.baseten.service import BasetenService, URLConfig | ||
| from truss.remote.remote_factory import USER_TRUSSRC_PATH, RemoteFactory | ||
| from truss.trt_llm.config_checks import ( | ||
| has_no_tags_trt_llm_builder, | ||
|
|
@@ -853,6 +855,43 @@ def model_logs( | |
| cli_log_utils.output_log(log) | ||
|
|
||
|
|
||
| _KEEPALIVE_INTERVAL_SEC = 30 | ||
| _KEEPALIVE_MAX_CONSECUTIVE_FAILURES = 20 # be very generous | ||
|
|
||
|
|
||
| def _keepalive_loop(url: str, api_key: str, stop_event: threading.Event) -> None: | ||
| headers = {"Authorization": f"Api-Key {api_key}"} | ||
| consecutive_failures = 0 | ||
|
|
||
| while not stop_event.is_set(): | ||
| try: | ||
| resp = requests_lib.get(url, headers=headers, timeout=10) | ||
| if resp.status_code == 200: | ||
| consecutive_failures = 0 | ||
| else: | ||
| try: | ||
| body = resp.json() | ||
| except Exception: | ||
| body = {} | ||
| msg = body.get("error", "") | ||
| if "Model is not ready, it is still building or deploying" not in msg: | ||
| # Readiness will fail when the model is being patched (status LOADING_MODEL), we don't want to count that as a failure | ||
| # TODO, ideally we do this based on error code, but beefeater returns a generic 400 | ||
| consecutive_failures += 1 | ||
| except requests_lib.RequestException: | ||
| consecutive_failures += 1 | ||
|
|
||
| if consecutive_failures >= _KEEPALIVE_MAX_CONSECUTIVE_FAILURES: | ||
| console.print( | ||
| f"⚠️ Keepalive ping failed {consecutive_failures} times in a row. " | ||
| "Exiting truss watch.", | ||
| style="red", | ||
| ) | ||
| os._exit(1) # kill process not just the thread | ||
|
|
||
| stop_event.wait(timeout=_KEEPALIVE_INTERVAL_SEC) | ||
|
|
||
|
|
||
| @truss_cli.command() | ||
| @click.argument("target_directory", required=False, default=os.getcwd()) | ||
| @click.option( | ||
|
|
@@ -874,12 +913,19 @@ def model_logs( | |
| required=False, | ||
| help="Team name for the model to watch", | ||
| ) | ||
| @click.option( | ||
| "--no-sleep", | ||
| is_flag=True, | ||
| default=False, | ||
| help="Keep the development model warm by preventing scale-to-zero while watching.", | ||
| ) | ||
| @common.common_options() | ||
| def watch( | ||
| target_directory: str, | ||
| config: Optional[str], | ||
| remote: str, | ||
| provided_team_name: Optional[str] = None, | ||
| no_sleep: bool = False, | ||
| ) -> None: | ||
| """ | ||
| Seamless remote development with truss | ||
|
|
@@ -916,11 +962,74 @@ def watch( | |
| sys.exit(1) | ||
|
|
||
| # Use model_id to get service (no additional resolution needed) | ||
| service = remote_provider.get_service(model_identifier=ModelId(model_id)) | ||
| dev_version_id = dev_version["id"] | ||
| logs_url = URLConfig.model_logs_url( | ||
| remote_provider.remote_url, model_id, dev_version_id | ||
| ) | ||
| console.print( | ||
| f"🪵 View logs for your deployment at {common.format_link(service.logs_url)}" | ||
| f"🪵 View logs for your development model at {common.format_link(logs_url)}" | ||
| ) | ||
|
|
||
| stop_event = threading.Event() | ||
| if no_sleep: | ||
| model_hostname = resolved_model.get("hostname") | ||
| if not model_hostname: | ||
| console.print( | ||
| "❌ Could not determine model hostname for --no-sleep.", style="red" | ||
| ) | ||
| sys.exit(1) | ||
|
|
||
| # Wake the model in case it's scaled to zero | ||
| wake_url = f"{model_hostname}/development/wake" | ||
| api_key = remote_provider._auth_service.authenticate().value | ||
|
Comment on lines
+982
to
+984
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If Useful? React with 👍 / 👎. |
||
| headers = {"Authorization": f"Api-Key {api_key}"} | ||
| try: | ||
| requests_lib.post(wake_url, headers=headers, timeout=10) | ||
| except requests_lib.RequestException: | ||
| # best effort | ||
| pass | ||
|
|
||
| # Wait for model to be ready before starting keepalive | ||
| with console.status( | ||
| "[bold green]Waiting for development model to be ready..." | ||
| ) as status: | ||
| while True: | ||
| time.sleep(1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Thoughts about making this slightly less aggressive? Agreed it's a balance between appearance of quickness to user and risk to our servers
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. current |
||
| try: | ||
| deployment = remote_provider.api.get_deployment( | ||
| model_id, dev_version_id | ||
| ) | ||
| deployment_status = deployment["status"] | ||
| except Exception: | ||
| continue | ||
| status.update( | ||
| f"[bold green]Waiting for development model to be ready... " | ||
| f"Current Status: {deployment_status}" | ||
| ) | ||
| if deployment_status in [ACTIVE_STATUS] + ["LOADING_MODEL"]: | ||
| # by the time we have status LOADING_MODEL, we should be able to start patching. | ||
| # keepalive thread also handles this state | ||
| break | ||
| if deployment_status not in DEPLOYING_STATUSES + [ | ||
| "SCALED_TO_ZERO", | ||
| "WAKING_UP", | ||
| "UPDATING", | ||
| ]: | ||
| console.print( | ||
| f"❌ Development model failed with status {deployment_status}.", | ||
| style="red", | ||
| ) | ||
| sys.exit(1) | ||
|
|
||
| keepalive_url = f"{model_hostname}/development/sync/v1/models/model" | ||
| console.print("💤 --no-sleep enabled: keeping development model warm") | ||
| keepalive_thread = threading.Thread( | ||
| target=_keepalive_loop, | ||
| args=(keepalive_url, api_key, stop_event), | ||
| daemon=True, | ||
| ) | ||
| keepalive_thread.start() | ||
|
|
||
| if not os.path.isfile(target_directory): | ||
| # Pass the resolved model to avoid re-resolution | ||
| remote_provider.sync_truss_to_dev_version_with_model( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Should we log here too? Something along the lines of 'Model currently inactive, waking.'