Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"requests>=2.31",
"rich>=13.4.2,<14",
"rich-click>=1.6.1,<2",
"ruamel.yaml>=0.18.0,<0.19",
"ruff>=0.4.8",
"tenacity>=8.0.1",
"watchfiles>=0.19.0,<0.20",
Expand Down
4 changes: 3 additions & 1 deletion truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def login(api_key: Optional[str]):
def upgrade(ctx: click.Context, version: Optional[str]) -> None:
"""Upgrade truss to the latest (or specified) version."""
interactive = not ctx.obj.get("non_interactive", False)
# Check for migrations before upgrade (in case already up-to-date)
self_upgrade._notify_about_available_migrations()
self_upgrade.run_upgrade(version, interactive=interactive)


Expand Down Expand Up @@ -1034,4 +1036,4 @@ def kill_all() -> None:


# These imports are needed to register the subcommands
from truss.cli import chains_commands, train_commands # noqa: F401
from truss.cli import chains_commands, migrate_commands, train_commands # noqa: F401
319 changes: 319 additions & 0 deletions truss/cli/migrate_commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
"""Migration commands for truss CLI.

Provides commands to migrate deprecated config formats to the new weights API.
"""

import difflib
import io
import os
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any, Dict

import rich_click as click
from rich.console import Console
from rich.syntax import Syntax
from ruamel.yaml import YAML

from truss.base.constants import MODEL_CACHE_PATH
from truss.base.truss_config import ExternalDataItem, ModelRepo, ModelRepoSourceKind
from truss.cli.cli import truss_cli
from truss.cli.migrations.detection import get_available_migrations
from truss.cli.migrations.history import record_migration_applied
from truss.cli.utils import common

console = Console()
error_console = Console(stderr=True)

# Data directory path (where external_data files are downloaded)
DATA_DIR_PATH = Path("/app/data")

# Initialize ruamel.yaml for round-trip (comment-preserving) parsing
yaml = YAML()
yaml.preserve_quotes = True
yaml.default_flow_style = False


def generate_source_uri(model: ModelRepo) -> str:
"""Generate source URI from ModelRepo based on its kind."""
kind = model.kind
repo_id = model.repo_id
revision = model.revision

if kind == ModelRepoSourceKind.HF:
# HuggingFace: hf://owner/repo or hf://owner/repo@revision
if revision:
return f"hf://{repo_id}@{revision}"
return f"hf://{repo_id}"
elif kind == ModelRepoSourceKind.GCS:
# GCS: repo_id should already have gs:// prefix or be a bucket path
if repo_id.startswith("gs://"):
return repo_id
return f"gs://{repo_id}"
elif kind == ModelRepoSourceKind.S3:
# S3: repo_id should already have s3:// prefix or be a bucket path
if repo_id.startswith("s3://"):
return repo_id
return f"s3://{repo_id}"
elif kind == ModelRepoSourceKind.AZURE:
# Azure: repo_id should already have azure:// prefix or be an account path
if repo_id.startswith("azure://"):
return repo_id
return f"azure://{repo_id}"
else:
# Default to treating as HuggingFace
if revision:
return f"hf://{repo_id}@{revision}"
return f"hf://{repo_id}"


def generate_mount_location_for_model(model: ModelRepo) -> str:
"""Generate mount_location for a ModelRepo.

For v2 (use_volume=True): Uses volume_folder
For v1 (use_volume=False): Generates from repo_id
"""
if model.use_volume and model.volume_folder:
# v2: Use the explicit volume_folder
return str(MODEL_CACHE_PATH / model.volume_folder)

# v1: Generate from repo_id
kind = model.kind
repo_id = model.repo_id

if kind == ModelRepoSourceKind.HF:
# Sanitize HuggingFace repo_id: owner/repo -> owner_repo
sanitized = repo_id.replace("/", "_")
return str(MODEL_CACHE_PATH / sanitized)
elif kind in (
ModelRepoSourceKind.GCS,
ModelRepoSourceKind.S3,
ModelRepoSourceKind.AZURE,
):
# For cloud storage, extract bucket name from the path
# Remove any scheme prefix first
path = repo_id
for prefix in ("gs://", "s3://", "azure://"):
if path.startswith(prefix):
path = path[len(prefix) :]
break
# Use the first path component (bucket name)
bucket_name = path.split("/")[0]
return str(MODEL_CACHE_PATH / bucket_name)
else:
# Default: sanitize repo_id
sanitized = repo_id.replace("/", "_")
return str(MODEL_CACHE_PATH / sanitized)


def convert_model_repo_to_weights(model: ModelRepo) -> Dict[str, Any]:
"""Convert a ModelRepo to a WeightsSource dict."""
source = generate_source_uri(model)
mount_location = generate_mount_location_for_model(model)

result: Dict[str, Any] = {"source": source, "mount_location": mount_location}

# Map runtime_secret_name to auth_secret_name
if model.runtime_secret_name:
result["auth_secret_name"] = model.runtime_secret_name

# Preserve patterns if set
if model.allow_patterns:
result["allow_patterns"] = model.allow_patterns
if model.ignore_patterns:
result["ignore_patterns"] = model.ignore_patterns

return result


def convert_external_data_to_weights(item: ExternalDataItem) -> dict:
"""Convert an ExternalDataItem to a WeightsSource dict."""
# URL is already https://
source = item.url

# Mount location is /app/data/{local_data_path}
mount_location = str(DATA_DIR_PATH / item.local_data_path)

return {"source": source, "mount_location": mount_location}


def migrate_config_data(config_data: dict) -> tuple[list[dict], list[str]]:
"""Generate weights list from model_cache and external_data.

Args:
config_data: The raw config dictionary

Returns:
Tuple of (weights_list, warnings)
"""
warnings = []
weights_list = []

# Migrate model_cache if present
model_cache = config_data.get("model_cache", [])
if model_cache:
for model_dict in model_cache:
# Parse as ModelRepo to get proper types
model = ModelRepo.model_validate(dict(model_dict))

# Warn about v1 HuggingFace requiring model.py changes
if not model.use_volume and model.kind == ModelRepoSourceKind.HF:
warnings.append(
f"v1 HuggingFace repo '{model.repo_id}' migrated. "
f"You may need to update model.py to use the new mount path: "
f"{generate_mount_location_for_model(model)}"
)

weights_list.append(convert_model_repo_to_weights(model))

# Migrate external_data if present
external_data = config_data.get("external_data")
if external_data:
for item_dict in external_data:
item = ExternalDataItem.model_validate(dict(item_dict))
weights_list.append(convert_external_data_to_weights(item))

return weights_list, warnings


def dump_yaml_to_string(data) -> str:
"""Dump ruamel.yaml data to string."""
stream = io.StringIO()
yaml.dump(data, stream)
return stream.getvalue()


def show_diff(original_yaml: str, migrated_yaml: str) -> None:
"""Display a colorized diff between original and migrated configs."""
original_lines = original_yaml.splitlines(keepends=True)
migrated_lines = migrated_yaml.splitlines(keepends=True)

diff = difflib.unified_diff(
original_lines,
migrated_lines,
fromfile="config.yaml (original)",
tofile="config.yaml (migrated)",
lineterm="",
)

diff_text = "".join(diff)
if diff_text:
syntax = Syntax(diff_text, "diff", theme="monokai", line_numbers=False)
console.print(syntax)
else:
console.print("[yellow]No changes detected.[/yellow]")


@truss_cli.command()
@click.argument("target_directory", required=False, default=os.getcwd())
@common.common_options()
def migrate(target_directory: str) -> None:
"""Migrate Truss config to newer formats.

TARGET_DIRECTORY: A Truss directory. If none, use current directory.

This command checks for available config migrations and applies them
interactively. Migrations are version-gated and only shown if applicable.
"""
target_path = Path(target_directory)
config_path = target_path / "config.yaml"

if not config_path.exists():
error_console.print(f"[red]Error: No config.yaml found at {config_path}[/red]")
raise SystemExit(1)

# Load config to check for available migrations
with config_path.open() as f:
config_data = yaml.load(f)

if config_data is None:
config_data = {}

available_migrations = get_available_migrations(target_path, config_data)

if not available_migrations:
console.print(
"[yellow]No applicable migrations found. Your config is up to date.[/yellow]"
)
return

console.print(
f"[bold]Found {len(available_migrations)} applicable migration(s):[/bold]\n"
)

# Process each migration
for migration in available_migrations:
console.print(f"[bold cyan]📦 {migration.id}[/bold cyan]")
console.print(f" {migration.description}\n")

# Read the original config with ruamel.yaml to preserve comments
with config_path.open() as f:
original_yaml = f.read()

with config_path.open() as f:
current_config = yaml.load(f)

if current_config is None:
current_config = {}

# Apply migration
migrated_config, warnings = migration.apply_function(current_config)

# Generate migrated YAML string
migrated_yaml = dump_yaml_to_string(migrated_config)

# Show warnings
for warning in warnings:
console.print(f"[yellow]Warning:[/yellow] {warning}")

# Show diff
console.print("\n[bold]Proposed changes:[/bold]\n")
show_diff(original_yaml, migrated_yaml)

# Prompt for confirmation
console.print()
if not click.confirm(f"Apply migration '{migration.id}'?", default=False):
console.print(f"[yellow]Migration '{migration.id}' cancelled.[/yellow]\n")
continue

# Create backup with timestamp
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
backup_path = config_path.with_suffix(f".yaml.bak.{timestamp}")
shutil.copy(config_path, backup_path)
console.print(f"[dim]Backup created: {backup_path}[/dim]")

# Write migrated config
with config_path.open("w") as f:
yaml.dump(migrated_config, f)

record_migration_applied(target_path, migration.id, backup_path.name)

console.print(f"[green]Migration '{migration.id}' complete![/green]\n")


# For backwards compatibility with tests that import migrate_config
def migrate_config(config_dict: dict) -> tuple[dict, list[str]]:
"""Migrate model_cache and external_data to weights.

Args:
config_dict: The raw config dictionary

Returns:
Tuple of (migrated_config_dict, warnings)
"""
weights_list, warnings = migrate_config_data(config_dict)

migrated = dict(config_dict)

# Remove old keys
if "model_cache" in migrated:
del migrated["model_cache"]
if "external_data" in migrated:
del migrated["external_data"]

# Add weights if we have any
if weights_list:
migrated["weights"] = weights_list

return migrated, warnings
Loading
Loading