Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "truss"
version = "0.13.3"
version = "0.13.4rc001"
description = "A seamless bridge from model development to model delivery"
authors = [
{ name = "Pankaj Gupta", email = "no-reply@baseten.co" },
Expand Down
115 changes: 112 additions & 3 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import json
import os
import sys
import threading
import time
from pathlib import Path
from typing import Optional, cast

import requests as requests_lib
import rich.table
import rich_click as click
from rich import progress
Expand Down Expand Up @@ -36,7 +38,7 @@
get_dev_version_from_versions,
)
from truss.remote.baseten.remote import BasetenRemote
from truss.remote.baseten.service import BasetenService
from truss.remote.baseten.service import BasetenService, URLConfig
from truss.remote.remote_factory import USER_TRUSSRC_PATH, RemoteFactory
from truss.trt_llm.config_checks import (
has_no_tags_trt_llm_builder,
Expand Down Expand Up @@ -853,6 +855,43 @@ def model_logs(
cli_log_utils.output_log(log)


_KEEPALIVE_INTERVAL_SEC = 30
_KEEPALIVE_MAX_CONSECUTIVE_FAILURES = 20 # be very generous


def _keepalive_loop(url: str, api_key: str, stop_event: threading.Event) -> None:
headers = {"Authorization": f"Api-Key {api_key}"}
consecutive_failures = 0

while not stop_event.is_set():
try:
resp = requests_lib.get(url, headers=headers, timeout=10)
if resp.status_code == 200:
consecutive_failures = 0
else:
try:
body = resp.json()
except Exception:
body = {}
msg = body.get("error", "")
if "Model is not ready, it is still building or deploying" not in msg:
# Readiness will fail when the model is being patched (status LOADING_MODEL), we don't want to count that as a failure
# TODO, ideally we do this based on error code, but beefeater returns a generic 400
consecutive_failures += 1
except requests_lib.RequestException:
consecutive_failures += 1

if consecutive_failures >= _KEEPALIVE_MAX_CONSECUTIVE_FAILURES:
console.print(
f"⚠️ Keepalive ping failed {consecutive_failures} times in a row. "
"Exiting truss watch.",
style="red",
)
os._exit(1) # kill process not just the thread

stop_event.wait(timeout=_KEEPALIVE_INTERVAL_SEC)


@truss_cli.command()
@click.argument("target_directory", required=False, default=os.getcwd())
@click.option(
Expand All @@ -874,12 +913,19 @@ def model_logs(
required=False,
help="Team name for the model to watch",
)
@click.option(
"--no-sleep",
is_flag=True,
default=False,
help="Keep the development model warm by preventing scale-to-zero while watching.",
)
@common.common_options()
def watch(
target_directory: str,
config: Optional[str],
remote: str,
provided_team_name: Optional[str] = None,
no_sleep: bool = False,
) -> None:
"""
Seamless remote development with truss
Expand Down Expand Up @@ -916,11 +962,74 @@ def watch(
sys.exit(1)

# Use model_id to get service (no additional resolution needed)
service = remote_provider.get_service(model_identifier=ModelId(model_id))
dev_version_id = dev_version["id"]
logs_url = URLConfig.model_logs_url(
remote_provider.remote_url, model_id, dev_version_id
)
console.print(
f"🪵 View logs for your deployment at {common.format_link(service.logs_url)}"
f"🪵 View logs for your development model at {common.format_link(logs_url)}"
)

stop_event = threading.Event()
if no_sleep:
model_hostname = resolved_model.get("hostname")
if not model_hostname:
console.print(
"❌ Could not determine model hostname for --no-sleep.", style="red"
)
sys.exit(1)

# Wake the model in case it's scaled to zero
wake_url = f"{model_hostname}/development/wake"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Should we log here too? Something along the lines of 'Model currently inactive, waking.'

api_key = remote_provider._auth_service.authenticate().value
Comment on lines +982 to +984

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize model hostname before building keepalive URLs

If resolved_model["hostname"] is a bare host (e.g. the repo’s tests and API mocks use values like host.baseten.co/hostname without a scheme), then f"{model_hostname}/development/wake" (and the later keepalive URL) becomes an invalid URL and requests raises MissingSchema. That means --no-sleep will never successfully wake/keepalive and will eventually exit after 20 failures. Consider normalizing the hostname (e.g. prefix https:// when missing) or reusing a URL builder that guarantees a full scheme.

Useful? React with 👍 / 👎.

headers = {"Authorization": f"Api-Key {api_key}"}
try:
requests_lib.post(wake_url, headers=headers, timeout=10)
except requests_lib.RequestException:
# best effort
pass

# Wait for model to be ready before starting keepalive
with console.status(
"[bold green]Waiting for development model to be ready..."
) as status:
while True:
time.sleep(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Thoughts about making this slightly less aggressive? Agreed it's a balance between appearance of quickness to user and risk to our servers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current --wait functionality also uses a 1s sleep for fetching for status but yes I'm ok with increasing this value

try:
deployment = remote_provider.api.get_deployment(
model_id, dev_version_id
)
deployment_status = deployment["status"]
except Exception:
continue
status.update(
f"[bold green]Waiting for development model to be ready... "
f"Current Status: {deployment_status}"
)
if deployment_status in [ACTIVE_STATUS] + ["LOADING_MODEL"]:
# by the time we have status LOADING_MODEL, we should be able to start patching.
# keepalive thread also handles this state
break
if deployment_status not in DEPLOYING_STATUSES + [
"SCALED_TO_ZERO",
"WAKING_UP",
"UPDATING",
]:
console.print(
f"❌ Development model failed with status {deployment_status}.",
style="red",
)
sys.exit(1)

keepalive_url = f"{model_hostname}/development/sync/v1/models/model"
console.print("💤 --no-sleep enabled: keeping development model warm")
keepalive_thread = threading.Thread(
target=_keepalive_loop,
args=(keepalive_url, api_key, stop_event),
daemon=True,
)
keepalive_thread.start()

if not os.path.isfile(target_directory):
# Pass the resolved model to avoid re-resolution
remote_provider.sync_truss_to_dev_version_with_model(
Expand Down
Loading
Loading