diff --git a/.github/workflows/test-assets.yml b/.github/workflows/test-assets.yml new file mode 100644 index 000000000000..4ae26ba5ff2d --- /dev/null +++ b/.github/workflows/test-assets.yml @@ -0,0 +1,173 @@ +name: Asset System Tests + +on: + push: + paths: + - 'app/**' + - 'tests-assets/**' + - '.github/workflows/test-assets.yml' + - 'requirements.txt' + pull_request: + branches: [master] + workflow_dispatch: + +permissions: + contents: read + +env: + PIP_DISABLE_PIP_VERSION_CHECK: '1' + PYTHONUNBUFFERED: '1' + +jobs: + sqlite: + name: SQLite (${{ matrix.sqlite_mode }}) • Python ${{ matrix.python }} + runs-on: ubuntu-latest + timeout-minutes: 40 + strategy: + fail-fast: false + matrix: + python: ['3.9', '3.12'] + sqlite_mode: ['memory', 'file'] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + + - name: Install dependencies + run: | + python -m pip install -U pip wheel + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt + pip install pytest pytest-aiohttp pytest-asyncio + + - name: Set deterministic test base dir + id: basedir + shell: bash + run: | + BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}" + echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV" + echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV" + mkdir -p "$BASE/logs" + echo "ASSETS_TEST_BASE_DIR=$BASE" + + - name: Set DB URL for SQLite + id: setdb + shell: bash + run: | + if [ "${{ matrix.sqlite_mode }}" = "memory" ]; then + echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///:memory:" >> "$GITHUB_ENV" + else + DBFILE="$RUNNER_TEMP/assets-tests.sqlite" + mkdir -p "$(dirname "$DBFILE")" + echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///$DBFILE" >> "$GITHUB_ENV" + fi + + - name: Run tests + run: python -m pytest tests-assets + + - name: Show ComfyUI logs + if: always() + shell: bash + run: | + echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ====" + echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ====" + ls -la "$ASSETS_TEST_LOGS" || true + for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do + if [ -f "$f" ]; then + echo "----- BEGIN $f -----" + sed -n '1,400p' "$f" + echo "----- END $f -----" + fi + done + + - name: Upload ComfyUI logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: asset-logs-sqlite-${{ matrix.sqlite_mode }}-py${{ matrix.python }} + path: ${{ env.ASSETS_TEST_LOGS }}/*.log + if-no-files-found: warn + + postgres: + name: PostgreSQL ${{ matrix.pgsql }} • Python ${{ matrix.python }} + runs-on: ubuntu-latest + timeout-minutes: 40 + strategy: + fail-fast: false + matrix: + python: ['3.9', '3.12'] + pgsql: ['16', '18'] + + services: + postgres: + image: postgres:${{ matrix.pgsql }} + env: + POSTGRES_DB: assets + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U postgres -d assets" + --health-interval 10s + --health-timeout 5s + --health-retries 12 + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + + - name: Install dependencies + run: | + python -m pip install -U pip wheel + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt + pip install pytest pytest-aiohttp pytest-asyncio + pip install greenlet psycopg + + - name: Set deterministic test base dir + id: basedir + shell: bash + run: | + BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}" + echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV" + echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV" + mkdir -p "$BASE/logs" + echo "ASSETS_TEST_BASE_DIR=$BASE" + + - name: Set DB URL for PostgreSQL + shell: bash + run: | + echo "ASSETS_TEST_DB_URL=postgresql+psycopg://postgres:postgres@localhost:5432/assets" >> "$GITHUB_ENV" + + - name: Run tests + run: python -m pytest tests-assets + + - name: Show ComfyUI logs + if: always() + shell: bash + run: | + echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ====" + echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ====" + ls -la "$ASSETS_TEST_LOGS" || true + for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do + if [ -f "$f" ]; then + echo "----- BEGIN $f -----" + sed -n '1,400p' "$f" + echo "----- END $f -----" + fi + done + + - name: Upload ComfyUI logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: asset-logs-pgsql-${{ matrix.pgsql }}-py${{ matrix.python }} + path: ${{ env.ASSETS_TEST_LOGS }}/*.log + if-no-files-found: warn diff --git a/alembic.ini b/alembic.ini index 12f18712f430..360efd386901 100644 --- a/alembic.ini +++ b/alembic.ini @@ -3,7 +3,7 @@ [alembic] # path to migration scripts # Use forward slashes (/) also on windows to provide an os agnostic path -script_location = alembic_db +script_location = app/alembic_db # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s # Uncomment the line below if you want the files to be prepended with date and time diff --git a/alembic_db/README.md b/app/alembic_db/README.md similarity index 100% rename from alembic_db/README.md rename to app/alembic_db/README.md diff --git a/alembic_db/env.py b/app/alembic_db/env.py similarity index 97% rename from alembic_db/env.py rename to app/alembic_db/env.py index 4d7770679875..44f4e1a0c9e3 100644 --- a/alembic_db/env.py +++ b/app/alembic_db/env.py @@ -2,13 +2,12 @@ from sqlalchemy import pool from alembic import context +from app.assets.database.models import Base # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config - -from app.database.models import Base target_metadata = Base.metadata # other values from the config, defined by the needs of env.py, diff --git a/alembic_db/script.py.mako b/app/alembic_db/script.py.mako similarity index 100% rename from alembic_db/script.py.mako rename to app/alembic_db/script.py.mako diff --git a/app/alembic_db/versions/0001_assets.py b/app/alembic_db/versions/0001_assets.py new file mode 100644 index 000000000000..589b22ac8a49 --- /dev/null +++ b/app/alembic_db/versions/0001_assets.py @@ -0,0 +1,175 @@ +"""initial assets schema + +Revision ID: 0001_assets +Revises: +Create Date: 2025-08-20 00:00:00 +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision = "0001_assets" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ASSETS: content identity + op.create_table( + "assets", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("hash", sa.String(length=256), nullable=True), + sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), + sa.Column("mime_type", sa.String(length=255), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), + ) + op.create_index("uq_assets_hash", "assets", ["hash"], unique=True) + op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) + + # ASSETS_INFO: user-visible references + op.create_table( + "assets_info", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + ) + op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) + op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"]) + op.create_index("ix_assets_info_name", "assets_info", ["name"]) + op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) + op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"]) + + # TAGS: normalized tag vocabulary + op.create_table( + "tags", + sa.Column("name", sa.String(length=512), primary_key=True), + sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"), + sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"), + ) + op.create_index("ix_tags_tag_type", "tags", ["tag_type"]) + + # ASSET_INFO_TAGS: many-to-many for tags on AssetInfo + op.create_table( + "asset_info_tags", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), + ) + op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) + op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) + + # ASSET_CACHE_STATE: N:1 local cache rows per Asset + op.create_table( + "asset_cache_state", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False), + sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) + op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"]) + + # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting + op.create_table( + "asset_info_meta", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON().with_variant(postgresql.JSONB(), 'postgresql'), nullable=True), + sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"), + ) + op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"]) + op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"]) + op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) + op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) + + # Tags vocabulary + tags_table = sa.table( + "tags", + sa.column("name", sa.String(length=512)), + sa.column("tag_type", sa.String()), + ) + op.bulk_insert( + tags_table, + [ + {"name": "models", "tag_type": "system"}, + {"name": "input", "tag_type": "system"}, + {"name": "output", "tag_type": "system"}, + + {"name": "configs", "tag_type": "system"}, + {"name": "checkpoints", "tag_type": "system"}, + {"name": "loras", "tag_type": "system"}, + {"name": "vae", "tag_type": "system"}, + {"name": "text_encoders", "tag_type": "system"}, + {"name": "diffusion_models", "tag_type": "system"}, + {"name": "clip_vision", "tag_type": "system"}, + {"name": "style_models", "tag_type": "system"}, + {"name": "embeddings", "tag_type": "system"}, + {"name": "diffusers", "tag_type": "system"}, + {"name": "vae_approx", "tag_type": "system"}, + {"name": "controlnet", "tag_type": "system"}, + {"name": "gligen", "tag_type": "system"}, + {"name": "upscale_models", "tag_type": "system"}, + {"name": "hypernetworks", "tag_type": "system"}, + {"name": "photomaker", "tag_type": "system"}, + {"name": "classifiers", "tag_type": "system"}, + + {"name": "encoder", "tag_type": "system"}, + {"name": "decoder", "tag_type": "system"}, + + {"name": "missing", "tag_type": "system"}, + {"name": "rescan", "tag_type": "system"}, + ], + ) + + +def downgrade() -> None: + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") + op.drop_table("asset_info_meta") + + op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state") + op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_table("asset_cache_state") + + op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") + op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") + op.drop_table("asset_info_tags") + + op.drop_index("ix_tags_tag_type", table_name="tags") + op.drop_table("tags") + + op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") + op.drop_index("ix_assets_info_created_at", table_name="assets_info") + op.drop_index("ix_assets_info_name", table_name="assets_info") + op.drop_index("ix_assets_info_asset_id", table_name="assets_info") + op.drop_index("ix_assets_info_owner_id", table_name="assets_info") + op.drop_table("assets_info") + + op.drop_index("uq_assets_hash", table_name="assets") + op.drop_index("ix_assets_mime_type", table_name="assets") + op.drop_table("assets") diff --git a/app/assets/__init__.py b/app/assets/__init__.py new file mode 100644 index 000000000000..28020a2935fb --- /dev/null +++ b/app/assets/__init__.py @@ -0,0 +1,4 @@ +from .api.routes import register_assets_system +from .scanner import sync_seed_assets + +__all__ = ["sync_seed_assets", "register_assets_system"] diff --git a/app/assets/_helpers.py b/app/assets/_helpers.py new file mode 100644 index 000000000000..59141e99707e --- /dev/null +++ b/app/assets/_helpers.py @@ -0,0 +1,225 @@ +import contextlib +import os +import uuid +from datetime import datetime, timezone +from pathlib import Path +from typing import Literal, Optional, Sequence + +import folder_paths + +from .api import schemas_in + + +def get_comfy_models_folders() -> list[tuple[str, list[str]]]: + """Build a list of (folder_name, base_paths[]) categories that are configured for model locations. + + We trust `folder_paths.folder_names_and_paths` and include a category if + *any* of its base paths lies under the Comfy `models_dir`. + """ + targets: list[tuple[str, list[str]]] = [] + models_root = os.path.abspath(folder_paths.models_dir) + for name, (paths, _exts) in folder_paths.folder_names_and_paths.items(): + if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): + targets.append((name, paths)) + return targets + + +def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]: + """Given an absolute or relative file path, determine which root category the path belongs to: + - 'input' if the file resides under `folder_paths.get_input_directory()` + - 'output' if the file resides under `folder_paths.get_output_directory()` + - 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()` + + Returns: + (root_category, relative_path_inside_that_root) + For 'models', the relative path is prefixed with the category name: + e.g. ('models', 'vae/test/sub/ae.safetensors') + + Raises: + ValueError: if the path does not belong to input, output, or configured model bases. + """ + fp_abs = os.path.abspath(file_path) + + def _is_within(child: str, parent: str) -> bool: + try: + return os.path.commonpath([child, parent]) == parent + except Exception: + return False + + def _rel(child: str, parent: str) -> str: + return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep) + + # 1) input + input_base = os.path.abspath(folder_paths.get_input_directory()) + if _is_within(fp_abs, input_base): + return "input", _rel(fp_abs, input_base) + + # 2) output + output_base = os.path.abspath(folder_paths.get_output_directory()) + if _is_within(fp_abs, output_base): + return "output", _rel(fp_abs, output_base) + + # 3) models (check deepest matching base to avoid ambiguity) + best: Optional[tuple[int, str, str]] = None # (base_len, bucket, rel_inside_bucket) + for bucket, bases in get_comfy_models_folders(): + for b in bases: + base_abs = os.path.abspath(b) + if not _is_within(fp_abs, base_abs): + continue + cand = (len(base_abs), bucket, _rel(fp_abs, base_abs)) + if best is None or cand[0] > best[0]: + best = cand + + if best is not None: + _, bucket, rel_inside = best + combined = os.path.join(bucket, rel_inside) + return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) + + raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}") + + +def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: + """Return a tuple (name, tags) derived from a filesystem path. + + Semantics: + - Root category is determined by `get_relative_to_root_category_path_of_asset`. + - The returned `name` is the base filename with extension from the relative path. + - The returned `tags` are: + [root_category] + parent folders of the relative path (in order) + For 'models', this means: + file '/.../ModelsDir/vae/test_tag/ae.safetensors' + -> root_category='models', some_path='vae/test_tag/ae.safetensors' + -> name='ae.safetensors', tags=['models', 'vae', 'test_tag'] + + Raises: + ValueError: if the path does not belong to input, output, or configured model bases. + """ + root_category, some_path = get_relative_to_root_category_path_of_asset(file_path) + p = Path(some_path) + parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] + return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) + + +def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]: + return [t.strip().lower() for t in (tags or []) if (t or "").strip()] + + +def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: + """Validates and maps tags -> (base_dir, subdirs_for_fs)""" + root = tags[0] + if root == "models": + if len(tags) < 2: + raise ValueError("at least two tags required for model asset") + try: + bases = folder_paths.folder_names_and_paths[tags[1]][0] + except KeyError: + raise ValueError(f"unknown model category '{tags[1]}'") + if not bases: + raise ValueError(f"no base path configured for category '{tags[1]}'") + base_dir = os.path.abspath(bases[0]) + raw_subdirs = tags[2:] + else: + base_dir = os.path.abspath( + folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory() + ) + raw_subdirs = tags[1:] + for i in raw_subdirs: + if i in (".", ".."): + raise ValueError("invalid path component in tags") + + return base_dir, raw_subdirs if raw_subdirs else [] + + +def ensure_within_base(candidate: str, base: str) -> None: + cand_abs = os.path.abspath(candidate) + base_abs = os.path.abspath(base) + try: + if os.path.commonpath([cand_abs, base_abs]) != base_abs: + raise ValueError("destination escapes base directory") + except Exception: + raise ValueError("invalid destination path") + + +def compute_relative_filename(file_path: str) -> Optional[str]: + """ + Return the model's path relative to the last well-known folder (the model category), + using forward slashes, eg: + /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" + /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" + + For non-model paths, returns None. + NOTE: this is a temporary helper, used only for initializing metadata["filename"] field. + """ + try: + root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path) + except ValueError: + return None + + p = Path(rel_path) + parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] + if not parts: + return None + + if root_category == "models": + # parts[0] is the category ("checkpoints", "vae", etc) – drop it + inside = parts[1:] if len(parts) > 1 else [parts[0]] + return "/".join(inside) + return "/".join(parts) # input/output: keep all parts + + +def list_tree(base_dir: str) -> list[str]: + out: list[str] = [] + base_abs = os.path.abspath(base_dir) + if not os.path.isdir(base_abs): + return out + for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): + for name in filenames: + out.append(os.path.abspath(os.path.join(dirpath, name))) + return out + + +def prefixes_for_root(root: schemas_in.RootType) -> list[str]: + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + return [os.path.abspath(p) for p in bases] + if root == "input": + return [os.path.abspath(folder_paths.get_input_directory())] + if root == "output": + return [os.path.abspath(folder_paths.get_output_directory())] + return [] + + +def ts_to_iso(ts: Optional[float]) -> Optional[str]: + if ts is None: + return None + try: + return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat() + except Exception: + return None + + +def new_scan_id(root: schemas_in.RootType) -> str: + return f"scan-{root}-{uuid.uuid4().hex[:8]}" + + +def collect_models_files() -> list[str]: + out: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): + rel_files = folder_paths.get_filename_list(folder_name) or [] + for rel_path in rel_files: + abs_path = folder_paths.get_full_path(folder_name, rel_path) + if not abs_path: + continue + abs_path = os.path.abspath(abs_path) + allowed = False + for b in bases: + base_abs = os.path.abspath(b) + with contextlib.suppress(Exception): + if os.path.commonpath([abs_path, base_abs]) == base_abs: + allowed = True + break + if allowed: + out.append(abs_path) + return out diff --git a/app/__init__.py b/app/assets/api/__init__.py similarity index 100% rename from app/__init__.py rename to app/assets/api/__init__.py diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py new file mode 100644 index 000000000000..4ca7467750fd --- /dev/null +++ b/app/assets/api/routes.py @@ -0,0 +1,544 @@ +import contextlib +import logging +import os +import urllib.parse +import uuid +from typing import Optional + +from aiohttp import web +from pydantic import ValidationError + +import folder_paths + +from ... import user_manager +from .. import manager, scanner +from . import schemas_in, schemas_out + +ROUTES = web.RouteTableDef() +USER_MANAGER: Optional[user_manager.UserManager] = None +LOGGER = logging.getLogger(__name__) + +# UUID regex (canonical hyphenated form, case-insensitive) +UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + + +@ROUTES.head("/api/assets/hash/{hash}") +async def head_asset_by_hash(request: web.Request) -> web.Response: + hash_str = request.match_info.get("hash", "").strip().lower() + if not hash_str or ":" not in hash_str: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + algo, digest = hash_str.split(":", 1) + if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + exists = await manager.asset_exists(asset_hash=hash_str) + return web.Response(status=200 if exists else 404) + + +@ROUTES.get("/api/assets") +async def list_assets(request: web.Request) -> web.Response: + qp = request.rel_url.query + query_dict = {} + if "include_tags" in qp: + query_dict["include_tags"] = qp.getall("include_tags") + if "exclude_tags" in qp: + query_dict["exclude_tags"] = qp.getall("exclude_tags") + for k in ("name_contains", "metadata_filter", "limit", "offset", "sort", "order"): + v = qp.get(k) + if v is not None: + query_dict[k] = v + + try: + q = schemas_in.ListAssetsQuery.model_validate(query_dict) + except ValidationError as ve: + return _validation_error_response("INVALID_QUERY", ve) + + payload = await manager.list_assets( + include_tags=q.include_tags, + exclude_tags=q.exclude_tags, + name_contains=q.name_contains, + metadata_filter=q.metadata_filter, + limit=q.limit, + offset=q.offset, + sort=q.sort, + order=q.order, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + return web.json_response(payload.model_dump(mode="json")) + + +@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") +async def download_asset_content(request: web.Request) -> web.Response: + disposition = request.query.get("disposition", "attachment").lower().strip() + if disposition not in {"inline", "attachment"}: + disposition = "attachment" + + try: + abs_path, content_type, filename = await manager.resolve_asset_content_for_download( + asset_info_id=str(uuid.UUID(request.match_info["id"])), + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve)) + except NotImplementedError as nie: + return _error_response(501, "BACKEND_UNSUPPORTED", str(nie)) + except FileNotFoundError: + return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.") + + quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'") + cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}' + + resp = web.FileResponse(abs_path) + resp.content_type = content_type + resp.headers["Content-Disposition"] = cd + return resp + + +@ROUTES.post("/api/assets/from-hash") +async def create_asset_from_hash(request: web.Request) -> web.Response: + try: + payload = await request.json() + body = schemas_in.CreateFromHashBody.model_validate(payload) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + result = await manager.create_asset_from_hash( + hash_str=body.hash, + name=body.name, + tags=body.tags, + user_metadata=body.user_metadata, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + if result is None: + return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") + return web.json_response(result.model_dump(mode="json"), status=201) + + +@ROUTES.post("/api/assets") +async def upload_asset(request: web.Request) -> web.Response: + """Multipart/form-data endpoint for Asset uploads.""" + + if not (request.content_type or "").lower().startswith("multipart/"): + return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.") + + reader = await request.multipart() + + file_present = False + file_client_name: Optional[str] = None + tags_raw: list[str] = [] + provided_name: Optional[str] = None + user_metadata_raw: Optional[str] = None + provided_hash: Optional[str] = None + provided_hash_exists: Optional[bool] = None + + file_written = 0 + tmp_path: Optional[str] = None + while True: + field = await reader.next() + if field is None: + break + + fname = getattr(field, "name", "") or "" + + if fname == "hash": + try: + s = ((await field.text()) or "").strip().lower() + except Exception: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + + if s: + if ":" not in s: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + provided_hash = f"{algo}:{digest}" + try: + provided_hash_exists = await manager.asset_exists(asset_hash=provided_hash) + except Exception: + provided_hash_exists = None # do not fail the whole request here + + elif fname == "file": + file_present = True + file_client_name = (field.filename or "").strip() + + if provided_hash and provided_hash_exists is True: + # If client supplied a hash that we know exists, drain but do not write to disk + try: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + file_written += len(chunk) + except Exception: + return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.") + continue # Do not create temp file; we will create AssetInfo from the existing content + + # Otherwise, store to temp for hashing/ingest + uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") + unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) + os.makedirs(unique_dir, exist_ok=True) + tmp_path = os.path.join(unique_dir, ".upload.part") + + try: + with open(tmp_path, "wb") as f: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + f.write(chunk) + file_written += len(chunk) + except Exception: + try: + if os.path.exists(tmp_path or ""): + os.remove(tmp_path) + finally: + return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.") + elif fname == "tags": + tags_raw.append((await field.text()) or "") + elif fname == "name": + provided_name = (await field.text()) or None + elif fname == "user_metadata": + user_metadata_raw = (await field.text()) or None + + # If client did not send file, and we are not doing a from-hash fast path -> error + if not file_present and not (provided_hash and provided_hash_exists): + return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.") + + if file_present and file_written == 0 and not (provided_hash and provided_hash_exists): + # Empty upload is only acceptable if we are fast-pathing from existing hash + try: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + finally: + return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.") + + try: + spec = schemas_in.UploadAssetSpec.model_validate({ + "tags": tags_raw, + "name": provided_name, + "user_metadata": user_metadata_raw, + "hash": provided_hash, + }) + except ValidationError as ve: + try: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + finally: + return _validation_error_response("INVALID_BODY", ve) + + # Validate models category against configured folders (consistent with previous behavior) + if spec.tags and spec.tags[0] == "models": + if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + return _error_response( + 400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'" + ) + + owner_id = USER_MANAGER.get_request_user_id(request) + + # Fast path: if a valid provided hash exists, create AssetInfo without writing anything + if spec.hash and provided_hash_exists is True: + try: + result = await manager.create_asset_from_hash( + hash_str=spec.hash, + name=spec.name or (spec.hash.split(":", 1)[1]), + tags=spec.tags, + user_metadata=spec.user_metadata or {}, + owner_id=owner_id, + ) + except Exception: + LOGGER.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + if result is None: + return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist") + + # Drain temp if we accidentally saved (e.g., hash field came after file) + if tmp_path and os.path.exists(tmp_path): + with contextlib.suppress(Exception): + os.remove(tmp_path) + + status = 200 if (not result.created_new) else 201 + return web.json_response(result.model_dump(mode="json"), status=status) + + # Otherwise, we must have a temp file path to ingest + if not tmp_path or not os.path.exists(tmp_path): + # The only case we reach here without a temp file is: client sent a hash that does not exist and no file + return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.") + + try: + created = await manager.upload_asset_from_temp_path( + spec, + temp_path=tmp_path, + client_filename=file_client_name, + owner_id=owner_id, + expected_asset_hash=spec.hash, + ) + status = 201 if created.created_new else 200 + return web.json_response(created.model_dump(mode="json"), status=status) + except ValueError as e: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + msg = str(e) + if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH": + return _error_response( + 400, + "HASH_MISMATCH", + "Uploaded file hash does not match provided hash.", + ) + return _error_response(400, "BAD_REQUEST", "Invalid inputs.") + except Exception: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + LOGGER.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + +@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") +async def get_asset(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + result = await manager.get_asset( + asset_info_id=asset_info_id, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + LOGGER.exception( + "get_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") +async def update_asset(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + body = schemas_in.UpdateAssetBody.model_validate(await request.json()) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = await manager.update_asset( + asset_info_id=asset_info_id, + name=body.name, + tags=body.tags, + user_metadata=body.user_metadata, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except (ValueError, PermissionError) as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + LOGGER.exception( + "update_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview") +async def set_asset_preview(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + body = schemas_in.SetPreviewBody.model_validate(await request.json()) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = await manager.set_asset_preview( + asset_info_id=asset_info_id, + preview_asset_id=body.preview_id, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except (PermissionError, ValueError) as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + LOGGER.exception( + "set_asset_preview failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") +async def delete_asset(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + delete_content = request.query.get("delete_content") + delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"} + + try: + deleted = await manager.delete_asset_reference( + asset_info_id=asset_info_id, + owner_id=USER_MANAGER.get_request_user_id(request), + delete_content_if_orphan=delete_content, + ) + except Exception: + LOGGER.exception( + "delete_asset_reference failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + if not deleted: + return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.") + return web.Response(status=204) + + +@ROUTES.get("/api/tags") +async def get_tags(request: web.Request) -> web.Response: + query_map = dict(request.rel_url.query) + + try: + query = schemas_in.TagsListQuery.model_validate(query_map) + except ValidationError as ve: + return web.json_response( + {"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": ve.errors()}}, + status=400, + ) + + result = await manager.list_tags( + prefix=query.prefix, + limit=query.limit, + offset=query.offset, + order=query.order, + include_zero=query.include_zero, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + return web.json_response(result.model_dump(mode="json")) + + +@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") +async def add_asset_tags(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + payload = await request.json() + data = schemas_in.TagsAdd.model_validate(payload) + except ValidationError as ve: + return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()}) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = await manager.add_tags_to_asset( + asset_info_id=asset_info_id, + tags=data.tags, + origin="manual", + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except (ValueError, PermissionError) as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + LOGGER.exception( + "add_tags_to_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") +async def delete_asset_tags(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + payload = await request.json() + data = schemas_in.TagsRemove.model_validate(payload) + except ValidationError as ve: + return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()}) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = await manager.remove_tags_from_asset( + asset_info_id=asset_info_id, + tags=data.tags, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + LOGGER.exception( + "remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.post("/api/assets/scan/seed") +async def seed_assets(request: web.Request) -> web.Response: + try: + payload = await request.json() + except Exception: + payload = {} + + try: + body = schemas_in.ScheduleAssetScanBody.model_validate(payload) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + + try: + await scanner.sync_seed_assets(body.roots) + except Exception: + LOGGER.exception("sync_seed_assets failed for roots=%s", body.roots) + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response({"synced": True, "roots": body.roots}, status=200) + + +@ROUTES.post("/api/assets/scan/schedule") +async def schedule_asset_scan(request: web.Request) -> web.Response: + try: + payload = await request.json() + except Exception: + payload = {} + + try: + body = schemas_in.ScheduleAssetScanBody.model_validate(payload) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + + states = await scanner.schedule_scans(body.roots) + return web.json_response(states.model_dump(mode="json"), status=202) + + +@ROUTES.get("/api/assets/scan") +async def get_asset_scan_status(request: web.Request) -> web.Response: + root = request.query.get("root", "").strip().lower() + states = scanner.current_statuses() + if root in {"models", "input", "output"}: + states = [s for s in states.scans if s.root == root] # type: ignore + states = schemas_out.AssetScanStatusResponse(scans=states) + return web.json_response(states.model_dump(mode="json"), status=200) + + +def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: + global USER_MANAGER + USER_MANAGER = user_manager_instance + app.add_routes(ROUTES) + + +def _error_response(status: int, code: str, message: str, details: Optional[dict] = None) -> web.Response: + return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status) + + +def _validation_error_response(code: str, ve: ValidationError) -> web.Response: + return _error_response(400, code, "Validation failed.", {"errors": ve.json()}) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py new file mode 100644 index 000000000000..1469d325d8e1 --- /dev/null +++ b/app/assets/api/schemas_in.py @@ -0,0 +1,297 @@ +import json +import uuid +from typing import Any, Literal, Optional + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + conint, + field_validator, + model_validator, +) + + +class ListAssetsQuery(BaseModel): + include_tags: list[str] = Field(default_factory=list) + exclude_tags: list[str] = Field(default_factory=list) + name_contains: Optional[str] = None + + # Accept either a JSON string (query param) or a dict + metadata_filter: Optional[dict[str, Any]] = None + + limit: conint(ge=1, le=500) = 20 + offset: conint(ge=0) = 0 + + sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at" + order: Literal["asc", "desc"] = "desc" + + @field_validator("include_tags", "exclude_tags", mode="before") + @classmethod + def _split_csv_tags(cls, v): + # Accept "a,b,c" or ["a","b"] (we are liberal in what we accept) + if v is None: + return [] + if isinstance(v, str): + return [t.strip() for t in v.split(",") if t.strip()] + if isinstance(v, list): + out: list[str] = [] + for item in v: + if isinstance(item, str): + out.extend([t.strip() for t in item.split(",") if t.strip()]) + return out + return v + + @field_validator("metadata_filter", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v + if isinstance(v, str) and v.strip(): + try: + parsed = json.loads(v) + except Exception as e: + raise ValueError(f"metadata_filter must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("metadata_filter must be a JSON object") + return parsed + return None + + +class UpdateAssetBody(BaseModel): + name: Optional[str] = None + tags: Optional[list[str]] = None + user_metadata: Optional[dict[str, Any]] = None + + @model_validator(mode="after") + def _at_least_one(self): + if self.name is None and self.tags is None and self.user_metadata is None: + raise ValueError("Provide at least one of: name, tags, user_metadata.") + if self.tags is not None: + if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags): + raise ValueError("Field 'tags' must be an array of strings.") + return self + + +class CreateFromHashBody(BaseModel): + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + hash: str + name: str + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + + @field_validator("hash") + @classmethod + def _require_blake3(cls, v): + s = (v or "").strip().lower() + if ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3": + raise ValueError("only canonical 'blake3:' is accepted here") + if not digest or any(c for c in digest if c not in "0123456789abcdef"): + raise ValueError("hash digest must be lowercase hex") + return s + + @field_validator("tags", mode="before") + @classmethod + def _tags_norm(cls, v): + if v is None: + return [] + if isinstance(v, list): + out = [str(t).strip().lower() for t in v if str(t).strip()] + seen = set() + dedup = [] + for t in out: + if t not in seen: + seen.add(t) + dedup.append(t) + return dedup + if isinstance(v, str): + return [t.strip().lower() for t in v.split(",") if t.strip()] + return [] + + +class TagsListQuery(BaseModel): + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + prefix: Optional[str] = Field(None, min_length=1, max_length=256) + limit: int = Field(100, ge=1, le=1000) + offset: int = Field(0, ge=0, le=10_000_000) + order: Literal["count_desc", "name_asc"] = "count_desc" + include_zero: bool = True + + @field_validator("prefix") + @classmethod + def normalize_prefix(cls, v: Optional[str]) -> Optional[str]: + if v is None: + return v + v = v.strip() + return v.lower() or None + + +class TagsAdd(BaseModel): + model_config = ConfigDict(extra="ignore") + tags: list[str] = Field(..., min_length=1) + + @field_validator("tags") + @classmethod + def normalize_tags(cls, v: list[str]) -> list[str]: + out = [] + for t in v: + if not isinstance(t, str): + raise TypeError("tags must be strings") + tnorm = t.strip().lower() + if tnorm: + out.append(tnorm) + seen = set() + deduplicated = [] + for x in out: + if x not in seen: + seen.add(x) + deduplicated.append(x) + return deduplicated + + +class TagsRemove(TagsAdd): + pass + + +RootType = Literal["models", "input", "output"] +ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") + + +class ScheduleAssetScanBody(BaseModel): + roots: list[RootType] = Field(..., min_length=1) + + +class UploadAssetSpec(BaseModel): + """Upload Asset operation. + - tags: ordered; first is root ('models'|'input'|'output'); + if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths + - name: display name + - user_metadata: arbitrary JSON object (optional) + - hash: optional canonical 'blake3:' provided by the client for validation / fast-path + + Files created via this endpoint are stored on disk using the **content hash** as the filename stem + and the original extension is preserved when available. + """ + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + tags: list[str] = Field(..., min_length=1) + name: Optional[str] = Field(default=None, max_length=512, description="Display Name") + user_metadata: dict[str, Any] = Field(default_factory=dict) + hash: Optional[str] = Field(default=None) + + @field_validator("hash", mode="before") + @classmethod + def _parse_hash(cls, v): + if v is None: + return None + s = str(v).strip().lower() + if not s: + return None + if ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3": + raise ValueError("only canonical 'blake3:' is accepted here") + if not digest or any(c for c in digest if c not in "0123456789abcdef"): + raise ValueError("hash digest must be lowercase hex") + return f"{algo}:{digest}" + + @field_validator("tags", mode="before") + @classmethod + def _parse_tags(cls, v): + """ + Accepts a list of strings (possibly multiple form fields), + where each string can be: + - JSON array (e.g., '["models","loras","foo"]') + - comma-separated ('models, loras, foo') + - single token ('models') + Returns a normalized, deduplicated, ordered list. + """ + items: list[str] = [] + if v is None: + return [] + if isinstance(v, str): + v = [v] + + if isinstance(v, list): + for item in v: + if item is None: + continue + s = str(item).strip() + if not s: + continue + if s.startswith("["): + try: + arr = json.loads(s) + if isinstance(arr, list): + items.extend(str(x) for x in arr) + continue + except Exception: + pass # fallback to CSV parse below + items.extend([p for p in s.split(",") if p.strip()]) + else: + return [] + + # normalize + dedupe + norm = [] + seen = set() + for t in items: + tnorm = str(t).strip().lower() + if tnorm and tnorm not in seen: + seen.add(tnorm) + norm.append(tnorm) + return norm + + @field_validator("user_metadata", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v or {} + if isinstance(v, str): + s = v.strip() + if not s: + return {} + try: + parsed = json.loads(s) + except Exception as e: + raise ValueError(f"user_metadata must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("user_metadata must be a JSON object") + return parsed + return {} + + @model_validator(mode="after") + def _validate_order(self): + if not self.tags: + raise ValueError("tags must be provided and non-empty") + root = self.tags[0] + if root not in {"models", "input", "output"}: + raise ValueError("first tag must be one of: models, input, output") + if root == "models": + if len(self.tags) < 2: + raise ValueError("models uploads require a category tag as the second tag") + return self + + +class SetPreviewBody(BaseModel): + """Set or clear the preview for an AssetInfo. Provide an Asset.id or null.""" + preview_id: Optional[str] = None + + @field_validator("preview_id", mode="before") + @classmethod + def _norm_uuid(cls, v): + if v is None: + return None + s = str(v).strip() + if not s: + return None + try: + uuid.UUID(s) + except Exception: + raise ValueError("preview_id must be a UUID") + return s diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py new file mode 100644 index 000000000000..cc7e9572be4a --- /dev/null +++ b/app/assets/api/schemas_out.py @@ -0,0 +1,115 @@ +from datetime import datetime +from typing import Any, Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field, field_serializer + + +class AssetSummary(BaseModel): + id: str + name: str + asset_hash: Optional[str] + size: Optional[int] = None + mime_type: Optional[str] = None + tags: list[str] = Field(default_factory=list) + preview_url: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + last_access_time: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("created_at", "updated_at", "last_access_time") + def _ser_dt(self, v: Optional[datetime], _info): + return v.isoformat() if v else None + + +class AssetsList(BaseModel): + assets: list[AssetSummary] + total: int + has_more: bool + + +class AssetUpdated(BaseModel): + id: str + name: str + asset_hash: Optional[str] + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + updated_at: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("updated_at") + def _ser_updated(self, v: Optional[datetime], _info): + return v.isoformat() if v else None + + +class AssetDetail(BaseModel): + id: str + name: str + asset_hash: Optional[str] + size: Optional[int] = None + mime_type: Optional[str] = None + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + preview_id: Optional[str] = None + created_at: Optional[datetime] = None + last_access_time: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("created_at", "last_access_time") + def _ser_dt(self, v: Optional[datetime], _info): + return v.isoformat() if v else None + + +class AssetCreated(AssetDetail): + created_new: bool + + +class TagUsage(BaseModel): + name: str + count: int + type: str + + +class TagsList(BaseModel): + tags: list[TagUsage] = Field(default_factory=list) + total: int + has_more: bool + + +class TagsAdd(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + added: list[str] = Field(default_factory=list) + already_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) + + +class TagsRemove(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + removed: list[str] = Field(default_factory=list) + not_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) + + +class AssetScanError(BaseModel): + path: str + message: str + at: Optional[str] = Field(None, description="ISO timestamp") + + +class AssetScanStatus(BaseModel): + scan_id: str + root: Literal["models", "input", "output"] + status: Literal["scheduled", "running", "completed", "failed", "cancelled"] + scheduled_at: Optional[str] = None + started_at: Optional[str] = None + finished_at: Optional[str] = None + discovered: int = 0 + processed: int = 0 + file_errors: list[AssetScanError] = Field(default_factory=list) + + +class AssetScanStatusResponse(BaseModel): + scans: list[AssetScanStatus] = Field(default_factory=list) diff --git a/app/assets/database/__init__.py b/app/assets/database/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/app/assets/database/helpers/__init__.py b/app/assets/database/helpers/__init__.py new file mode 100644 index 000000000000..9ae13cd02e61 --- /dev/null +++ b/app/assets/database/helpers/__init__.py @@ -0,0 +1,25 @@ +from .bulk_ops import seed_from_paths_batch +from .escape_like import escape_like_prefix +from .fast_check import fast_asset_file_check +from .filters import apply_metadata_filter, apply_tag_filters +from .ownership import visible_owner_clause +from .projection import is_scalar, project_kv +from .tags import ( + add_missing_tag_for_asset_id, + ensure_tags_exist, + remove_missing_tag_for_asset_id, +) + +__all__ = [ + "apply_tag_filters", + "apply_metadata_filter", + "escape_like_prefix", + "fast_asset_file_check", + "is_scalar", + "project_kv", + "ensure_tags_exist", + "add_missing_tag_for_asset_id", + "remove_missing_tag_for_asset_id", + "seed_from_paths_batch", + "visible_owner_clause", +] diff --git a/app/assets/database/helpers/bulk_ops.py b/app/assets/database/helpers/bulk_ops.py new file mode 100644 index 000000000000..feefbb5bd63b --- /dev/null +++ b/app/assets/database/helpers/bulk_ops.py @@ -0,0 +1,230 @@ +import os +import uuid +from typing import Iterable, Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as d_pg +from sqlalchemy.dialects import sqlite as d_sqlite +from sqlalchemy.ext.asyncio import AsyncSession + +from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag +from ..timeutil import utcnow + +MAX_BIND_PARAMS = 800 + + +async def seed_from_paths_batch( + session: AsyncSession, + *, + specs: Sequence[dict], + owner_id: str = "", +) -> dict: + """Each spec is a dict with keys: + - abs_path: str + - size_bytes: int + - mtime_ns: int + - info_name: str + - tags: list[str] + - fname: Optional[str] + """ + if not specs: + return {"inserted_infos": 0, "won_states": 0, "lost_states": 0} + + now = utcnow() + dialect = session.bind.dialect.name + if dialect not in ("sqlite", "postgresql"): + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + + asset_rows: list[dict] = [] + state_rows: list[dict] = [] + path_to_asset: dict[str, str] = {} + asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row + path_list: list[str] = [] + + for sp in specs: + ap = os.path.abspath(sp["abs_path"]) + aid = str(uuid.uuid4()) + iid = str(uuid.uuid4()) + path_list.append(ap) + path_to_asset[ap] = aid + + asset_rows.append( + { + "id": aid, + "hash": None, + "size_bytes": sp["size_bytes"], + "mime_type": None, + "created_at": now, + } + ) + state_rows.append( + { + "asset_id": aid, + "file_path": ap, + "mtime_ns": sp["mtime_ns"], + } + ) + asset_to_info[aid] = { + "id": iid, + "owner_id": owner_id, + "name": sp["info_name"], + "asset_id": aid, + "preview_id": None, + "user_metadata": {"filename": sp["fname"]} if sp["fname"] else None, + "created_at": now, + "updated_at": now, + "last_access_time": now, + "_tags": sp["tags"], + "_filename": sp["fname"], + } + + # insert all seed Assets (hash=NULL) + ins_asset = d_sqlite.insert(Asset) if dialect == "sqlite" else d_pg.insert(Asset) + for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)): + await session.execute(ins_asset, chunk) + + # try to claim AssetCacheState (file_path) + winners_by_path: set[str] = set() + if dialect == "sqlite": + ins_state = ( + d_sqlite.insert(AssetCacheState) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + .returning(AssetCacheState.file_path) + ) + else: + ins_state = ( + d_pg.insert(AssetCacheState) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + .returning(AssetCacheState.file_path) + ) + for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)): + winners_by_path.update((await session.execute(ins_state, chunk)).scalars().all()) + + all_paths_set = set(path_list) + losers_by_path = all_paths_set - winners_by_path + lost_assets = [path_to_asset[p] for p in losers_by_path] + if lost_assets: # losers get their Asset removed + for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS): + await session.execute(sa.delete(Asset).where(Asset.id.in_(id_chunk))) + + if not winners_by_path: + return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)} + + # insert AssetInfo only for winners + winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path] + if dialect == "sqlite": + ins_info = ( + d_sqlite.insert(AssetInfo) + .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) + .returning(AssetInfo.id) + ) + else: + ins_info = ( + d_pg.insert(AssetInfo) + .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) + .returning(AssetInfo.id) + ) + + inserted_info_ids: set[str] = set() + for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)): + inserted_info_ids.update((await session.execute(ins_info, chunk)).scalars().all()) + + # build and insert tag + meta rows for the AssetInfo + tag_rows: list[dict] = [] + meta_rows: list[dict] = [] + if inserted_info_ids: + for row in winner_info_rows: + iid = row["id"] + if iid not in inserted_info_ids: + continue + for t in row["_tags"]: + tag_rows.append({ + "asset_info_id": iid, + "tag_name": t, + "origin": "automatic", + "added_at": now, + }) + if row["_filename"]: + meta_rows.append( + { + "asset_info_id": iid, + "key": "filename", + "ordinal": 0, + "val_str": row["_filename"], + "val_num": None, + "val_bool": None, + "val_json": None, + } + ) + + await bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS) + return { + "inserted_infos": len(inserted_info_ids), + "won_states": len(winners_by_path), + "lost_states": len(losers_by_path), + } + + +async def bulk_insert_tags_and_meta( + session: AsyncSession, + *, + tag_rows: list[dict], + meta_rows: list[dict], + max_bind_params: int, +) -> None: + """Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING. + - tag_rows keys: asset_info_id, tag_name, origin, added_at + - meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json + """ + dialect = session.bind.dialect.name + if tag_rows: + if dialect == "sqlite": + ins_links = ( + d_sqlite.insert(AssetInfoTag) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + elif dialect == "postgresql": + ins_links = ( + d_pg.insert(AssetInfoTag) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params): + await session.execute(ins_links, chunk) + if meta_rows: + if dialect == "sqlite": + ins_meta = ( + d_sqlite.insert(AssetInfoMeta) + .on_conflict_do_nothing( + index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] + ) + ) + elif dialect == "postgresql": + ins_meta = ( + d_pg.insert(AssetInfoMeta) + .on_conflict_do_nothing( + index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] + ) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params): + await session.execute(ins_meta, chunk) + + +def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]: + if not rows: + return [] + rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row)) + for i in range(0, len(rows), rows_per_stmt): + yield rows[i:i + rows_per_stmt] + + +def _iter_chunks(seq, n: int): + for i in range(0, len(seq), n): + yield seq[i:i + n] + + +def _rows_per_stmt(cols: int) -> int: + return max(1, MAX_BIND_PARAMS // max(1, cols)) diff --git a/app/assets/database/helpers/escape_like.py b/app/assets/database/helpers/escape_like.py new file mode 100644 index 000000000000..f905bd40b52c --- /dev/null +++ b/app/assets/database/helpers/escape_like.py @@ -0,0 +1,7 @@ +def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]: + """Escapes %, _ and the escape char itself in a LIKE prefix. + Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like(). + """ + s = s.replace(escape, escape + escape) # escape the escape char first + s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards + return s, escape diff --git a/app/assets/database/helpers/fast_check.py b/app/assets/database/helpers/fast_check.py new file mode 100644 index 000000000000..940d6984f535 --- /dev/null +++ b/app/assets/database/helpers/fast_check.py @@ -0,0 +1,19 @@ +import os +from typing import Optional + + +def fast_asset_file_check( + *, + mtime_db: Optional[int], + size_db: Optional[int], + stat_result: os.stat_result, +) -> bool: + if mtime_db is None: + return False + actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)) + if int(mtime_db) != int(actual_mtime_ns): + return False + sz = int(size_db or 0) + if sz > 0: + return int(stat_result.st_size) == sz + return True diff --git a/app/assets/database/helpers/filters.py b/app/assets/database/helpers/filters.py new file mode 100644 index 000000000000..0edc0c66d88c --- /dev/null +++ b/app/assets/database/helpers/filters.py @@ -0,0 +1,87 @@ +from typing import Optional, Sequence + +import sqlalchemy as sa +from sqlalchemy import exists + +from ..._helpers import normalize_tags +from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag + + +def apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Optional[Sequence[str]], + exclude_tags: Optional[Sequence[str]], +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + + +def apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: Optional[dict], +) -> sa.sql.Select: + """Apply filters using asset_info_meta projection table.""" + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + return sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + *preds, + ) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + if value is None: + no_row_for_key = sa.not_( + sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + ) + ) + null_row = _exists_for_pred( + key, + AssetInfoMeta.val_json.is_(None), + AssetInfoMeta.val_str.is_(None), + AssetInfoMeta.val_num.is_(None), + AssetInfoMeta.val_bool.is_(None), + ) + return sa.or_(no_row_for_key, null_row) + + if isinstance(value, bool): + return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) + if isinstance(value, (int, float)): + from decimal import Decimal + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetInfoMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetInfoMeta.val_str == value) + return _exists_for_pred(key, AssetInfoMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt diff --git a/app/assets/database/helpers/ownership.py b/app/assets/database/helpers/ownership.py new file mode 100644 index 000000000000..c0073160831e --- /dev/null +++ b/app/assets/database/helpers/ownership.py @@ -0,0 +1,12 @@ +import sqlalchemy as sa + +from ..models import AssetInfo + + +def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" + + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetInfo.owner_id == "" + return AssetInfo.owner_id.in_(["", owner_id]) diff --git a/app/assets/database/helpers/projection.py b/app/assets/database/helpers/projection.py new file mode 100644 index 000000000000..687802d1803c --- /dev/null +++ b/app/assets/database/helpers/projection.py @@ -0,0 +1,64 @@ +from decimal import Decimal + + +def is_scalar(v): + if v is None: + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + + +def project_kv(key: str, value): + """ + Turn a metadata key/value into typed projection rows. + Returns list[dict] with keys: + key, ordinal, and one of val_str / val_num / val_bool / val_json (others None) + """ + rows: list[dict] = [] + + def _null_row(ordinal: int) -> dict: + return { + "key": key, "ordinal": ordinal, + "val_str": None, "val_num": None, "val_bool": None, "val_json": None + } + + if value is None: + rows.append(_null_row(0)) + return rows + + if is_scalar(value): + if isinstance(value, bool): + rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) + elif isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + rows.append({"key": key, "ordinal": 0, "val_num": num}) + elif isinstance(value, str): + rows.append({"key": key, "ordinal": 0, "val_str": value}) + else: + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows + + if isinstance(value, list): + if all(is_scalar(x) for x in value): + for i, x in enumerate(value): + if x is None: + rows.append(_null_row(i)) + elif isinstance(x, bool): + rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) + elif isinstance(x, (int, float, Decimal)): + num = x if isinstance(x, Decimal) else Decimal(str(x)) + rows.append({"key": key, "ordinal": i, "val_num": num}) + elif isinstance(x, str): + rows.append({"key": key, "ordinal": i, "val_str": x}) + else: + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + for i, x in enumerate(value): + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows diff --git a/app/assets/database/helpers/tags.py b/app/assets/database/helpers/tags.py new file mode 100644 index 000000000000..402dc346d430 --- /dev/null +++ b/app/assets/database/helpers/tags.py @@ -0,0 +1,90 @@ +from typing import Iterable + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as d_pg +from sqlalchemy.dialects import sqlite as d_sqlite +from sqlalchemy.ext.asyncio import AsyncSession + +from ..._helpers import normalize_tags +from ..models import AssetInfo, AssetInfoTag, Tag +from ..timeutil import utcnow + + +async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> None: + wanted = normalize_tags(list(names)) + if not wanted: + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + dialect = session.bind.dialect.name + if dialect == "sqlite": + ins = ( + d_sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + await session.execute(ins) + + +async def add_missing_tag_for_asset_id( + session: AsyncSession, + *, + asset_id: str, + origin: str = "automatic", +) -> None: + select_rows = ( + sa.select( + AssetInfo.id.label("asset_info_id"), + sa.literal("missing").label("tag_name"), + sa.literal(origin).label("origin"), + sa.literal(utcnow()).label("added_at"), + ) + .where(AssetInfo.asset_id == asset_id) + .where( + sa.not_( + sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing")) + ) + ) + ) + dialect = session.bind.dialect.name + if dialect == "sqlite": + ins = ( + d_sqlite.insert(AssetInfoTag) + .from_select( + ["asset_info_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(AssetInfoTag) + .from_select( + ["asset_info_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + await session.execute(ins) + + +async def remove_missing_tag_for_asset_id( + session: AsyncSession, + *, + asset_id: str, +) -> None: + await session.execute( + sa.delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), + AssetInfoTag.tag_name == "missing", + ) + ) diff --git a/app/assets/database/models.py b/app/assets/database/models.py new file mode 100644 index 000000000000..c6555fa61732 --- /dev/null +++ b/app/assets/database/models.py @@ -0,0 +1,251 @@ +import uuid +from datetime import datetime +from typing import Any, Optional + +from sqlalchemy import ( + JSON, + BigInteger, + Boolean, + CheckConstraint, + DateTime, + ForeignKey, + Index, + Integer, + Numeric, + String, + Text, + UniqueConstraint, +) +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import DeclarativeBase, Mapped, foreign, mapped_column, relationship + +from .timeutil import utcnow + +JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql') + + +class Base(DeclarativeBase): + pass + + +def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: + fields = obj.__table__.columns.keys() + out: dict[str, Any] = {} + for field in fields: + val = getattr(obj, field) + if val is None and not include_none: + continue + if isinstance(val, datetime): + out[field] = val.isoformat() + else: + out[field] = val + return out + + +class Asset(Base): + __tablename__ = "assets" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True) + size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + mime_type: Mapped[Optional[str]] = mapped_column(String(255)) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=utcnow + ) + + infos: Mapped[list["AssetInfo"]] = relationship( + "AssetInfo", + back_populates="asset", + primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id), + foreign_keys=lambda: [AssetInfo.asset_id], + cascade="all,delete-orphan", + passive_deletes=True, + ) + + preview_of: Mapped[list["AssetInfo"]] = relationship( + "AssetInfo", + back_populates="preview_asset", + primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id), + foreign_keys=lambda: [AssetInfo.preview_id], + viewonly=True, + ) + + cache_states: Mapped[list["AssetCacheState"]] = relationship( + back_populates="asset", + cascade="all, delete-orphan", + passive_deletes=True, + ) + + __table_args__ = ( + Index("uq_assets_hash", "hash", unique=True), + Index("ix_assets_mime_type", "mime_type"), + CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + return to_dict(self, include_none=include_none) + + def __repr__(self) -> str: + return f"" + + +class AssetCacheState(Base): + __tablename__ = "asset_cache_state" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False) + file_path: Mapped[str] = mapped_column(Text, nullable=False) + mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True) + needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + + asset: Mapped["Asset"] = relationship(back_populates="cache_states") + + __table_args__ = ( + Index("ix_asset_cache_state_file_path", "file_path"), + Index("ix_asset_cache_state_asset_id", "asset_id"), + CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + return to_dict(self, include_none=include_none) + + def __repr__(self) -> str: + return f"" + + +class AssetInfo(Base): + __tablename__ = "assets_info" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") + name: Mapped[str] = mapped_column(String(512), nullable=False) + asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False) + preview_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL")) + user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True)) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + + asset: Mapped[Asset] = relationship( + "Asset", + back_populates="infos", + foreign_keys=[asset_id], + lazy="selectin", + ) + preview_asset: Mapped[Optional[Asset]] = relationship( + "Asset", + back_populates="preview_of", + foreign_keys=[preview_id], + ) + + metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship( + back_populates="asset_info", + cascade="all,delete-orphan", + passive_deletes=True, + ) + + tag_links: Mapped[list["AssetInfoTag"]] = relationship( + back_populates="asset_info", + cascade="all,delete-orphan", + passive_deletes=True, + overlaps="tags,asset_infos", + ) + + tags: Mapped[list["Tag"]] = relationship( + secondary="asset_info_tags", + back_populates="asset_infos", + lazy="selectin", + viewonly=True, + overlaps="tag_links,asset_info_links,asset_infos,tag", + ) + + __table_args__ = ( + UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + Index("ix_assets_info_owner_name", "owner_id", "name"), + Index("ix_assets_info_owner_id", "owner_id"), + Index("ix_assets_info_asset_id", "asset_id"), + Index("ix_assets_info_name", "name"), + Index("ix_assets_info_created_at", "created_at"), + Index("ix_assets_info_last_access_time", "last_access_time"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + data = to_dict(self, include_none=include_none) + data["tags"] = [t.name for t in self.tags] + return data + + def __repr__(self) -> str: + return f"" + + +class AssetInfoMeta(Base): + __tablename__ = "asset_info_meta" + + asset_info_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + ) + key: Mapped[str] = mapped_column(String(256), primary_key=True) + ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0) + + val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) + val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True) + val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True) + val_json: Mapped[Optional[Any]] = mapped_column(JSONB_V, nullable=True) + + asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries") + + __table_args__ = ( + Index("ix_asset_info_meta_key", "key"), + Index("ix_asset_info_meta_key_val_str", "key", "val_str"), + Index("ix_asset_info_meta_key_val_num", "key", "val_num"), + Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"), + ) + + +class AssetInfoTag(Base): + __tablename__ = "asset_info_tags" + + asset_info_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + ) + tag_name: Mapped[str] = mapped_column( + String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True + ) + origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") + added_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=utcnow + ) + + asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links") + tag: Mapped["Tag"] = relationship(back_populates="asset_info_links") + + __table_args__ = ( + Index("ix_asset_info_tags_tag_name", "tag_name"), + Index("ix_asset_info_tags_asset_info_id", "asset_info_id"), + ) + + +class Tag(Base): + __tablename__ = "tags" + + name: Mapped[str] = mapped_column(String(512), primary_key=True) + tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") + + asset_info_links: Mapped[list["AssetInfoTag"]] = relationship( + back_populates="tag", + overlaps="asset_infos,tags", + ) + asset_infos: Mapped[list["AssetInfo"]] = relationship( + secondary="asset_info_tags", + back_populates="tags", + viewonly=True, + overlaps="asset_info_links,tag_links,tags,asset_info", + ) + + __table_args__ = ( + Index("ix_tags_tag_type", "tag_type"), + ) + + def __repr__(self) -> str: + return f"" diff --git a/app/assets/database/services/__init__.py b/app/assets/database/services/__init__.py new file mode 100644 index 000000000000..6c6f26e514d7 --- /dev/null +++ b/app/assets/database/services/__init__.py @@ -0,0 +1,57 @@ +from .content import ( + check_fs_asset_exists_quick, + compute_hash_and_dedup_for_cache_state, + ingest_fs_asset, + list_cache_states_with_asset_under_prefixes, + list_unhashed_candidates_under_prefixes, + list_verify_candidates_under_prefixes, + redirect_all_references_then_delete_asset, + touch_asset_infos_by_fs_path, +) +from .info import ( + add_tags_to_asset_info, + create_asset_info_for_existing_asset, + delete_asset_info_by_id, + fetch_asset_info_and_asset, + fetch_asset_info_asset_and_tags, + get_asset_tags, + list_asset_infos_page, + list_tags_with_usage, + remove_tags_from_asset_info, + replace_asset_info_metadata_projection, + set_asset_info_preview, + set_asset_info_tags, + touch_asset_info_by_id, + update_asset_info_full, +) +from .queries import ( + asset_exists_by_hash, + asset_info_exists_for_asset_id, + get_asset_by_hash, + get_asset_info_by_id, + get_cache_state_by_asset_id, + list_cache_states_by_asset_id, + pick_best_live_path, +) + +__all__ = [ + # queries + "asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id", + "get_cache_state_by_asset_id", + "list_cache_states_by_asset_id", + "pick_best_live_path", + # info + "list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags", + "update_asset_info_full", "replace_asset_info_metadata_projection", + "touch_asset_info_by_id", "delete_asset_info_by_id", + "add_tags_to_asset_info", "remove_tags_from_asset_info", + "get_asset_tags", "list_tags_with_usage", "set_asset_info_preview", + "fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags", + # content + "check_fs_asset_exists_quick", + "redirect_all_references_then_delete_asset", + "compute_hash_and_dedup_for_cache_state", + "list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes", + "ingest_fs_asset", "touch_asset_infos_by_fs_path", + "list_cache_states_with_asset_under_prefixes", +] diff --git a/app/assets/database/services/content.py b/app/assets/database/services/content.py new file mode 100644 index 000000000000..864c190442cb --- /dev/null +++ b/app/assets/database/services/content.py @@ -0,0 +1,721 @@ +import contextlib +import logging +import os +from datetime import datetime +from typing import Any, Optional, Sequence, Union + +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.dialects import postgresql as d_pg +from sqlalchemy.dialects import sqlite as d_sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import noload + +from ..._helpers import compute_relative_filename +from ...storage import hashing as hashing_mod +from ..helpers import ( + ensure_tags_exist, + escape_like_prefix, + remove_missing_tag_for_asset_id, +) +from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag +from ..timeutil import utcnow +from .info import replace_asset_info_metadata_projection +from .queries import list_cache_states_by_asset_id, pick_best_live_path + + +async def check_fs_asset_exists_quick( + session: AsyncSession, + *, + file_path: str, + size_bytes: Optional[int] = None, + mtime_ns: Optional[int] = None, +) -> bool: + """Returns True if we already track this absolute path with a HASHED asset and the cached mtime/size match.""" + locator = os.path.abspath(file_path) + + stmt = ( + sa.select(sa.literal(True)) + .select_from(AssetCacheState) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where( + AssetCacheState.file_path == locator, + Asset.hash.isnot(None), + AssetCacheState.needs_verify.is_(False), + ) + .limit(1) + ) + + conds = [] + if mtime_ns is not None: + conds.append(AssetCacheState.mtime_ns == int(mtime_ns)) + if size_bytes is not None: + conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes))) + if conds: + stmt = stmt.where(*conds) + return (await session.execute(stmt)).first() is not None + + +async def redirect_all_references_then_delete_asset( + session: AsyncSession, + *, + duplicate_asset_id: str, + canonical_asset_id: str, +) -> None: + """ + Safely migrate all references from duplicate_asset_id to canonical_asset_id. + + - If an AssetInfo for (owner_id, name) already exists on the canonical asset, + merge tags, metadata, times, and preview, then delete the duplicate AssetInfo. + - Otherwise, simply repoint the AssetInfo.asset_id. + - Always retarget AssetCacheState rows. + - Finally delete the duplicate Asset row. + """ + if duplicate_asset_id == canonical_asset_id: + return + + # 1) Migrate AssetInfo rows one-by-one to avoid UNIQUE conflicts. + dup_infos = ( + await session.execute( + select(AssetInfo).options(noload(AssetInfo.tags)).where(AssetInfo.asset_id == duplicate_asset_id) + ) + ).unique().scalars().all() + + for info in dup_infos: + # Try to find an existing collision on canonical + existing = ( + await session.execute( + select(AssetInfo) + .options(noload(AssetInfo.tags)) + .where( + AssetInfo.asset_id == canonical_asset_id, + AssetInfo.owner_id == info.owner_id, + AssetInfo.name == info.name, + ) + .limit(1) + ) + ).unique().scalars().first() + + if existing: + merged_meta = dict(existing.user_metadata or {}) + other_meta = info.user_metadata or {} + for k, v in other_meta.items(): + if k not in merged_meta: + merged_meta[k] = v + if merged_meta != (existing.user_metadata or {}): + await replace_asset_info_metadata_projection( + session, + asset_info_id=existing.id, + user_metadata=merged_meta, + ) + + existing_tags = { + t for (t,) in ( + await session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == existing.id) + ) + ).all() + } + from_tags = { + t for (t,) in ( + await session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == info.id) + ) + ).all() + } + to_add = sorted(from_tags - existing_tags) + if to_add: + await ensure_tags_exist(session, to_add, tag_type="user") + now = utcnow() + session.add_all([ + AssetInfoTag(asset_info_id=existing.id, tag_name=t, origin="automatic", added_at=now) + for t in to_add + ]) + await session.flush() + + if existing.preview_id is None and info.preview_id is not None: + existing.preview_id = info.preview_id + if info.last_access_time and ( + existing.last_access_time is None or info.last_access_time > existing.last_access_time + ): + existing.last_access_time = info.last_access_time + existing.updated_at = utcnow() + await session.flush() + + # Delete the duplicate AssetInfo (cascades will clean its tags/meta) + await session.delete(info) + await session.flush() + else: + # Simple retarget + info.asset_id = canonical_asset_id + info.updated_at = utcnow() + await session.flush() + + # 2) Repoint cache states and previews + await session.execute( + sa.update(AssetCacheState) + .where(AssetCacheState.asset_id == duplicate_asset_id) + .values(asset_id=canonical_asset_id) + ) + await session.execute( + sa.update(AssetInfo) + .where(AssetInfo.preview_id == duplicate_asset_id) + .values(preview_id=canonical_asset_id) + ) + + # 3) Remove duplicate Asset + dup = await session.get(Asset, duplicate_asset_id) + if dup: + await session.delete(dup) + await session.flush() + + +async def compute_hash_and_dedup_for_cache_state( + session: AsyncSession, + *, + state_id: int, +) -> Optional[str]: + """ + Compute hash for the given cache state, deduplicate, and settle verify cases. + + Returns the asset_id that this state ends up pointing to, or None if file disappeared. + """ + state = await session.get(AssetCacheState, state_id) + if not state: + return None + + path = state.file_path + try: + if not os.path.isfile(path): + # File vanished: drop the state. If the Asset has hash=NULL and has no other states, drop the Asset too. + asset = await session.get(Asset, state.asset_id) + await session.delete(state) + await session.flush() + + if asset and asset.hash is None: + remaining = ( + await session.execute( + sa.select(sa.func.count()) + .select_from(AssetCacheState) + .where(AssetCacheState.asset_id == asset.id) + ) + ).scalar_one() + if int(remaining or 0) == 0: + await session.delete(asset) + await session.flush() + else: + await _recompute_and_apply_filename_for_asset(session, asset_id=asset.id) + return None + + digest = await hashing_mod.blake3_hash(path) + new_hash = f"blake3:{digest}" + + st = os.stat(path, follow_symlinks=True) + new_size = int(st.st_size) + mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + + # Current asset of this state + this_asset = await session.get(Asset, state.asset_id) + + # If the state got orphaned somehow (race), just reattach appropriately. + if not this_asset: + canonical = ( + await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1)) + ).scalars().first() + if canonical: + state.asset_id = canonical.id + else: + now = utcnow() + new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now) + session.add(new_asset) + await session.flush() + state.asset_id = new_asset.id + state.mtime_ns = mtime_ns + state.needs_verify = False + with contextlib.suppress(Exception): + await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id) + await session.flush() + return state.asset_id + + # 1) Seed asset case (hash is NULL): claim or merge into canonical + if this_asset.hash is None: + canonical = ( + await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1)) + ).scalars().first() + + if canonical and canonical.id != this_asset.id: + # Merge seed asset into canonical (safe, collision-aware) + await redirect_all_references_then_delete_asset( + session, + duplicate_asset_id=this_asset.id, + canonical_asset_id=canonical.id, + ) + state = await session.get(AssetCacheState, state_id) + if state: + state.mtime_ns = mtime_ns + state.needs_verify = False + with contextlib.suppress(Exception): + await remove_missing_tag_for_asset_id(session, asset_id=canonical.id) + await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id) + await session.flush() + return canonical.id + + # No canonical: try to claim the hash; handle races with a SAVEPOINT + try: + async with session.begin_nested(): + this_asset.hash = new_hash + if int(this_asset.size_bytes or 0) == 0 and new_size > 0: + this_asset.size_bytes = new_size + await session.flush() + except IntegrityError: + # Someone else claimed it concurrently; fetch canonical and merge + canonical = ( + await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1)) + ).scalars().first() + if canonical and canonical.id != this_asset.id: + await redirect_all_references_then_delete_asset( + session, + duplicate_asset_id=this_asset.id, + canonical_asset_id=canonical.id, + ) + state = await session.get(AssetCacheState, state_id) + if state: + state.mtime_ns = mtime_ns + state.needs_verify = False + with contextlib.suppress(Exception): + await remove_missing_tag_for_asset_id(session, asset_id=canonical.id) + await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id) + await session.flush() + return canonical.id + # If we got here, the integrity error was not about hash uniqueness + raise + + # Claimed successfully + state.mtime_ns = mtime_ns + state.needs_verify = False + with contextlib.suppress(Exception): + await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id) + await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id) + await session.flush() + return this_asset.id + + # 2) Verify case for hashed assets + if this_asset.hash == new_hash: + if int(this_asset.size_bytes or 0) == 0 and new_size > 0: + this_asset.size_bytes = new_size + state.mtime_ns = mtime_ns + state.needs_verify = False + with contextlib.suppress(Exception): + await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id) + await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id) + await session.flush() + return this_asset.id + + # Content changed on this path only: retarget THIS state, do not move AssetInfo rows + canonical = ( + await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1)) + ).scalars().first() + if canonical: + target_id = canonical.id + else: + now = utcnow() + new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now) + session.add(new_asset) + await session.flush() + target_id = new_asset.id + + state.asset_id = target_id + state.mtime_ns = mtime_ns + state.needs_verify = False + with contextlib.suppress(Exception): + await remove_missing_tag_for_asset_id(session, asset_id=target_id) + await _recompute_and_apply_filename_for_asset(session, asset_id=target_id) + await session.flush() + return target_id + except Exception: + raise + + +async def list_unhashed_candidates_under_prefixes(session: AsyncSession, *, prefixes: list[str]) -> list[int]: + if not prefixes: + return [] + + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + escaped, esc = escape_like_prefix(base) + conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) + + path_filter = sa.or_(*conds) if len(conds) > 1 else conds[0] + if session.bind.dialect.name == "postgresql": + stmt = ( + sa.select(AssetCacheState.id) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(Asset.hash.is_(None), path_filter) + .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) + .distinct(AssetCacheState.asset_id) + ) + else: + first_id = sa.func.min(AssetCacheState.id).label("first_id") + stmt = ( + sa.select(first_id) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(Asset.hash.is_(None), path_filter) + .group_by(AssetCacheState.asset_id) + .order_by(first_id.asc()) + ) + return [int(x) for x in (await session.execute(stmt)).scalars().all()] + + +async def list_verify_candidates_under_prefixes( + session: AsyncSession, *, prefixes: Sequence[str] +) -> Union[list[int], Sequence[int]]: + if not prefixes: + return [] + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + escaped, esc = escape_like_prefix(base) + conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) + + return ( + await session.execute( + sa.select(AssetCacheState.id) + .where(AssetCacheState.needs_verify.is_(True)) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.id.asc()) + ) + ).scalars().all() + + +async def ingest_fs_asset( + session: AsyncSession, + *, + asset_hash: str, + abs_path: str, + size_bytes: int, + mtime_ns: int, + mime_type: Optional[str] = None, + info_name: Optional[str] = None, + owner_id: str = "", + preview_id: Optional[str] = None, + user_metadata: Optional[dict] = None, + tags: Sequence[str] = (), + tag_origin: str = "manual", + require_existing_tags: bool = False, +) -> dict: + """ + Idempotently upsert: + - Asset by content hash (create if missing) + - AssetCacheState(file_path) pointing to asset_id + - Optionally AssetInfo + tag links and metadata projection + Returns flags and ids. + """ + locator = os.path.abspath(abs_path) + now = utcnow() + dialect = session.bind.dialect.name + + if preview_id: + if not await session.get(Asset, preview_id): + preview_id = None + + out: dict[str, Any] = { + "asset_created": False, + "asset_updated": False, + "state_created": False, + "state_updated": False, + "asset_info_id": None, + } + + # 1) Asset by hash + asset = ( + await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + if not asset: + vals = { + "hash": asset_hash, + "size_bytes": int(size_bytes), + "mime_type": mime_type, + "created_at": now, + } + if dialect == "sqlite": + res = await session.execute( + d_sqlite.insert(Asset) + .values(**vals) + .on_conflict_do_nothing(index_elements=[Asset.hash]) + ) + if int(res.rowcount or 0) > 0: + out["asset_created"] = True + asset = ( + await session.execute( + select(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).scalars().first() + elif dialect == "postgresql": + res = await session.execute( + d_pg.insert(Asset) + .values(**vals) + .on_conflict_do_nothing( + index_elements=[Asset.hash], + index_where=Asset.__table__.c.hash.isnot(None), + ) + .returning(Asset.id) + ) + inserted_id = res.scalar_one_or_none() + if inserted_id: + out["asset_created"] = True + asset = await session.get(Asset, inserted_id) + else: + asset = ( + await session.execute( + select(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).scalars().first() + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + if not asset: + raise RuntimeError("Asset row not found after upsert.") + else: + changed = False + if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: + asset.size_bytes = int(size_bytes) + changed = True + if mime_type and asset.mime_type != mime_type: + asset.mime_type = mime_type + changed = True + if changed: + out["asset_updated"] = True + + # 2) AssetCacheState upsert by file_path (unique) + vals = { + "asset_id": asset.id, + "file_path": locator, + "mtime_ns": int(mtime_ns), + } + if dialect == "sqlite": + ins = ( + d_sqlite.insert(AssetCacheState) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(AssetCacheState) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + + res = await session.execute(ins) + if int(res.rowcount or 0) > 0: + out["state_created"] = True + else: + upd = ( + sa.update(AssetCacheState) + .where(AssetCacheState.file_path == locator) + .where( + sa.or_( + AssetCacheState.asset_id != asset.id, + AssetCacheState.mtime_ns.is_(None), + AssetCacheState.mtime_ns != int(mtime_ns), + ) + ) + .values(asset_id=asset.id, mtime_ns=int(mtime_ns)) + ) + res2 = await session.execute(upd) + if int(res2.rowcount or 0) > 0: + out["state_updated"] = True + + # 3) Optional AssetInfo + tags + metadata + if info_name: + try: + async with session.begin_nested(): + info = AssetInfo( + owner_id=owner_id, + name=info_name, + asset_id=asset.id, + preview_id=preview_id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(info) + await session.flush() + out["asset_info_id"] = info.id + except IntegrityError: + pass + + existing_info = ( + await session.execute( + select(AssetInfo) + .where( + AssetInfo.asset_id == asset.id, + AssetInfo.name == info_name, + (AssetInfo.owner_id == owner_id), + ) + .limit(1) + ) + ).unique().scalar_one_or_none() + if not existing_info: + raise RuntimeError("Failed to update or insert AssetInfo.") + + if preview_id and existing_info.preview_id != preview_id: + existing_info.preview_id = preview_id + + existing_info.updated_at = now + if existing_info.last_access_time < now: + existing_info.last_access_time = now + await session.flush() + out["asset_info_id"] = existing_info.id + + norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] + if norm and out["asset_info_id"] is not None: + if not require_existing_tags: + await ensure_tags_exist(session, norm, tag_type="user") + + existing_tag_names = set( + name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() + ) + missing = [t for t in norm if t not in existing_tag_names] + if missing and require_existing_tags: + raise ValueError(f"Unknown tags: {missing}") + + existing_links = set( + tag_name + for (tag_name,) in ( + await session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) + ) + ).all() + ) + to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] + if to_add: + session.add_all( + [ + AssetInfoTag( + asset_info_id=out["asset_info_id"], + tag_name=t, + origin=tag_origin, + added_at=now, + ) + for t in to_add + ] + ) + await session.flush() + + # metadata["filename"] hack + if out["asset_info_id"] is not None: + primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id)) + computed_filename = compute_relative_filename(primary_path) if primary_path else None + + current_meta = existing_info.user_metadata or {} + new_meta = dict(current_meta) + if user_metadata is not None: + for k, v in user_metadata.items(): + new_meta[k] = v + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta != current_meta: + await replace_asset_info_metadata_projection( + session, + asset_info_id=out["asset_info_id"], + user_metadata=new_meta, + ) + + try: + await remove_missing_tag_for_asset_id(session, asset_id=asset.id) + except Exception: + logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) + return out + + +async def touch_asset_infos_by_fs_path( + session: AsyncSession, + *, + file_path: str, + ts: Optional[datetime] = None, + only_if_newer: bool = True, +) -> None: + locator = os.path.abspath(file_path) + ts = ts or utcnow() + stmt = sa.update(AssetInfo).where( + sa.exists( + sa.select(sa.literal(1)) + .select_from(AssetCacheState) + .where( + AssetCacheState.asset_id == AssetInfo.asset_id, + AssetCacheState.file_path == locator, + ) + ) + ) + if only_if_newer: + stmt = stmt.where( + sa.or_( + AssetInfo.last_access_time.is_(None), + AssetInfo.last_access_time < ts, + ) + ) + await session.execute(stmt.values(last_access_time=ts)) + + +async def list_cache_states_with_asset_under_prefixes( + session: AsyncSession, + *, + prefixes: Sequence[str], +) -> list[tuple[AssetCacheState, Optional[str], int]]: + """Return (AssetCacheState, asset_hash, size_bytes) for rows under any prefix.""" + if not prefixes: + return [] + + conds = [] + for p in prefixes: + if not p: + continue + base = os.path.abspath(p) + if not base.endswith(os.sep): + base = base + os.sep + escaped, esc = escape_like_prefix(base) + conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) + + if not conds: + return [] + + rows = ( + await session.execute( + select(AssetCacheState, Asset.hash, Asset.size_bytes) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.id.asc()) + ) + ).all() + return [(r[0], r[1], int(r[2] or 0)) for r in rows] + + +async def _recompute_and_apply_filename_for_asset(session: AsyncSession, *, asset_id: str) -> None: + """Compute filename from the first *existing* cache state path and apply it to all AssetInfo (if changed).""" + try: + primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset_id)) + if not primary_path: + return + new_filename = compute_relative_filename(primary_path) + if not new_filename: + return + infos = ( + await session.execute(select(AssetInfo).where(AssetInfo.asset_id == asset_id)) + ).scalars().all() + for info in infos: + current_meta = info.user_metadata or {} + if current_meta.get("filename") == new_filename: + continue + updated = dict(current_meta) + updated["filename"] = new_filename + await replace_asset_info_metadata_projection(session, asset_info_id=info.id, user_metadata=updated) + except Exception: + logging.exception("Failed to recompute filename metadata for asset %s", asset_id) diff --git a/app/assets/database/services/info.py b/app/assets/database/services/info.py new file mode 100644 index 000000000000..b499557418c8 --- /dev/null +++ b/app/assets/database/services/info.py @@ -0,0 +1,586 @@ +from collections import defaultdict +from datetime import datetime +from typing import Any, Optional, Sequence + +import sqlalchemy as sa +from sqlalchemy import delete, func, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import contains_eager, noload + +from ..._helpers import compute_relative_filename, normalize_tags +from ..helpers import ( + apply_metadata_filter, + apply_tag_filters, + ensure_tags_exist, + escape_like_prefix, + project_kv, + visible_owner_clause, +) +from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag +from ..timeutil import utcnow +from .queries import ( + get_asset_by_hash, + list_cache_states_by_asset_id, + pick_best_live_path, +) + + +async def list_asset_infos_page( + session: AsyncSession, + *, + owner_id: str = "", + include_tags: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + name_contains: Optional[str] = None, + metadata_filter: Optional[dict] = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", +) -> tuple[list[AssetInfo], dict[str, list[str]], int]: + base = ( + select(AssetInfo) + .join(Asset, Asset.id == AssetInfo.asset_id) + .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags)) + .where(visible_owner_clause(owner_id)) + ) + + if name_contains: + escaped, esc = escape_like_prefix(name_contains) + base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) + + base = apply_tag_filters(base, include_tags, exclude_tags) + base = apply_metadata_filter(base, metadata_filter) + + sort = (sort or "created_at").lower() + order = (order or "desc").lower() + sort_map = { + "name": AssetInfo.name, + "created_at": AssetInfo.created_at, + "updated_at": AssetInfo.updated_at, + "last_access_time": AssetInfo.last_access_time, + "size": Asset.size_bytes, + } + sort_col = sort_map.get(sort, AssetInfo.created_at) + sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + + base = base.order_by(sort_exp).limit(limit).offset(offset) + + count_stmt = ( + select(func.count()) + .select_from(AssetInfo) + .join(Asset, Asset.id == AssetInfo.asset_id) + .where(visible_owner_clause(owner_id)) + ) + if name_contains: + escaped, esc = escape_like_prefix(name_contains) + count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) + count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = apply_metadata_filter(count_stmt, metadata_filter) + + total = int((await session.execute(count_stmt)).scalar_one() or 0) + + infos = (await session.execute(base)).unique().scalars().all() + + id_list: list[str] = [i.id for i in infos] + tag_map: dict[str, list[str]] = defaultdict(list) + if id_list: + rows = await session.execute( + select(AssetInfoTag.asset_info_id, Tag.name) + .join(Tag, Tag.name == AssetInfoTag.tag_name) + .where(AssetInfoTag.asset_info_id.in_(id_list)) + ) + for aid, tag_name in rows.all(): + tag_map[aid].append(tag_name) + + return infos, tag_map, total + + +async def fetch_asset_info_and_asset( + session: AsyncSession, + *, + asset_info_id: str, + owner_id: str = "", +) -> Optional[tuple[AssetInfo, Asset]]: + stmt = ( + select(AssetInfo, Asset) + .join(Asset, Asset.id == AssetInfo.asset_id) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + .limit(1) + .options(noload(AssetInfo.tags)) + ) + row = await session.execute(stmt) + pair = row.first() + if not pair: + return None + return pair[0], pair[1] + + +async def fetch_asset_info_asset_and_tags( + session: AsyncSession, + *, + asset_info_id: str, + owner_id: str = "", +) -> Optional[tuple[AssetInfo, Asset, list[str]]]: + stmt = ( + select(AssetInfo, Asset, Tag.name) + .join(Asset, Asset.id == AssetInfo.asset_id) + .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) + .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + .options(noload(AssetInfo.tags)) + .order_by(Tag.name.asc()) + ) + + rows = (await session.execute(stmt)).all() + if not rows: + return None + + first_info, first_asset, _ = rows[0] + tags: list[str] = [] + seen: set[str] = set() + for _info, _asset, tag_name in rows: + if tag_name and tag_name not in seen: + seen.add(tag_name) + tags.append(tag_name) + return first_info, first_asset, tags + + +async def create_asset_info_for_existing_asset( + session: AsyncSession, + *, + asset_hash: str, + name: str, + user_metadata: Optional[dict] = None, + tags: Optional[Sequence[str]] = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> AssetInfo: + """Create or return an existing AssetInfo for an Asset identified by asset_hash.""" + now = utcnow() + asset = await get_asset_by_hash(session, asset_hash=asset_hash) + if not asset: + raise ValueError(f"Unknown asset hash {asset_hash}") + + info = AssetInfo( + owner_id=owner_id, + name=name, + asset_id=asset.id, + preview_id=None, + created_at=now, + updated_at=now, + last_access_time=now, + ) + try: + async with session.begin_nested(): + session.add(info) + await session.flush() + except IntegrityError: + existing = ( + await session.execute( + select(AssetInfo) + .options(noload(AssetInfo.tags)) + .where( + AssetInfo.asset_id == asset.id, + AssetInfo.name == name, + AssetInfo.owner_id == owner_id, + ) + .limit(1) + ) + ).unique().scalars().first() + if not existing: + raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") + return existing + + # metadata["filename"] hack + new_meta = dict(user_metadata or {}) + computed_filename = None + try: + p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id)) + if p: + computed_filename = compute_relative_filename(p) + except Exception: + computed_filename = None + if computed_filename: + new_meta["filename"] = computed_filename + if new_meta: + await replace_asset_info_metadata_projection( + session, + asset_info_id=info.id, + user_metadata=new_meta, + ) + + if tags is not None: + await set_asset_info_tags( + session, + asset_info_id=info.id, + tags=tags, + origin=tag_origin, + ) + return info + + +async def set_asset_info_tags( + session: AsyncSession, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", +) -> dict: + desired = normalize_tags(tags) + + current = set( + tag_name for (tag_name,) in ( + await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) + ).all() + ) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + await ensure_tags_exist(session, to_add, tag_type="user") + session.add_all([ + AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) + for t in to_add + ]) + await session.flush() + + if to_remove: + await session.execute( + delete(AssetInfoTag) + .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) + ) + await session.flush() + + return {"added": to_add, "removed": to_remove, "total": desired} + + +async def update_asset_info_full( + session: AsyncSession, + *, + asset_info_id: str, + name: Optional[str] = None, + tags: Optional[Sequence[str]] = None, + user_metadata: Optional[dict] = None, + tag_origin: str = "manual", + asset_info_row: Any = None, +) -> AssetInfo: + if not asset_info_row: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + else: + info = asset_info_row + + touched = False + if name is not None and name != info.name: + info.name = name + touched = True + + computed_filename = None + try: + p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=info.asset_id)) + if p: + computed_filename = compute_relative_filename(p) + except Exception: + computed_filename = None + + if user_metadata is not None: + new_meta = dict(user_metadata) + if computed_filename: + new_meta["filename"] = computed_filename + await replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=new_meta + ) + touched = True + else: + if computed_filename: + current_meta = info.user_metadata or {} + if current_meta.get("filename") != computed_filename: + new_meta = dict(current_meta) + new_meta["filename"] = computed_filename + await replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=new_meta + ) + touched = True + + if tags is not None: + await set_asset_info_tags( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=tag_origin, + ) + touched = True + + if touched and user_metadata is None: + info.updated_at = utcnow() + await session.flush() + + return info + + +async def replace_asset_info_metadata_projection( + session: AsyncSession, + *, + asset_info_id: str, + user_metadata: Optional[dict], +) -> None: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + info.user_metadata = user_metadata or {} + info.updated_at = utcnow() + await session.flush() + + await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) + await session.flush() + + if not user_metadata: + return + + rows: list[AssetInfoMeta] = [] + for k, v in user_metadata.items(): + for r in project_kv(k, v): + rows.append( + AssetInfoMeta( + asset_info_id=asset_info_id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + await session.flush() + + +async def touch_asset_info_by_id( + session: AsyncSession, + *, + asset_info_id: str, + ts: Optional[datetime] = None, + only_if_newer: bool = True, +) -> None: + ts = ts or utcnow() + stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) + if only_if_newer: + stmt = stmt.where( + sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) + ) + await session.execute(stmt.values(last_access_time=ts)) + + +async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool: + stmt = sa.delete(AssetInfo).where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + return int((await session.execute(stmt)).rowcount or 0) > 0 + + +async def add_tags_to_asset_info( + session: AsyncSession, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", + create_if_missing: bool = True, + asset_info_row: Any = None, +) -> dict: + if not asset_info_row: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"added": [], "already_present": [], "total_tags": total} + + if create_if_missing: + await ensure_tags_exist(session, norm, tag_type="user") + + current = { + tag_name + for (tag_name,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + want = set(norm) + to_add = sorted(want - current) + + if to_add: + async with session.begin_nested() as nested: + try: + session.add_all( + [ + AssetInfoTag( + asset_info_id=asset_info_id, + tag_name=t, + origin=origin, + added_at=utcnow(), + ) + for t in to_add + ] + ) + await session.flush() + except IntegrityError: + await nested.rollback() + + after = set(await get_asset_tags(session, asset_info_id=asset_info_id)) + return { + "added": sorted(((after - current) & want)), + "already_present": sorted(want & current), + "total_tags": sorted(after), + } + + +async def remove_tags_from_asset_info( + session: AsyncSession, + *, + asset_info_id: str, + tags: Sequence[str], +) -> dict: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": [], "not_present": [], "total_tags": total} + + existing = { + tag_name + for (tag_name,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + await session.execute( + delete(AssetInfoTag) + .where( + AssetInfoTag.asset_info_id == asset_info_id, + AssetInfoTag.tag_name.in_(to_remove), + ) + ) + await session.flush() + + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": to_remove, "not_present": not_present, "total_tags": total} + + +async def list_tags_with_usage( + session: AsyncSession, + *, + prefix: Optional[str] = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", + owner_id: str = "", +) -> tuple[list[tuple[str, str, int]], int]: + counts_sq = ( + select( + AssetInfoTag.tag_name.label("tag_name"), + func.count(AssetInfoTag.asset_info_id).label("cnt"), + ) + .select_from(AssetInfoTag) + .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) + .where(visible_owner_clause(owner_id)) + .group_by(AssetInfoTag.tag_name) + .subquery() + ) + + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + if prefix: + escaped, esc = escape_like_prefix(prefix.strip().lower()) + q = q.where(Tag.name.like(escaped + "%", escape=esc)) + + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + total_q = select(func.count()).select_from(Tag) + if prefix: + escaped, esc = escape_like_prefix(prefix.strip().lower()) + total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) + if not include_zero: + total_q = total_q.where( + Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) + ) + + rows = (await session.execute(q.limit(limit).offset(offset))).all() + total = (await session.execute(total_q)).scalar_one() + + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) + + +async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]: + return [ + tag_name + for (tag_name,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + ] + + +async def set_asset_info_preview( + session: AsyncSession, + *, + asset_info_id: str, + preview_asset_id: Optional[str], +) -> None: + """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + if preview_asset_id is None: + info.preview_id = None + else: + # validate preview asset exists + if not await session.get(Asset, preview_asset_id): + raise ValueError(f"Preview Asset {preview_asset_id} not found") + info.preview_id = preview_asset_id + + info.updated_at = utcnow() + await session.flush() diff --git a/app/assets/database/services/queries.py b/app/assets/database/services/queries.py new file mode 100644 index 000000000000..fc05e5cbf2a0 --- /dev/null +++ b/app/assets/database/services/queries.py @@ -0,0 +1,76 @@ +import os +from typing import Optional, Sequence, Union + +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from ..models import Asset, AssetCacheState, AssetInfo + + +async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: + row = ( + await session.execute( + select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).first() + return row is not None + + +async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]: + return ( + await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + + +async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]: + return await session.get(AssetInfo, asset_info_id) + + +async def asset_info_exists_for_asset_id(session: AsyncSession, *, asset_id: str) -> bool: + q = ( + select(sa.literal(True)) + .select_from(AssetInfo) + .where(AssetInfo.asset_id == asset_id) + .limit(1) + ) + return (await session.execute(q)).first() is not None + + +async def get_cache_state_by_asset_id(session: AsyncSession, *, asset_id: str) -> Optional[AssetCacheState]: + return ( + await session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_id == asset_id) + .order_by(AssetCacheState.id.asc()) + .limit(1) + ) + ).scalars().first() + + +async def list_cache_states_by_asset_id( + session: AsyncSession, *, asset_id: str +) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]: + return ( + await session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_id == asset_id) + .order_by(AssetCacheState.id.asc()) + ) + ).scalars().all() + + +def pick_best_live_path(states: Union[list[AssetCacheState], Sequence[AssetCacheState]]) -> str: + """ + Return the best on-disk path among cache states: + 1) Prefer a path that exists with needs_verify == False (already verified). + 2) Otherwise, pick the first path that exists. + 3) Otherwise return empty string. + """ + alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)] + if not alive: + return "" + for s in alive: + if not getattr(s, "needs_verify", False): + return s.file_path + return alive[0].file_path diff --git a/app/assets/database/timeutil.py b/app/assets/database/timeutil.py new file mode 100644 index 000000000000..e8fab12ee7c1 --- /dev/null +++ b/app/assets/database/timeutil.py @@ -0,0 +1,6 @@ +from datetime import datetime, timezone + + +def utcnow() -> datetime: + """Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC.""" + return datetime.now(timezone.utc).replace(tzinfo=None) diff --git a/app/assets/manager.py b/app/assets/manager.py new file mode 100644 index 000000000000..50cf146d26a3 --- /dev/null +++ b/app/assets/manager.py @@ -0,0 +1,556 @@ +import contextlib +import logging +import mimetypes +import os +from typing import Optional, Sequence + +from comfy_api.internal import async_to_sync + +from ..db import create_session +from ._helpers import ( + ensure_within_base, + get_name_and_tags_from_asset_path, + resolve_destination_from_tags, +) +from .api import schemas_in, schemas_out +from .database.models import Asset +from .database.services import ( + add_tags_to_asset_info, + asset_exists_by_hash, + asset_info_exists_for_asset_id, + check_fs_asset_exists_quick, + create_asset_info_for_existing_asset, + delete_asset_info_by_id, + fetch_asset_info_and_asset, + fetch_asset_info_asset_and_tags, + get_asset_by_hash, + get_asset_info_by_id, + get_asset_tags, + ingest_fs_asset, + list_asset_infos_page, + list_cache_states_by_asset_id, + list_tags_with_usage, + pick_best_live_path, + remove_tags_from_asset_info, + set_asset_info_preview, + touch_asset_info_by_id, + touch_asset_infos_by_fs_path, + update_asset_info_full, +) +from .storage import hashing + + +async def asset_exists(*, asset_hash: str) -> bool: + async with await create_session() as session: + return await asset_exists_by_hash(session, asset_hash=asset_hash) + + +def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None: + if tags is None: + tags = [] + try: + asset_name, path_tags = get_name_and_tags_from_asset_path(file_path) + async_to_sync.AsyncToSyncConverter.run_async_in_thread( + add_local_asset, + tags=list(dict.fromkeys([*path_tags, *tags])), + file_name=asset_name, + file_path=file_path, + ) + except ValueError as e: + logging.warning("Skipping non-asset path %s: %s", file_path, e) + + +async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None: + abs_path = os.path.abspath(file_path) + size_bytes, mtime_ns = _get_size_mtime_ns(abs_path) + if not size_bytes: + return + + async with await create_session() as session: + if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns): + await touch_asset_infos_by_fs_path(session, file_path=abs_path) + await session.commit() + return + + asset_hash = hashing.blake3_hash_sync(abs_path) + + async with await create_session() as session: + await ingest_fs_asset( + session, + asset_hash="blake3:" + asset_hash, + abs_path=abs_path, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=None, + info_name=file_name, + tag_origin="automatic", + tags=tags, + ) + await session.commit() + + +async def list_assets( + *, + include_tags: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + name_contains: Optional[str] = None, + metadata_filter: Optional[dict] = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", + owner_id: str = "", +) -> schemas_out.AssetsList: + sort = _safe_sort_field(sort) + order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() + + async with await create_session() as session: + infos, tag_map, total = await list_asset_infos_page( + session, + owner_id=owner_id, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + offset=offset, + sort=sort, + order=order, + ) + + summaries: list[schemas_out.AssetSummary] = [] + for info in infos: + asset = info.asset + tags = tag_map.get(info.id, []) + summaries.append( + schemas_out.AssetSummary( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset else None, + mime_type=asset.mime_type if asset else None, + tags=tags, + preview_url=f"/api/assets/{info.id}/content", + created_at=info.created_at, + updated_at=info.updated_at, + last_access_time=info.last_access_time, + ) + ) + + return schemas_out.AssetsList( + assets=summaries, + total=total, + has_more=(offset + len(summaries)) < total, + ) + + +async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail: + async with await create_session() as session: + res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not res: + raise ValueError(f"AssetInfo {asset_info_id} not found") + info, asset, tag_names = res + preview_id = info.preview_id + + return schemas_out.AssetDetail( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, + mime_type=asset.mime_type if asset else None, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + ) + + +async def resolve_asset_content_for_download( + *, + asset_info_id: str, + owner_id: str = "", +) -> tuple[str, str, str]: + async with await create_session() as session: + pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not pair: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + info, asset = pair + states = await list_cache_states_by_asset_id(session, asset_id=asset.id) + abs_path = pick_best_live_path(states) + if not abs_path: + raise FileNotFoundError + + await touch_asset_info_by_id(session, asset_info_id=asset_info_id) + await session.commit() + + ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" + download_name = info.name or os.path.basename(abs_path) + return abs_path, ctype, download_name + + +async def upload_asset_from_temp_path( + spec: schemas_in.UploadAssetSpec, + *, + temp_path: str, + client_filename: Optional[str] = None, + owner_id: str = "", + expected_asset_hash: Optional[str] = None, +) -> schemas_out.AssetCreated: + try: + digest = await hashing.blake3_hash(temp_path) + except Exception as e: + raise RuntimeError(f"failed to hash uploaded file: {e}") + asset_hash = "blake3:" + digest + + if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower(): + raise ValueError("HASH_MISMATCH") + + async with await create_session() as session: + existing = await get_asset_by_hash(session, asset_hash=asset_hash) + if existing is not None: + with contextlib.suppress(Exception): + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) + info = await create_asset_info_for_existing_asset( + session, + asset_hash=asset_hash, + name=display_name, + user_metadata=spec.user_metadata or {}, + tags=spec.tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + tag_names = await get_asset_tags(session, asset_info_id=info.id) + await session.commit() + + return schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=existing.hash, + size=int(existing.size_bytes) if existing.size_bytes is not None else None, + mime_type=existing.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=False, + ) + + base_dir, subdirs = resolve_destination_from_tags(spec.tags) + dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir + os.makedirs(dest_dir, exist_ok=True) + + src_for_ext = (client_filename or spec.name or "").strip() + _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" + ext = _ext if 0 < len(_ext) <= 16 else "" + hashed_basename = f"{digest}{ext}" + dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) + ensure_within_base(dest_abs, base_dir) + + content_type = ( + mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] + or mimetypes.guess_type(hashed_basename, strict=False)[0] + or "application/octet-stream" + ) + + try: + os.replace(temp_path, dest_abs) + except Exception as e: + raise RuntimeError(f"failed to move uploaded file into place: {e}") + + try: + size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs) + except OSError as e: + raise RuntimeError(f"failed to stat destination file: {e}") + + async with await create_session() as session: + result = await ingest_fs_asset( + session, + asset_hash=asset_hash, + abs_path=dest_abs, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=content_type, + info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest), + owner_id=owner_id, + preview_id=None, + user_metadata=spec.user_metadata or {}, + tags=spec.tags, + tag_origin="manual", + require_existing_tags=False, + ) + info_id = result["asset_info_id"] + if not info_id: + raise RuntimeError("failed to create asset metadata") + + pair = await fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id) + if not pair: + raise RuntimeError("inconsistent DB state after ingest") + info, asset = pair + tag_names = await get_asset_tags(session, asset_info_id=info.id) + await session.commit() + + return schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=asset.hash, + size=int(asset.size_bytes), + mime_type=asset.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=result["asset_created"], + ) + + +async def update_asset( + *, + asset_info_id: str, + name: Optional[str] = None, + tags: Optional[list[str]] = None, + user_metadata: Optional[dict] = None, + owner_id: str = "", +) -> schemas_out.AssetUpdated: + async with await create_session() as session: + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + info = await update_asset_info_full( + session, + asset_info_id=asset_info_id, + name=name, + tags=tags, + user_metadata=user_metadata, + tag_origin="manual", + asset_info_row=info_row, + ) + + tag_names = await get_asset_tags(session, asset_info_id=asset_info_id) + await session.commit() + + return schemas_out.AssetUpdated( + id=info.id, + name=info.name, + asset_hash=info.asset.hash if info.asset else None, + tags=tag_names, + user_metadata=info.user_metadata or {}, + updated_at=info.updated_at, + ) + + +async def set_asset_preview( + *, + asset_info_id: str, + preview_asset_id: Optional[str], + owner_id: str = "", +) -> schemas_out.AssetDetail: + async with await create_session() as session: + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + await set_asset_info_preview( + session, + asset_info_id=asset_info_id, + preview_asset_id=preview_asset_id, + ) + + res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not res: + raise RuntimeError("State changed during preview update") + info, asset, tags = res + await session.commit() + + return schemas_out.AssetDetail( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, + mime_type=asset.mime_type if asset else None, + tags=tags, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + ) + + +async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool: + async with await create_session() as session: + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + asset_id = info_row.asset_id if info_row else None + deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not deleted: + await session.commit() + return False + + if not delete_content_if_orphan or not asset_id: + await session.commit() + return True + + still_exists = await asset_info_exists_for_asset_id(session, asset_id=asset_id) + if still_exists: + await session.commit() + return True + + states = await list_cache_states_by_asset_id(session, asset_id=asset_id) + file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)] + + asset_row = await session.get(Asset, asset_id) + if asset_row is not None: + await session.delete(asset_row) + + await session.commit() + for p in file_paths: + with contextlib.suppress(Exception): + if p and os.path.isfile(p): + os.remove(p) + return True + + +async def create_asset_from_hash( + *, + hash_str: str, + name: str, + tags: Optional[list[str]] = None, + user_metadata: Optional[dict] = None, + owner_id: str = "", +) -> Optional[schemas_out.AssetCreated]: + canonical = hash_str.strip().lower() + async with await create_session() as session: + asset = await get_asset_by_hash(session, asset_hash=canonical) + if not asset: + return None + + info = await create_asset_info_for_existing_asset( + session, + asset_hash=canonical, + name=_safe_filename(name, fallback=canonical.split(":", 1)[1]), + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + tag_names = await get_asset_tags(session, asset_info_id=info.id) + await session.commit() + + return schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=asset.hash, + size=int(asset.size_bytes), + mime_type=asset.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=False, + ) + + +async def list_tags( + *, + prefix: Optional[str] = None, + limit: int = 100, + offset: int = 0, + order: str = "count_desc", + include_zero: bool = True, + owner_id: str = "", +) -> schemas_out.TagsList: + limit = max(1, min(1000, limit)) + offset = max(0, offset) + + async with await create_session() as session: + rows, total = await list_tags_with_usage( + session, + prefix=prefix, + limit=limit, + offset=offset, + include_zero=include_zero, + order=order, + owner_id=owner_id, + ) + + tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] + return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total) + + +async def add_tags_to_asset( + *, + asset_info_id: str, + tags: list[str], + origin: str = "manual", + owner_id: str = "", +) -> schemas_out.TagsAdd: + async with await create_session() as session: + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + data = await add_tags_to_asset_info( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=origin, + create_if_missing=True, + asset_info_row=info_row, + ) + await session.commit() + return schemas_out.TagsAdd(**data) + + +async def remove_tags_from_asset( + *, + asset_info_id: str, + tags: list[str], + owner_id: str = "", +) -> schemas_out.TagsRemove: + async with await create_session() as session: + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + data = await remove_tags_from_asset_info( + session, + asset_info_id=asset_info_id, + tags=tags, + ) + await session.commit() + return schemas_out.TagsRemove(**data) + + +def _safe_sort_field(requested: Optional[str]) -> str: + if not requested: + return "created_at" + v = requested.lower() + if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: + return v + return "created_at" + + +def _get_size_mtime_ns(path: str) -> tuple[int, int]: + st = os.stat(path, follow_symlinks=True) + return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + + +def _safe_filename(name: Optional[str], fallback: str) -> str: + n = os.path.basename((name or "").strip() or fallback) + if n: + return n + return fallback diff --git a/app/assets/scanner.py b/app/assets/scanner.py new file mode 100644 index 000000000000..aa8123a49c5a --- /dev/null +++ b/app/assets/scanner.py @@ -0,0 +1,501 @@ +import asyncio +import contextlib +import logging +import os +import time +from dataclasses import dataclass, field +from typing import Literal, Optional + +import sqlalchemy as sa + +import folder_paths + +from ..db import create_session +from ._helpers import ( + collect_models_files, + compute_relative_filename, + get_comfy_models_folders, + get_name_and_tags_from_asset_path, + list_tree, + new_scan_id, + prefixes_for_root, + ts_to_iso, +) +from .api import schemas_in, schemas_out +from .database.helpers import ( + add_missing_tag_for_asset_id, + ensure_tags_exist, + escape_like_prefix, + fast_asset_file_check, + remove_missing_tag_for_asset_id, + seed_from_paths_batch, +) +from .database.models import Asset, AssetCacheState, AssetInfo +from .database.services import ( + compute_hash_and_dedup_for_cache_state, + list_cache_states_by_asset_id, + list_cache_states_with_asset_under_prefixes, + list_unhashed_candidates_under_prefixes, + list_verify_candidates_under_prefixes, +) + +LOGGER = logging.getLogger(__name__) + +SLOW_HASH_CONCURRENCY = 1 + + +@dataclass +class ScanProgress: + scan_id: str + root: schemas_in.RootType + status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled" + scheduled_at: float = field(default_factory=lambda: time.time()) + started_at: Optional[float] = None + finished_at: Optional[float] = None + discovered: int = 0 + processed: int = 0 + file_errors: list[dict] = field(default_factory=list) + + +@dataclass +class SlowQueueState: + queue: asyncio.Queue + workers: list[asyncio.Task] = field(default_factory=list) + closed: bool = False + + +RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {} +PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {} +SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {} + + +def current_statuses() -> schemas_out.AssetScanStatusResponse: + scans = [] + for root in schemas_in.ALLOWED_ROOTS: + prog = PROGRESS_BY_ROOT.get(root) + if not prog: + continue + scans.append(_scan_progress_to_scan_status_model(prog)) + return schemas_out.AssetScanStatusResponse(scans=scans) + + +async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse: + results: list[ScanProgress] = [] + for root in roots: + if root in RUNNING_TASKS and not RUNNING_TASKS[root].done(): + results.append(PROGRESS_BY_ROOT[root]) + continue + + prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled") + PROGRESS_BY_ROOT[root] = prog + state = SlowQueueState(queue=asyncio.Queue()) + SLOW_STATE_BY_ROOT[root] = state + RUNNING_TASKS[root] = asyncio.create_task( + _run_hash_verify_pipeline(root, prog, state), + name=f"asset-scan:{root}", + ) + results.append(prog) + return _status_response_for(results) + + +async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: + t_total = time.perf_counter() + created = 0 + skipped_existing = 0 + paths: list[str] = [] + try: + existing_paths: set[str] = set() + for r in roots: + try: + survivors = await _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True) + if survivors: + existing_paths.update(survivors) + except Exception as ex: + LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex) + + if "models" in roots: + paths.extend(collect_models_files()) + if "input" in roots: + paths.extend(list_tree(folder_paths.get_input_directory())) + if "output" in roots: + paths.extend(list_tree(folder_paths.get_output_directory())) + + specs: list[dict] = [] + tag_pool: set[str] = set() + for p in paths: + ap = os.path.abspath(p) + if ap in existing_paths: + skipped_existing += 1 + continue + try: + st = os.stat(ap, follow_symlinks=True) + except OSError: + continue + if not st.st_size: + continue + name, tags = get_name_and_tags_from_asset_path(ap) + specs.append( + { + "abs_path": ap, + "size_bytes": st.st_size, + "mtime_ns": getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)), + "info_name": name, + "tags": tags, + "fname": compute_relative_filename(ap), + } + ) + for t in tags: + tag_pool.add(t) + + if not specs: + return + async with await create_session() as sess: + if tag_pool: + await ensure_tags_exist(sess, tag_pool, tag_type="user") + + result = await seed_from_paths_batch(sess, specs=specs, owner_id="") + created += result["inserted_infos"] + await sess.commit() + finally: + LOGGER.info( + "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)", + roots, + time.perf_counter() - t_total, + created, + skipped_existing, + len(paths), + ) + + +def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse: + return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses]) + + +def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus: + return schemas_out.AssetScanStatus( + scan_id=progress.scan_id, + root=progress.root, + status=progress.status, + scheduled_at=ts_to_iso(progress.scheduled_at), + started_at=ts_to_iso(progress.started_at), + finished_at=ts_to_iso(progress.finished_at), + discovered=progress.discovered, + processed=progress.processed, + file_errors=[ + schemas_out.AssetScanError( + path=e.get("path", ""), + message=e.get("message", ""), + at=e.get("at"), + ) + for e in (progress.file_errors or []) + ], + ) + + +async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None: + prog.status = "running" + prog.started_at = time.time() + try: + prefixes = prefixes_for_root(root) + + await _fast_db_consistency_pass(root) + + # collect candidates from DB + async with await create_session() as sess: + verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes) + unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes) + # dedupe: prioritize verification first + seen = set() + ordered: list[int] = [] + for lst in (verify_ids, unhashed_ids): + for sid in lst: + if sid not in seen: + seen.add(sid) + ordered.append(sid) + + prog.discovered = len(ordered) + + # queue up work + for sid in ordered: + await state.queue.put(sid) + state.closed = True + _start_state_workers(root, prog, state) + await _await_state_workers_then_finish(root, prog, state) + except asyncio.CancelledError: + prog.status = "cancelled" + raise + except Exception as exc: + _append_error(prog, path="", message=str(exc)) + prog.status = "failed" + prog.finished_at = time.time() + LOGGER.exception("Asset scan failed for %s", root) + finally: + RUNNING_TASKS.pop(root, None) + + +async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None: + """ + Detect missing files quickly and toggle 'missing' tag per asset_id. + + Rules: + - Only hashed assets (assets.hash != NULL) participate in missing tagging. + - We consider ALL cache states of the asset (across roots) before tagging. + """ + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + elif root == "input": + bases = [folder_paths.get_input_directory()] + else: + bases = [folder_paths.get_output_directory()] + + try: + async with await create_session() as sess: + # state + hash + size for the current root + rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases) + + # Track fast_ok within the scanned root and whether the asset is hashed + by_asset: dict[str, dict[str, bool]] = {} + for state, a_hash, size_db in rows: + aid = state.asset_id + acc = by_asset.get(aid) + if acc is None: + acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)} + by_asset[aid] = acc + try: + if acc["hashed"]: + st = os.stat(state.file_path, follow_symlinks=True) + if fast_asset_file_check(mtime_db=state.mtime_ns, size_db=acc["size_db"], stat_result=st): + acc["any_fast_ok_here"] = True + except FileNotFoundError: + pass + except OSError as e: + _append_error(prog, path=state.file_path, message=str(e)) + + # Decide per asset, considering ALL its states (not just this root) + for aid, acc in by_asset.items(): + try: + if not acc["hashed"]: + # Never tag seed assets as missing + continue + + any_fast_ok_global = acc["any_fast_ok_here"] + if not any_fast_ok_global: + # Check other states outside this root + others = await list_cache_states_by_asset_id(sess, asset_id=aid) + for st in others: + try: + any_fast_ok_global = fast_asset_file_check( + mtime_db=st.mtime_ns, + size_db=acc["size_db"], + stat_result=os.stat(st.file_path, follow_symlinks=True), + ) + except OSError: + continue + + if any_fast_ok_global: + await remove_missing_tag_for_asset_id(sess, asset_id=aid) + else: + await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") + except Exception as ex: + _append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}") + + await sess.commit() + except Exception as e: + _append_error(prog, path="", message=f"reconcile failed: {e}") + + +def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None: + if state.workers: + return + + async def _worker(_wid: int): + while True: + sid = await state.queue.get() + try: + if sid is None: + return + try: + async with await create_session() as sess: + # Optional: fetch path for better error messages + st = await sess.get(AssetCacheState, sid) + try: + await compute_hash_and_dedup_for_cache_state(sess, state_id=sid) + await sess.commit() + except Exception as e: + path = st.file_path if st else f"state:{sid}" + _append_error(prog, path=path, message=str(e)) + raise + except Exception: + pass + finally: + prog.processed += 1 + finally: + state.queue.task_done() + + state.workers = [ + asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") + for i in range(SLOW_HASH_CONCURRENCY) + ] + + async def _close_when_ready(): + while not state.closed: + await asyncio.sleep(0.05) + for _ in range(SLOW_HASH_CONCURRENCY): + await state.queue.put(None) + + asyncio.create_task(_close_when_ready()) + + +async def _await_state_workers_then_finish( + root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState +) -> None: + if state.workers: + await asyncio.gather(*state.workers, return_exceptions=True) + await _reconcile_missing_tags_for_root(root, prog) + prog.finished_at = time.time() + prog.status = "completed" + + +def _append_error(prog: ScanProgress, *, path: str, message: str) -> None: + prog.file_errors.append({ + "path": path, + "message": message, + "at": ts_to_iso(time.time()), + }) + + +async def _fast_db_consistency_pass( + root: schemas_in.RootType, + *, + collect_existing_paths: bool = False, + update_missing_tags: bool = False, +) -> Optional[set[str]]: + """Fast DB+FS pass for a root: + - Toggle needs_verify per state using fast check + - For hashed assets with at least one fast-ok state in this root: delete stale missing states + - For seed assets with all states missing: delete Asset and its AssetInfos + - Optionally add/remove 'missing' tags based on fast-ok in this root + - Optionally return surviving absolute paths + """ + prefixes = prefixes_for_root(root) + if not prefixes: + return set() if collect_existing_paths else None + + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + escaped, esc = escape_like_prefix(base) + conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) + + async with await create_session() as sess: + rows = ( + await sess.execute( + sa.select( + AssetCacheState.id, + AssetCacheState.file_path, + AssetCacheState.mtime_ns, + AssetCacheState.needs_verify, + AssetCacheState.asset_id, + Asset.hash, + Asset.size_bytes, + ) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) + ) + ).all() + + by_asset: dict[str, dict] = {} + for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows: + acc = by_asset.get(aid) + if acc is None: + acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} + by_asset[aid] = acc + + fast_ok = False + try: + exists = True + fast_ok = fast_asset_file_check( + mtime_db=mtime_db, + size_db=acc["size_db"], + stat_result=os.stat(fp, follow_symlinks=True), + ) + except FileNotFoundError: + exists = False + except OSError: + exists = False + + acc["states"].append({ + "sid": sid, + "fp": fp, + "exists": exists, + "fast_ok": fast_ok, + "needs_verify": bool(needs_verify), + }) + + to_set_verify: list[int] = [] + to_clear_verify: list[int] = [] + stale_state_ids: list[int] = [] + survivors: set[str] = set() + + for aid, acc in by_asset.items(): + a_hash = acc["hash"] + states = acc["states"] + any_fast_ok = any(s["fast_ok"] for s in states) + all_missing = all(not s["exists"] for s in states) + + for s in states: + if not s["exists"]: + continue + if s["fast_ok"] and s["needs_verify"]: + to_clear_verify.append(s["sid"]) + if not s["fast_ok"] and not s["needs_verify"]: + to_set_verify.append(s["sid"]) + + if a_hash is None: + if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists + await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid)) + asset = await sess.get(Asset, aid) + if asset: + await sess.delete(asset) + else: + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["fp"])) + continue + + if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records + for s in states: + if not s["exists"]: + stale_state_ids.append(s["sid"]) + if update_missing_tags: + with contextlib.suppress(Exception): + await remove_missing_tag_for_asset_id(sess, asset_id=aid) + elif update_missing_tags: + with contextlib.suppress(Exception): + await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") + + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["fp"])) + + if stale_state_ids: + await sess.execute(sa.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids))) + if to_set_verify: + await sess.execute( + sa.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_set_verify)) + .values(needs_verify=True) + ) + if to_clear_verify: + await sess.execute( + sa.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_clear_verify)) + .values(needs_verify=False) + ) + await sess.commit() + return survivors if collect_existing_paths else None diff --git a/app/assets/storage/__init__.py b/app/assets/storage/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/app/assets/storage/hashing.py b/app/assets/storage/hashing.py new file mode 100644 index 000000000000..3eaed77a33eb --- /dev/null +++ b/app/assets/storage/hashing.py @@ -0,0 +1,72 @@ +import asyncio +import os +from typing import IO, Union + +from blake3 import blake3 + +DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB + + +def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str: + """Hash an already-open binary file object by streaming in chunks. + - Seeks to the beginning before reading (if supported). + - Restores the original position afterward (if tell/seek are supported). + """ + if chunk_size <= 0: + chunk_size = DEFAULT_CHUNK + + orig_pos = None + if hasattr(file_obj, "tell"): + orig_pos = file_obj.tell() + + try: + if hasattr(file_obj, "seek"): + file_obj.seek(0) + + h = blake3() + while True: + chunk = file_obj.read(chunk_size) + if not chunk: + break + h.update(chunk) + return h.hexdigest() + finally: + if hasattr(file_obj, "seek") and orig_pos is not None: + file_obj.seek(orig_pos) + + +def blake3_hash_sync( + fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]], + chunk_size: int = DEFAULT_CHUNK, +) -> str: + """Returns a BLAKE3 hex digest for ``fp``, which may be: + - a filename (str/bytes) or PathLike + - an open binary file object + + If ``fp`` is a file object, it must be opened in **binary** mode and support + ``read``, ``seek``, and ``tell``. The function will seek to the start before + reading and will attempt to restore the original position afterward. + """ + if hasattr(fp, "read"): + return _hash_file_obj_sync(fp, chunk_size) + + with open(os.fspath(fp), "rb") as f: + return _hash_file_obj_sync(f, chunk_size) + + +async def blake3_hash( + fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]], + chunk_size: int = DEFAULT_CHUNK, +) -> str: + """Async wrapper for ``blake3_hash_sync``. + Uses a worker thread so the event loop remains responsive. + """ + # If it is a path, open inside the worker thread to keep I/O off the loop. + if hasattr(fp, "read"): + return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size) + + def _worker() -> str: + with open(os.fspath(fp), "rb") as f: + return _hash_file_obj_sync(f, chunk_size) + + return await asyncio.to_thread(_worker) diff --git a/app/database/db.py b/app/database/db.py deleted file mode 100644 index 1de8b80edd8a..000000000000 --- a/app/database/db.py +++ /dev/null @@ -1,112 +0,0 @@ -import logging -import os -import shutil -from app.logger import log_startup_warning -from utils.install_util import get_missing_requirements_message -from comfy.cli_args import args - -_DB_AVAILABLE = False -Session = None - - -try: - from alembic import command - from alembic.config import Config - from alembic.runtime.migration import MigrationContext - from alembic.script import ScriptDirectory - from sqlalchemy import create_engine - from sqlalchemy.orm import sessionmaker - - _DB_AVAILABLE = True -except ImportError as e: - log_startup_warning( - f""" ------------------------------------------------------------------------- -Error importing dependencies: {e} -{get_missing_requirements_message()} -This error is happening because ComfyUI now uses a local sqlite database. ------------------------------------------------------------------------- -""".strip() - ) - - -def dependencies_available(): - """ - Temporary function to check if the dependencies are available - """ - return _DB_AVAILABLE - - -def can_create_session(): - """ - Temporary function to check if the database is available to create a session - During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created - """ - return dependencies_available() and Session is not None - - -def get_alembic_config(): - root_path = os.path.join(os.path.dirname(__file__), "../..") - config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) - scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) - - config = Config(config_path) - config.set_main_option("script_location", scripts_path) - config.set_main_option("sqlalchemy.url", args.database_url) - - return config - - -def get_db_path(): - url = args.database_url - if url.startswith("sqlite:///"): - return url.split("///")[1] - else: - raise ValueError(f"Unsupported database URL '{url}'.") - - -def init_db(): - db_url = args.database_url - logging.debug(f"Database URL: {db_url}") - db_path = get_db_path() - db_exists = os.path.exists(db_path) - - config = get_alembic_config() - - # Check if we need to upgrade - engine = create_engine(db_url) - conn = engine.connect() - - context = MigrationContext.configure(conn) - current_rev = context.get_current_revision() - - script = ScriptDirectory.from_config(config) - target_rev = script.get_current_head() - - if target_rev is None: - logging.warning("No target revision found.") - elif current_rev != target_rev: - # Backup the database pre upgrade - backup_path = db_path + ".bkp" - if db_exists: - shutil.copy(db_path, backup_path) - else: - backup_path = None - - try: - command.upgrade(config, target_rev) - logging.info(f"Database upgraded from {current_rev} to {target_rev}") - except Exception as e: - if backup_path: - # Restore the database from backup if upgrade fails - shutil.copy(backup_path, db_path) - os.remove(backup_path) - logging.exception("Error upgrading database: ") - raise e - - global Session - Session = sessionmaker(bind=engine) - - -def create_session(): - return Session() diff --git a/app/database/models.py b/app/database/models.py deleted file mode 100644 index 6facfb8f2b5e..000000000000 --- a/app/database/models.py +++ /dev/null @@ -1,14 +0,0 @@ -from sqlalchemy.orm import declarative_base - -Base = declarative_base() - - -def to_dict(obj): - fields = obj.__table__.columns.keys() - return { - field: (val.to_dict() if hasattr(val, "to_dict") else val) - for field in fields - if (val := getattr(obj, field)) - } - -# TODO: Define models here diff --git a/app/db.py b/app/db.py new file mode 100644 index 000000000000..f125706f0189 --- /dev/null +++ b/app/db.py @@ -0,0 +1,255 @@ +import logging +import os +import shutil +from contextlib import asynccontextmanager +from typing import Optional + +from alembic import command +from alembic.config import Config +from alembic.runtime.migration import MigrationContext +from alembic.script import ScriptDirectory +from sqlalchemy import create_engine, text +from sqlalchemy.engine import make_url +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) + +from comfy.cli_args import args + +LOGGER = logging.getLogger(__name__) +ENGINE: Optional[AsyncEngine] = None +SESSION: Optional[async_sessionmaker] = None + + +def _root_paths(): + """Resolve alembic.ini and migrations script folder.""" + root_path = os.path.abspath(os.path.dirname(__file__)) + config_path = os.path.abspath(os.path.join(root_path, "../alembic.ini")) + scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) + return config_path, scripts_path + + +def _absolutize_sqlite_url(db_url: str) -> str: + """Make SQLite database path absolute. No-op for non-SQLite URLs.""" + try: + u = make_url(db_url) + except Exception: + return db_url + + if not u.drivername.startswith("sqlite"): + return db_url + + db_path: str = u.database or "" + if isinstance(db_path, str) and db_path.startswith("file:"): + return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared" + if not os.path.isabs(db_path): + db_path = os.path.abspath(os.path.join(os.getcwd(), db_path)) + u = u.set(database=db_path) + return str(u) + + +def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]: + """ + If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory), + rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present. + Returns: (normalized_url, is_memory) + """ + try: + u = make_url(db_url) + except Exception: + return db_url, False + if not u.drivername.startswith("sqlite"): + return db_url, False + + db = u.database or "" + if db == ":memory:": + u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true") + return str(u), True + if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db: + if "uri=true" not in db: + u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true")) + return str(u), True + return str(u), False + + +def _get_sqlite_file_path(sync_url: str) -> Optional[str]: + """Return the on-disk path for a SQLite URL, else None.""" + try: + u = make_url(sync_url) + except Exception: + return None + + if not u.drivername.startswith("sqlite"): + return None + db_path = u.database + if isinstance(db_path, str) and db_path.startswith("file:"): + return None # Not a real file if it is a URI like "file:...?" + return db_path + + +def _get_alembic_config(sync_url: str) -> Config: + """Prepare Alembic Config with script location and DB URL.""" + config_path, scripts_path = _root_paths() + cfg = Config(config_path) + cfg.set_main_option("script_location", scripts_path) + cfg.set_main_option("sqlalchemy.url", sync_url) + return cfg + + +async def init_db_engine() -> None: + """Initialize async engine + sessionmaker and run migrations to head. + + This must be called once on application startup before any DB usage. + """ + global ENGINE, SESSION + + if ENGINE is not None: + return + + raw_url = args.database_url + if not raw_url: + raise RuntimeError("Database URL is not configured.") + + db_url, is_mem = _normalize_sqlite_memory_url(raw_url) + db_url = _absolutize_sqlite_url(db_url) + + # Prepare async engine + connect_args = {} + if db_url.startswith("sqlite"): + connect_args = { + "check_same_thread": False, + "timeout": 12, + } + if is_mem: + connect_args["uri"] = True + + ENGINE = create_async_engine( + db_url, + connect_args=connect_args, + pool_pre_ping=True, + future=True, + ) + + # Enforce SQLite pragmas on the async engine + if db_url.startswith("sqlite"): + async with ENGINE.begin() as conn: + if not is_mem: + # WAL for concurrency and durability, Foreign Keys for referential integrity + current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar() + if str(current_mode).lower() != "wal": + new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar() + if str(new_mode).lower() != "wal": + raise RuntimeError("Failed to set SQLite journal mode to WAL.") + LOGGER.info("SQLite journal mode set to WAL.") + + await conn.execute(text("PRAGMA foreign_keys = ON;")) + await conn.execute(text("PRAGMA synchronous = NORMAL;")) + + await _run_migrations(database_url=db_url, connect_args=connect_args) + + SESSION = async_sessionmaker( + bind=ENGINE, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + + +async def _run_migrations(database_url: str, connect_args: dict) -> None: + if database_url.find("postgresql+psycopg") == -1: + """SQLite: Convert an async SQLAlchemy URL to a sync URL for Alembic.""" + u = make_url(database_url) + driver = u.drivername + if not driver.startswith("sqlite+aiosqlite"): + raise ValueError(f"Unsupported DB driver: {driver}") + database_url, is_mem = _normalize_sqlite_memory_url(str(u.set(drivername="sqlite"))) + database_url = _absolutize_sqlite_url(database_url) + + cfg = _get_alembic_config(database_url) + engine = create_engine(database_url, future=True, connect_args=connect_args) + with engine.connect() as conn: + context = MigrationContext.configure(conn) + current_rev = context.get_current_revision() + + script = ScriptDirectory.from_config(cfg) + target_rev = script.get_current_head() + + if target_rev is None: + LOGGER.warning("Alembic: no target revision found.") + return + + if current_rev == target_rev: + LOGGER.debug("Alembic: database already at head %s", target_rev) + return + + LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev) + + # Optional backup for SQLite file DBs + backup_path = None + sqlite_path = _get_sqlite_file_path(database_url) + if sqlite_path and os.path.exists(sqlite_path): + backup_path = sqlite_path + ".bkp" + try: + shutil.copy(sqlite_path, backup_path) + except Exception as exc: + LOGGER.warning("Failed to create SQLite backup before migration: %s", exc) + + try: + command.upgrade(cfg, target_rev) + except Exception: + if backup_path and os.path.exists(backup_path): + LOGGER.exception("Error upgrading database, attempting restore from backup.") + try: + shutil.copy(backup_path, sqlite_path) # restore + os.remove(backup_path) + except Exception as re: + LOGGER.error("Failed to restore SQLite backup: %s", re) + else: + LOGGER.exception("Error upgrading database, backup is not available.") + raise + + +def get_engine(): + """Return the global async engine (initialized after init_db_engine()).""" + if ENGINE is None: + raise RuntimeError("Engine is not initialized. Call init_db_engine() first.") + return ENGINE + + +def get_session_maker(): + """Return the global async_sessionmaker (initialized after init_db_engine()).""" + if SESSION is None: + raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.") + return SESSION + + +@asynccontextmanager +async def session_scope(): + """Async context manager for a unit of work: + + async with session_scope() as sess: + ... use sess ... + """ + maker = get_session_maker() + async with maker() as sess: + try: + yield sess + await sess.commit() + except Exception: + await sess.rollback() + raise + + +async def create_session(): + """Convenience helper to acquire a single AsyncSession instance. + + Typical usage: + async with (await create_session()) as sess: + ... + """ + maker = get_session_maker() + return maker() diff --git a/app/frontend_management.py b/app/frontend_management.py index 0bee73685b93..75660fe189cb 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -196,6 +196,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None: class FrontendManager: + """ + A class to manage ComfyUI frontend versions and installations. + + This class handles the initialization and management of different frontend versions, + including the default frontend from the pip package and custom frontend versions + from GitHub repositories. + + Attributes: + CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored. + """ + CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions") @classmethod @@ -205,6 +216,15 @@ def get_required_frontend_version(cls) -> str: @classmethod def default_frontend_path(cls) -> str: + """ + Get the path to the default frontend installation from the pip package. + + Returns: + str: The path to the default frontend static files. + + Raises: + SystemExit: If the comfyui-frontend-package is not installed. + """ try: import comfyui_frontend_package @@ -225,6 +245,15 @@ def default_frontend_path(cls) -> str: @classmethod def templates_path(cls) -> str: + """ + Get the path to the workflow templates. + + Returns: + str: The path to the workflow templates directory. + + Raises: + SystemExit: If the comfyui-workflow-templates package is not installed. + """ try: import comfyui_workflow_templates @@ -260,11 +289,16 @@ def embedded_docs_path(cls) -> str: @classmethod def parse_version_string(cls, value: str) -> tuple[str, str, str]: """ + Parse a version string into its components. + + The version string should be in the format: 'owner/repo@version' + where version can be either a semantic version (v1.2.3) or 'latest'. + Args: value (str): The version string to parse. Returns: - tuple[str, str]: A tuple containing provider name and version. + tuple[str, str, str]: A tuple containing (owner, repo, version). Raises: argparse.ArgumentTypeError: If the version string is invalid. @@ -281,18 +315,22 @@ def init_frontend_unsafe( cls, version_string: str, provider: Optional[FrontEndProvider] = None ) -> str: """ - Initializes the frontend for the specified version. + Initialize a frontend version without error handling. + + This method attempts to initialize a specific frontend version, either from + the default pip package or from a custom GitHub repository. It will download + and extract the frontend files if necessary. Args: - version_string (str): The version string. - provider (FrontEndProvider, optional): The provider to use. Defaults to None. + version_string (str): The version string specifying which frontend to use. + provider (FrontEndProvider, optional): The provider to use for custom frontends. Returns: str: The path to the initialized frontend. Raises: - Exception: If there is an error during the initialization process. - main error source might be request timeout or invalid URL. + Exception: If there is an error during initialization (e.g., network timeout, + invalid URL, or missing assets). """ if version_string == DEFAULT_VERSION_STRING: check_frontend_version() @@ -344,13 +382,17 @@ def init_frontend_unsafe( @classmethod def init_frontend(cls, version_string: str) -> str: """ - Initializes the frontend with the specified version string. + Initialize a frontend version with error handling. + + This is the main method to initialize a frontend version. It wraps init_frontend_unsafe + with error handling, falling back to the default frontend if initialization fails. Args: - version_string (str): The version string to initialize the frontend with. + version_string (str): The version string specifying which frontend to use. Returns: - str: The path of the initialized frontend. + str: The path to the initialized frontend. If initialization fails, + returns the path to the default frontend. """ try: return cls.init_frontend_unsafe(version_string) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cc1f12482e9f..7955cc763180 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -212,7 +212,8 @@ def is_valid_directory(path: str) -> str: database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) -parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") +parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.") +parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/comfy/utils.py b/comfy/utils.py index fab28cf088c0..73461b8761fb 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -50,10 +50,16 @@ class ModelCheckpoint: else: logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") +def is_html_file(file_path): + with open(file_path, "rb") as f: + content = f.read(100) + return b"" in content or b" 0: message = e.args[0] if "HeaderTooLarge" in message: @@ -93,6 +101,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): sd = pl_sd else: sd = pl_sd + + # populate_db_with_asset(ckpt) # surprise tool that can help us later - performs hashing on model file return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 4826818df860..7986f490d5cd 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -392,6 +392,20 @@ def as_dict(self): }) return to_return +@comfytype(io_type="ASSET") +class Asset(ComfyTypeI): + class Input(WidgetInput): + def __init__(self, id: str, query_tags: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: str=None, socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + self.query_tags = query_tags + + def as_dict(self): + to_return = super().as_dict() | prune_dict({ + "query_tags": self.query_tags + }) + return to_return + @comfytype(io_type="IMAGE") class Image(ComfyTypeIO): Type = torch.Tensor diff --git a/folder_paths.py b/folder_paths.py index f110d832bb23..8839fac784ba 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -279,10 +279,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -def get_full_path(folder_name: str, filename: str) -> str | None: - """ - Get the full path of a file in a folder, has to be a file - """ +def get_full_path(folder_name: str, filename: str, allow_missing: bool = False) -> str | None: global folder_names_and_paths folder_name = map_legacy(folder_name) if folder_name not in folder_names_and_paths: @@ -295,6 +292,8 @@ def get_full_path(folder_name: str, filename: str) -> str | None: return full_path elif os.path.islink(full_path): logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path)) + elif allow_missing: + return full_path return None @@ -309,6 +308,27 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str: return full_path +def get_relative_path(full_path: str) -> tuple[str, str] | None: + """Convert a full path back to a type-relative path. + + Args: + full_path: The full path to the file + + Returns: + tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise + """ + global folder_names_and_paths + full_path = os.path.normpath(full_path) + + for model_type, (paths, _) in folder_names_and_paths.items(): + for base_path in paths: + base_path = os.path.normpath(base_path) + if full_path.startswith(base_path): + relative_path = os.path.relpath(full_path, base_path) + return model_type, relative_path + + return None + def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]: folder_name = map_legacy(folder_name) global folder_names_and_paths diff --git a/main.py b/main.py index c33f0e17bf30..9bc9ac9ed527 100644 --- a/main.py +++ b/main.py @@ -164,7 +164,6 @@ def cuda_malloc_warning(): if cuda_malloc_warning: logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") - def prompt_worker(q, server_instance): current_time: float = 0.0 cache_type = execution.CacheType.CLASSIC @@ -279,14 +278,13 @@ def cleanup_temp(): if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) +async def setup_database(): + from app.assets import sync_seed_assets + from app.db import init_db_engine -def setup_database(): - try: - from app.database.db import init_db, dependencies_available - if dependencies_available(): - init_db() - except Exception as e: - logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") + await init_db_engine() + if not args.disable_assets_autoscan: + await sync_seed_assets(["models"]) def start_comfyui(asyncio_loop=None): @@ -312,6 +310,8 @@ def start_comfyui(asyncio_loop=None): asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) + asyncio_loop.run_until_complete(setup_database()) + hook_breaker_ac10a0.save_functions() asyncio_loop.run_until_complete(nodes.init_extra_nodes( init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, @@ -320,7 +320,6 @@ def start_comfyui(asyncio_loop=None): hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() - setup_database() prompt_server.add_routes() hijack_progress(prompt_server) diff --git a/requirements.txt b/requirements.txt index 2980bebdd909..f9497c93a492 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,9 @@ tqdm psutil alembic SQLAlchemy +aiosqlite av>=14.2.0 +blake3 #non essential dependencies: kornia>=0.7.1 diff --git a/server.py b/server.py index 6036773974ae..424ca9b593d7 100644 --- a/server.py +++ b/server.py @@ -37,6 +37,7 @@ from app.custom_node_manager import CustomNodeManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes +from app.assets import sync_seed_assets, register_assets_system from protocol import BinaryEventTypes # Import cache control middleware @@ -178,6 +179,7 @@ def __init__(self, loop): else args.front_end_root ) logging.info(f"[Prompt Server] web root: {self.web_root}") + register_assets_system(self.app, self.user_manager) routes = web.RouteTableDef() self.routes = routes self.last_node_id = None @@ -622,6 +624,7 @@ def node_info(node_class): @routes.get("/object_info") async def get_object_info(request): + await sync_seed_assets(["models"]) with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: diff --git a/tests-assets/conftest.py b/tests-assets/conftest.py new file mode 100644 index 000000000000..7d1ea5acb570 --- /dev/null +++ b/tests-assets/conftest.py @@ -0,0 +1,307 @@ +import asyncio +import contextlib +import json +import os +import socket +import subprocess +import sys +import tempfile +import time +from pathlib import Path +from typing import AsyncIterator, Callable, Optional + +import aiohttp +import pytest +import pytest_asyncio + + +def pytest_addoption(parser: pytest.Parser) -> None: + """ + Allow overriding the database URL used by the spawned ComfyUI process. + Priority: + 1) --db-url command line option + 2) ASSETS_TEST_DB_URL environment variable (used by CI) + 3) default: sqlite in-memory + """ + parser.addoption( + "--db-url", + action="store", + default=os.environ.get("ASSETS_TEST_DB_URL", "sqlite+aiosqlite:///:memory:"), + help="Async SQLAlchemy DB URL (e.g. sqlite+aiosqlite:///:memory: or postgresql+psycopg://user:pass@host/db)", + ) + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def _make_base_dirs(root: Path) -> None: + for sub in ("models", "custom_nodes", "input", "output", "temp", "user"): + (root / sub).mkdir(parents=True, exist_ok=True) + + +async def _wait_http_ready(base: str, session: aiohttp.ClientSession, timeout: float = 90.0) -> None: + start = time.time() + last_err = None + while time.time() - start < timeout: + try: + async with session.get(base + "/api/assets") as r: + if r.status in (200, 400): + return + except Exception as e: + last_err = e + await asyncio.sleep(0.25) + raise RuntimeError(f"ComfyUI HTTP did not become ready: {last_err}") + + +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="session") +def comfy_tmp_base_dir() -> Path: + env_base = os.environ.get("ASSETS_TEST_BASE_DIR") + created_by_fixture = False + if env_base: + tmp = Path(env_base) + tmp.mkdir(parents=True, exist_ok=True) + else: + tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-")) + created_by_fixture = True + _make_base_dirs(tmp) + yield tmp + if created_by_fixture: + with contextlib.suppress(Exception): + for p in sorted(tmp.rglob("*"), reverse=True): + if p.is_file() or p.is_symlink(): + p.unlink(missing_ok=True) + for p in sorted(tmp.glob("**/*"), reverse=True): + with contextlib.suppress(Exception): + p.rmdir() + tmp.rmdir() + + +@pytest.fixture(scope="session") +def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest): + """ + Boot ComfyUI subprocess with: + - sandbox base dir + - sqlite memory DB (default) + - autoscan disabled + Returns (base_url, process, port) + """ + port = _free_port() + db_url = request.config.getoption("--db-url") + + logs_dir = comfy_tmp_base_dir / "logs" + logs_dir.mkdir(exist_ok=True) + out_log = open(logs_dir / "stdout.log", "w", buffering=1) + err_log = open(logs_dir / "stderr.log", "w", buffering=1) + + comfy_root = Path(__file__).resolve().parent.parent + if not (comfy_root / "main.py").is_file(): + raise FileNotFoundError(f"main.py not found under {comfy_root}") + + proc = subprocess.Popen( + args=[ + sys.executable, + "main.py", + f"--base-directory={str(comfy_tmp_base_dir)}", + f"--database-url={db_url}", + "--disable-assets-autoscan", + "--listen", + "127.0.0.1", + "--port", + str(port), + "--cpu", + ], + stdout=out_log, + stderr=err_log, + cwd=str(comfy_root), + env={**os.environ}, + ) + + for _ in range(50): + if proc.poll() is not None: + out_log.flush() + err_log.flush() + raise RuntimeError(f"ComfyUI exited early with code {proc.returncode}") + time.sleep(0.1) + + base_url = f"http://127.0.0.1:{port}" + try: + async def _probe(): + async with aiohttp.ClientSession() as s: + await _wait_http_ready(base_url, s, timeout=90.0) + + asyncio.run(_probe()) + yield base_url, proc, port + except Exception as e: + with contextlib.suppress(Exception): + proc.terminate() + proc.wait(timeout=10) + with contextlib.suppress(Exception): + out_log.flush() + err_log.flush() + raise RuntimeError(f"ComfyUI did not become ready: {e}") + + if proc and proc.poll() is None: + with contextlib.suppress(Exception): + proc.terminate() + proc.wait(timeout=15) + out_log.close() + err_log.close() + + +@pytest_asyncio.fixture +async def http() -> AsyncIterator[aiohttp.ClientSession]: + timeout = aiohttp.ClientTimeout(total=120) + async with aiohttp.ClientSession(timeout=timeout) as s: + yield s + + +@pytest.fixture +def api_base(comfy_url_and_proc) -> str: + base_url, _proc, _port = comfy_url_and_proc + return base_url + + +async def _post_multipart_asset( + session: aiohttp.ClientSession, + base: str, + *, + name: str, + tags: list[str], + meta: dict, + data: bytes, + extra_fields: Optional[dict] = None, +) -> tuple[int, dict]: + form = aiohttp.FormData() + form.add_field("file", data, filename=name, content_type="application/octet-stream") + form.add_field("tags", json.dumps(tags)) + form.add_field("name", name) + form.add_field("user_metadata", json.dumps(meta)) + if extra_fields: + for k, v in extra_fields.items(): + form.add_field(k, v) + async with session.post(base + "/api/assets", data=form) as r: + body = await r.json() + return r.status, body + + +@pytest.fixture +def make_asset_bytes() -> Callable[[str, int], bytes]: + def _make(name: str, size: int = 8192) -> bytes: + seed = sum(ord(c) for c in name) % 251 + return bytes((i * 31 + seed) % 256 for i in range(size)) + return _make + + +@pytest_asyncio.fixture +async def asset_factory(http: aiohttp.ClientSession, api_base: str): + """ + Returns create(name, tags, meta, data) -> response dict + Tracks created ids and deletes them after the test. + """ + created: list[str] = [] + + async def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict: + status, body = await _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data) + assert status in (200, 201), body + created.append(body["id"]) + return body + + yield create + + # cleanup by id + for aid in created: + with contextlib.suppress(Exception): + async with http.delete(f"{api_base}/api/assets/{aid}") as r: + await r.read() + + +@pytest_asyncio.fixture +async def seeded_asset(request: pytest.FixtureRequest, http: aiohttp.ClientSession, api_base: str) -> dict: + """ + Upload one asset with ".safetensors" extension into models/checkpoints/unit-tests/. + Returns response dict with id, asset_hash, tags, etc. + """ + name = "unit_1_example.safetensors" + p = getattr(request, "param", {}) or {} + tags: Optional[list[str]] = p.get("tags") + if tags is None: + tags = ["models", "checkpoints", "unit-tests", "alpha"] + meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None} + form = aiohttp.FormData() + form.add_field("file", b"A" * 4096, filename=name, content_type="application/octet-stream") + form.add_field("tags", json.dumps(tags)) + form.add_field("name", name) + form.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 201, body + return body + + +@pytest_asyncio.fixture(autouse=True) +async def autoclean_unit_test_assets(http: aiohttp.ClientSession, api_base: str): + """Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test.""" + yield + + while True: + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests", "limit": "500", "sort": "name"}, + ) as r: + body = await r.json() + if r.status != 200: + break + ids = [a["id"] for a in body.get("assets", [])] + if not ids: + break + for aid in ids: + with contextlib.suppress(Exception): + async with http.delete(f"{api_base}/api/assets/{aid}") as dr: + await dr.read() + + +async def trigger_sync_seed_assets(session: aiohttp.ClientSession, base_url: str) -> None: + """Force a fast sync/seed pass by calling the ComfyUI '/object_info' endpoint.""" + async with session.post(base_url + "/api/assets/scan/seed", json={"roots": ["models", "input", "output"]}) as r: + await r.read() + await asyncio.sleep(0.1) # tiny yield to the event loop to let any final DB commits flush + + +@pytest_asyncio.fixture +async def run_scan_and_wait(http: aiohttp.ClientSession, api_base: str): + """Schedule an asset scan for a given root and wait until it finishes.""" + async def _run(root: str, timeout: float = 120.0): + async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r: + # we ignore body; scheduling returns 202 with a status payload + await r.read() + + start = time.time() + while True: + async with http.get(api_base + "/api/assets/scan", params={"root": root}) as st: + body = await st.json() + scans = (body or {}).get("scans", []) + status = None + if scans: + status = scans[-1].get("status") + if status in {"completed", "failed", "cancelled"}: + if status != "completed": + raise RuntimeError(f"Scan for root={root} finished with status={status}") + return + if time.time() - start > timeout: + raise TimeoutError(f"Timed out waiting for scan of root={root}") + await asyncio.sleep(0.1) + return _run + + +def get_asset_filename(asset_hash: str, extension: str) -> str: + return asset_hash.removeprefix("blake3:") + extension diff --git a/tests-assets/test_assets_missing_sync.py b/tests-assets/test_assets_missing_sync.py new file mode 100644 index 000000000000..b959e33f0a33 --- /dev/null +++ b/tests-assets/test_assets_missing_sync.py @@ -0,0 +1,347 @@ +import os +import uuid +from pathlib import Path + +import aiohttp +import pytest +from conftest import get_asset_filename, trigger_sync_seed_assets + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_seed_asset_removed_when_file_is_deleted( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, +): + """Asset without hash (seed) whose file disappears: + after triggering sync_seed_assets, Asset + AssetInfo disappear. + """ + # Create a file directly under input/unit-tests/ so tags include "unit-tests" + case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed" + case_dir.mkdir(parents=True, exist_ok=True) + name = f"seed_{uuid.uuid4().hex[:8]}.bin" + fp = case_dir / name + fp.write_bytes(b"Z" * 2048) + + # Trigger a seed sync so DB sees this path (seed asset => hash is NULL) + await trigger_sync_seed_assets(http, api_base) + + # Verify it is visible via API and carries no hash (seed) + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,syncseed", "name_contains": name}, + ) as r1: + body1 = await r1.json() + assert r1.status == 200 + # there should be exactly one with that name + matches = [a for a in body1.get("assets", []) if a.get("name") == name] + assert matches + assert matches[0].get("asset_hash") is None + asset_info_id = matches[0]["id"] + + # Remove the underlying file and sync again + if fp.exists(): + fp.unlink() + + await trigger_sync_seed_assets(http, api_base) + + # It should disappear (AssetInfo and seed Asset gone) + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,syncseed", "name_contains": name}, + ) as r2: + body2 = await r2.json() + assert r2.status == 200 + matches2 = [a for a in body2.get("assets", []) if a.get("name") == name] + assert not matches2, f"Seed asset {asset_info_id} should be gone after sync" + + +@pytest.mark.asyncio +async def test_hashed_asset_missing_tag_added_then_removed_after_scan( + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """Hashed asset with a single cache_state: + 1. delete its file -> sync adds 'missing' + 2. restore file -> scan removes 'missing' + """ + name = "missing_tag_test.png" + tags = ["input", "unit-tests", "msync2"] + data = make_asset_bytes(name, 4096) + a = await asset_factory(name, tags, {}, data) + + # Compute its on-disk path and remove it + dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / get_asset_filename(a["asset_hash"], ".png") + assert dest.exists(), f"Expected asset file at {dest}" + dest.unlink() + + # Fast sync should add 'missing' to the AssetInfo + await trigger_sync_seed_assets(http, api_base) + + async with http.get(f"{api_base}/api/assets/{a['id']}") as g1: + d1 = await g1.json() + assert g1.status == 200, d1 + assert "missing" in set(d1.get("tags", [])), "Expected 'missing' tag after deletion" + + # Restore the file with the exact same content and re-hash/verify via scan + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(data) + + await run_scan_and_wait("input") + + async with http.get(f"{api_base}/api/assets/{a['id']}") as g2: + d2 = await g2.json() + assert g2.status == 200, d2 + assert "missing" not in set(d2.get("tags", [])), "Missing tag should be cleared after verify" + + +@pytest.mark.asyncio +async def test_hashed_asset_two_asset_infos_both_get_missing( + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, +): + """Hashed asset with a single cache_state, but two AssetInfo rows: + deleting the single file then syncing should add 'missing' to both infos. + """ + # Upload one hashed asset + name = "two_infos_one_path.png" + base_tags = ["input", "unit-tests", "multiinfo"] + created = await asset_factory(name, base_tags, {}, b"A" * 2048) + + # Create second AssetInfo for the same Asset via from-hash + payload = { + "hash": created["asset_hash"], + "name": "two_infos_one_path_copy.png", + "tags": base_tags, # keep it in our unit-tests scope for cleanup + "user_metadata": {"k": "v"}, + } + async with http.post(api_base + "/api/assets/from-hash", json=payload) as r2: + b2 = await r2.json() + assert r2.status == 201, b2 + second_id = b2["id"] + + # Remove the single underlying file + p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png") + assert p.exists() + p.unlink() + + async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r0: + tags0 = await r0.json() + assert r0.status == 200, tags0 + byname0 = {t["name"]: t for t in tags0.get("tags", [])} + old_missing = int(byname0.get("missing", {}).get("count", 0)) + + # Sync -> both AssetInfos for this asset must receive 'missing' + await trigger_sync_seed_assets(http, api_base) + + async with http.get(f"{api_base}/api/assets/{created['id']}") as ga: + da = await ga.json() + assert ga.status == 200, da + assert "missing" in set(da.get("tags", [])) + + async with http.get(f"{api_base}/api/assets/{second_id}") as gb: + db = await gb.json() + assert gb.status == 200, db + assert "missing" in set(db.get("tags", [])) + + # Tag usage for 'missing' increased by exactly 2 (two AssetInfos) + async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1: + tags1 = await r1.json() + assert r1.status == 200, tags1 + byname1 = {t["name"]: t for t in tags1.get("tags", [])} + new_missing = int(byname1.get("missing", {}).get("count", 0)) + assert new_missing == old_missing + 2 + + +@pytest.mark.asyncio +async def test_hashed_asset_two_cache_states_partial_delete_then_full_delete( + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """Hashed asset with two cache_state rows: + 1. delete one file -> sync should NOT add 'missing' + 2. delete second file -> sync should add 'missing' + """ + name = "two_cache_states_partial_delete.png" + tags = ["input", "unit-tests", "dual"] + data = make_asset_bytes(name, 3072) + + created = await asset_factory(name, tags, {}, data) + path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / get_asset_filename(created["asset_hash"], ".png") + assert path1.exists() + + # Create a second on-disk copy under the same root but different subfolder + path2 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual_copy" / name + path2.parent.mkdir(parents=True, exist_ok=True) + path2.write_bytes(data) + + # Fast seed so the second path appears (as a seed initially) + await trigger_sync_seed_assets(http, api_base) + + # Deduplication of AssetInfo-s will not happen as first AssetInfo has owner='default' and second has empty owner. + await run_scan_and_wait("input") + + # Remove only one file and sync -> asset should still be healthy (no 'missing') + path1.unlink() + await trigger_sync_seed_assets(http, api_base) + + async with http.get(f"{api_base}/api/assets/{created['id']}") as g1: + d1 = await g1.json() + assert g1.status == 200, d1 + assert "missing" not in set(d1.get("tags", [])), "Should not be missing while one valid path remains" + + # Baseline 'missing' usage count just before last file removal + async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r0: + tags0 = await r0.json() + assert r0.status == 200, tags0 + old_missing = int({t["name"]: t for t in tags0.get("tags", [])}.get("missing", {}).get("count", 0)) + + # Remove the second (last) file and sync -> now we expect 'missing' on this AssetInfo + path2.unlink() + await trigger_sync_seed_assets(http, api_base) + + async with http.get(f"{api_base}/api/assets/{created['id']}") as g2: + d2 = await g2.json() + assert g2.status == 200, d2 + assert "missing" in set(d2.get("tags", [])), "Missing must be set once no valid paths remain" + + # Tag usage for 'missing' increased by exactly 2 (two AssetInfo for one Asset) + async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1: + tags1 = await r1.json() + assert r1.status == 200, tags1 + new_missing = int({t["name"]: t for t in tags1.get("tags", [])}.get("missing", {}).get("count", 0)) + assert new_missing == old_missing + 2 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, +): + """ + Fast pass alone clears 'missing' when size and mtime match exactly: + 1) upload (hashed), record original mtime_ns + 2) delete -> fast pass adds 'missing' + 3) restore same bytes and set mtime back to the original value + 4) run fast pass again -> 'missing' is removed (no slow scan) + """ + scope = f"fastclear-{uuid.uuid4().hex[:6]}" + name = "fastpass_clear.bin" + data = make_asset_bytes(name, 3072) + + a = await asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + base = comfy_tmp_base_dir / root / "unit-tests" / scope + p = base / get_asset_filename(a["asset_hash"], ".bin") + st0 = p.stat() + orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000)) + + # Delete -> fast pass adds 'missing' + p.unlink() + await trigger_sync_seed_assets(http, api_base) + async with http.get(f"{api_base}/api/assets/{aid}") as g1: + d1 = await g1.json() + assert g1.status == 200, d1 + assert "missing" in set(d1.get("tags", [])) + + # Restore same bytes and revert mtime to the original value + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(data) + # set both atime and mtime in ns to ensure exact match + os.utime(p, ns=(orig_mtime_ns, orig_mtime_ns)) + + # Fast pass should clear 'missing' without a scan + await trigger_sync_seed_assets(http, api_base) + async with http.get(f"{api_base}/api/assets/{aid}") as g2: + d2 = await g2.json() + assert g2.status == 200, d2 + assert "missing" not in set(d2.get("tags", [])), "Fast pass should clear 'missing' when size+mtime match" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_fastpass_removes_stale_state_row_no_missing( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + Hashed asset with two states: + - delete one file + - run fast pass only + Expect: + - asset stays healthy (no 'missing') + - stale AssetCacheState row for the deleted path is removed. + We verify this behaviorally by recreating the deleted path and running fast pass again: + a new *seed* AssetInfo is created, which proves the old state row was not reused. + """ + scope = f"stale-{uuid.uuid4().hex[:6]}" + name = "two_states.bin" + data = make_asset_bytes(name, 2048) + + # Upload hashed asset at path1 + a = await asset_factory(name, [root, "unit-tests", scope], {}, data) + base = comfy_tmp_base_dir / root / "unit-tests" / scope + a1_filename = get_asset_filename(a["asset_hash"], ".bin") + p1 = base / a1_filename + assert p1.exists() + + aid = a["id"] + h = a["asset_hash"] + + # Create second state path2, seed+scan to dedupe into the same Asset + p2 = base / "copy" / name + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + await trigger_sync_seed_assets(http, api_base) + await run_scan_and_wait(root) + + # Delete path1 and run fast pass -> no 'missing' and stale state row should be removed + p1.unlink() + await trigger_sync_seed_assets(http, api_base) + async with http.get(f"{api_base}/api/assets/{aid}") as g1: + d1 = await g1.json() + assert g1.status == 200, d1 + assert "missing" not in set(d1.get("tags", [])) + + # Recreate path1 and run fast pass again. + # If the stale state row was removed, a NEW seed AssetInfo will appear for this path. + p1.write_bytes(data) + await trigger_sync_seed_assets(http, api_base) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}"}, + ) as rl: + bl = await rl.json() + assert rl.status == 200, bl + items = bl.get("assets", []) + # one hashed AssetInfo (asset_hash == h) + one seed AssetInfo (asset_hash == null) + hashes = [it.get("asset_hash") for it in items if it.get("name") in (name, a1_filename)] + assert h in hashes + assert any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path" + + # Asset identity still healthy + async with http.head(f"{api_base}/api/assets/hash/{h}") as rh: + assert rh.status == 200 diff --git a/tests-assets/test_crud.py b/tests-assets/test_crud.py new file mode 100644 index 000000000000..f2e4c2699dd7 --- /dev/null +++ b/tests-assets/test_crud.py @@ -0,0 +1,316 @@ +import asyncio +import uuid +from pathlib import Path + +import aiohttp +import pytest +from conftest import get_asset_filename, trigger_sync_seed_assets + + +@pytest.mark.asyncio +async def test_create_from_hash_success( + http: aiohttp.ClientSession, api_base: str, seeded_asset: dict +): + h = seeded_asset["asset_hash"] + payload = { + "hash": h, + "name": "from_hash_ok.safetensors", + "tags": ["models", "checkpoints", "unit-tests", "from-hash"], + "user_metadata": {"k": "v"}, + } + async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r1: + b1 = await r1.json() + assert r1.status == 201, b1 + assert b1["asset_hash"] == h + assert b1["created_new"] is False + aid = b1["id"] + + # Calling again with the same name should return the same AssetInfo id + async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r2: + b2 = await r2.json() + assert r2.status == 201, b2 + assert b2["id"] == aid + + +@pytest.mark.asyncio +async def test_get_and_delete_asset(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # GET detail + async with http.get(f"{api_base}/api/assets/{aid}") as rg: + detail = await rg.json() + assert rg.status == 200, detail + assert detail["id"] == aid + assert "user_metadata" in detail + assert "filename" in detail["user_metadata"] + + # DELETE + async with http.delete(f"{api_base}/api/assets/{aid}") as rd: + assert rd.status == 204 + + # GET again -> 404 + async with http.get(f"{api_base}/api/assets/{aid}") as rg2: + body = await rg2.json() + assert rg2.status == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +@pytest.mark.asyncio +async def test_delete_upon_reference_count( + http: aiohttp.ClientSession, api_base: str, seeded_asset: dict +): + # Create a second reference to the same asset via from-hash + src_hash = seeded_asset["asset_hash"] + payload = { + "hash": src_hash, + "name": "unit_ref_copy.safetensors", + "tags": ["models", "checkpoints", "unit-tests", "del-flow"], + "user_metadata": {"note": "copy"}, + } + async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r2: + copy = await r2.json() + assert r2.status == 201, copy + assert copy["asset_hash"] == src_hash + assert copy["created_new"] is False + + # Delete original reference -> asset identity must remain + aid1 = seeded_asset["id"] + async with http.delete(f"{api_base}/api/assets/{aid1}") as rd1: + assert rd1.status == 204 + + async with http.head(f"{api_base}/api/assets/hash/{src_hash}") as rh1: + assert rh1.status == 200 # identity still present + + # Delete the last reference with default semantics -> identity and cached files removed + aid2 = copy["id"] + async with http.delete(f"{api_base}/api/assets/{aid2}") as rd2: + assert rd2.status == 204 + + async with http.head(f"{api_base}/api/assets/hash/{src_hash}") as rh2: + assert rh2.status == 404 # orphan content removed + + +@pytest.mark.asyncio +async def test_update_asset_fields(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + payload = { + "name": "unit_1_renamed.safetensors", + "tags": ["models", "checkpoints", "unit-tests", "beta"], + "user_metadata": {"purpose": "updated", "epoch": 2}, + } + async with http.put(f"{api_base}/api/assets/{aid}", json=payload) as ru: + body = await ru.json() + assert ru.status == 200, body + assert body["name"] == payload["name"] + assert "beta" in body["tags"] + assert body["user_metadata"]["purpose"] == "updated" + # filename should still be present and normalized by server + assert "filename" in body["user_metadata"] + + +@pytest.mark.asyncio +async def test_head_asset_by_hash(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + h = seeded_asset["asset_hash"] + + # Existing + async with http.head(f"{api_base}/api/assets/hash/{h}") as rh1: + assert rh1.status == 200 + + # Non-existent + async with http.head(f"{api_base}/api/assets/hash/blake3:{'0'*64}") as rh2: + assert rh2.status == 404 + + +@pytest.mark.asyncio +async def test_head_asset_bad_hash_returns_400_and_no_body(http: aiohttp.ClientSession, api_base: str): + # Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload. + # aiohttp exposes an empty body for HEAD, so validate status and that there is no payload. + async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh: + assert rh.status == 400 + body = await rh.read() + assert body == b"" + + +@pytest.mark.asyncio +async def test_delete_nonexistent_returns_404(http: aiohttp.ClientSession, api_base: str): + bogus = str(uuid.uuid4()) + async with http.delete(f"{api_base}/api/assets/{bogus}") as r: + body = await r.json() + assert r.status == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +@pytest.mark.asyncio +async def test_create_from_hash_invalids(http: aiohttp.ClientSession, api_base: str): + # Bad hash algorithm + bad = { + "hash": "sha256:" + "0" * 64, + "name": "x.bin", + "tags": ["models", "checkpoints", "unit-tests"], + } + async with http.post(f"{api_base}/api/assets/from-hash", json=bad) as r1: + b1 = await r1.json() + assert r1.status == 400 + assert b1["error"]["code"] == "INVALID_BODY" + + # Invalid JSON body + async with http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}") as r2: + b2 = await r2.json() + assert r2.status == 400 + assert b2["error"]["code"] == "INVALID_JSON" + + +@pytest.mark.asyncio +async def test_get_update_download_bad_ids(http: aiohttp.ClientSession, api_base: str): + # All endpoints should be not found, as we UUID regex directly in the route definition. + bad_id = "not-a-uuid" + + async with http.get(f"{api_base}/api/assets/{bad_id}") as r1: + assert r1.status == 404 + + async with http.get(f"{api_base}/api/assets/{bad_id}/content") as r3: + assert r3.status == 404 + + +@pytest.mark.asyncio +async def test_update_requires_at_least_one_field(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + async with http.put(f"{api_base}/api/assets/{aid}", json={}) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_concurrent_delete_same_asset_info_single_204( + root: str, + http: aiohttp.ClientSession, + api_base: str, + asset_factory, + make_asset_bytes, +): + """ + Many concurrent DELETE for the same AssetInfo should result in: + - exactly one 204 No Content (the one that actually deleted) + - all others 404 Not Found (row already gone) + """ + scope = f"conc-del-{uuid.uuid4().hex[:6]}" + name = "to_delete.bin" + data = make_asset_bytes(name, 1536) + + created = await asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = created["id"] + + # Hit the same endpoint N times in parallel. + n_tests = 4 + url = f"{api_base}/api/assets/{aid}?delete_content=false" + tasks = [asyncio.create_task(http.delete(url)) for _ in range(n_tests)] + responses = await asyncio.gather(*tasks) + + statuses = [r.status for r in responses] + # Drain bodies to close connections (optional but avoids warnings). + await asyncio.gather(*[r.read() for r in responses]) + + # Exactly one actual delete, the rest must be 404 + assert statuses.count(204) == 1, f"Expected exactly one 204; got: {statuses}" + assert statuses.count(404) == n_tests - 1, f"Expected {n_tests-1} 404; got: {statuses}" + + # The resource must be gone. + async with http.get(f"{api_base}/api/assets/{aid}") as rg: + assert rg.status == 404 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_metadata_filename_is_set_for_seed_asset_without_hash( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, +): + """Seed ingest (no hash yet) must compute user_metadata['filename'] immediately.""" + scope = f"seedmeta-{uuid.uuid4().hex[:6]}" + name = "seed_filename.bin" + + base = comfy_tmp_base_dir / root / "unit-tests" / scope / "a" / "b" + base.mkdir(parents=True, exist_ok=True) + fp = base / name + fp.write_bytes(b"Z" * 2048) + + await trigger_sync_seed_assets(http, api_base) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": name}, + ) as r1: + body = await r1.json() + assert r1.status == 200, body + matches = [a for a in body.get("assets", []) if a.get("name") == name] + assert matches, "Seed asset should be visible after sync" + assert matches[0].get("asset_hash") is None # still a seed + aid = matches[0]["id"] + + async with http.get(f"{api_base}/api/assets/{aid}") as r2: + detail = await r2.json() + assert r2.status == 200, detail + filename = (detail.get("user_metadata") or {}).get("filename") + expected = str(fp.relative_to(comfy_tmp_base_dir / root)).replace("\\", "/") + assert filename == expected, f"expected filename={expected}, got {filename!r}" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_metadata_filename_computed_and_updated_on_retarget( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + 1) Ingest under {root}/unit-tests//a/b/ -> filename reflects relative path. + 2) Retarget by copying to {root}/unit-tests//x/, remove old file, + run fast pass + scan -> filename updates to new relative path. + """ + scope = f"meta-fn-{uuid.uuid4().hex[:6]}" + name1 = "compute_metadata_filename.png" + name2 = "compute_changed_metadata_filename.png" + data = make_asset_bytes(name1, 2100) + + # Upload into nested path a/b + a = await asset_factory(name1, [root, "unit-tests", scope, "a", "b"], {}, data) + aid = a["id"] + + root_base = comfy_tmp_base_dir / root + p1 = (root_base / "unit-tests" / scope / "a" / "b" / get_asset_filename(a["asset_hash"], ".png")) + assert p1.exists() + + # filename at ingest should be the path relative to root + rel1 = str(p1.relative_to(root_base)).replace("\\", "/") + async with http.get(f"{api_base}/api/assets/{aid}") as g1: + d1 = await g1.json() + assert g1.status == 200, d1 + fn1 = d1["user_metadata"].get("filename") + assert fn1 == rel1 + + # Retarget: copy to x/, remove old, then sync+scan + p2 = root_base / "unit-tests" / scope / "x" / name2 + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + if p1.exists(): + p1.unlink() + + await trigger_sync_seed_assets(http, api_base) # seed the new path + await run_scan_and_wait(root) # verify/hash and reconcile + + # filename should now point at x/ + rel2 = str(p2.relative_to(root_base)).replace("\\", "/") + async with http.get(f"{api_base}/api/assets/{aid}") as g2: + d2 = await g2.json() + assert g2.status == 200, d2 + fn2 = d2["user_metadata"].get("filename") + assert fn2 == rel2 diff --git a/tests-assets/test_downloads.py b/tests-assets/test_downloads.py new file mode 100644 index 000000000000..181aad6f60fa --- /dev/null +++ b/tests-assets/test_downloads.py @@ -0,0 +1,168 @@ +import asyncio +import uuid +from datetime import datetime +from pathlib import Path +from typing import Optional + +import aiohttp +import pytest +from conftest import get_asset_filename, trigger_sync_seed_assets + + +@pytest.mark.asyncio +async def test_download_attachment_and_inline(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # default attachment + async with http.get(f"{api_base}/api/assets/{aid}/content") as r1: + data = await r1.read() + assert r1.status == 200 + cd = r1.headers.get("Content-Disposition", "") + assert "attachment" in cd + assert data and len(data) == 4096 + + # inline requested + async with http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline") as r2: + await r2.read() + assert r2.status == 200 + cd2 = r2.headers.get("Content-Disposition", "") + assert "inline" in cd2 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_download_chooses_existing_state_and_updates_access_time( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + Hashed asset with two state paths: if the first one disappears, + GET /content still serves from the remaining path and bumps last_access_time. + """ + scope = f"dl-first-{uuid.uuid4().hex[:6]}" + name = "first_existing_state.bin" + data = make_asset_bytes(name, 3072) + + # Upload -> path1 + a = await asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + + base = comfy_tmp_base_dir / root / "unit-tests" / scope + path1 = base / get_asset_filename(a["asset_hash"], ".bin") + assert path1.exists() + + # Seed path2 by copying, then scan to dedupe into a second state + path2 = base / "alt" / name + path2.parent.mkdir(parents=True, exist_ok=True) + path2.write_bytes(data) + await trigger_sync_seed_assets(http, api_base) + await run_scan_and_wait(root) + + # Remove path1 so server must fall back to path2 + path1.unlink() + + # last_access_time before + async with http.get(f"{api_base}/api/assets/{aid}") as rg0: + d0 = await rg0.json() + assert rg0.status == 200, d0 + ts0 = d0.get("last_access_time") + + await asyncio.sleep(0.05) + async with http.get(f"{api_base}/api/assets/{aid}/content") as r: + blob = await r.read() + assert r.status == 200 + assert blob == data # must serve from the surviving state (same bytes) + + async with http.get(f"{api_base}/api/assets/{aid}") as rg1: + d1 = await rg1.json() + assert rg1.status == 200, d1 + ts1 = d1.get("last_access_time") + + def _parse_iso8601(s: Optional[str]) -> Optional[float]: + if not s: + return None + s = s[:-1] if s.endswith("Z") else s + return datetime.fromisoformat(s).timestamp() + + t0 = _parse_iso8601(ts0) + t1 = _parse_iso8601(ts1) + assert t1 is not None + if t0 is not None: + assert t1 > t0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True) +async def test_download_missing_file_returns_404( + http: aiohttp.ClientSession, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict +): + # Remove the underlying file then attempt download. + # We initialize fixture without additional tags to know exactly the asset file path. + try: + aid = seeded_asset["id"] + async with http.get(f"{api_base}/api/assets/{aid}") as rg: + detail = await rg.json() + assert rg.status == 200 + asset_filename = get_asset_filename(detail["asset_hash"], ".safetensors") + abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / asset_filename + assert abs_path.exists() + abs_path.unlink() + + async with http.get(f"{api_base}/api/assets/{aid}/content") as r2: + assert r2.status == 404 + body = await r2.json() + assert body["error"]["code"] == "FILE_NOT_FOUND" + finally: + # We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually. + async with http.delete(f"{api_base}/api/assets/{aid}") as dr: + await dr.read() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_download_404_if_all_states_missing( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """Multi-state asset: after the last remaining on-disk file is removed, download must return 404.""" + scope = f"dl-404-{uuid.uuid4().hex[:6]}" + name = "missing_all_states.bin" + data = make_asset_bytes(name, 2048) + + # Upload -> path1 + a = await asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + + base = comfy_tmp_base_dir / root / "unit-tests" / scope + p1 = base / get_asset_filename(a["asset_hash"], ".bin") + assert p1.exists() + + # Seed a second state and dedupe + p2 = base / "copy" / name + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + await trigger_sync_seed_assets(http, api_base) + await run_scan_and_wait(root) + + # Remove first file -> download should still work via the second state + p1.unlink() + async with http.get(f"{api_base}/api/assets/{aid}/content") as ok1: + b1 = await ok1.read() + assert ok1.status == 200 and b1 == data + + # Remove the last file -> download must 404 + p2.unlink() + async with http.get(f"{api_base}/api/assets/{aid}/content") as r2: + body = await r2.json() + assert r2.status == 404 + assert body["error"]["code"] == "FILE_NOT_FOUND" diff --git a/tests-assets/test_list_filter.py b/tests-assets/test_list_filter.py new file mode 100644 index 000000000000..835de0367058 --- /dev/null +++ b/tests-assets/test_list_filter.py @@ -0,0 +1,337 @@ +import asyncio +import uuid + +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_list_assets_paging_and_sort(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes): + names = ["a1_u.safetensors", "a2_u.safetensors", "a3_u.safetensors"] + for n in names: + await asset_factory( + n, + ["models", "checkpoints", "unit-tests", "paging"], + {"epoch": 1}, + make_asset_bytes(n, size=2048), + ) + + # name ascending for stable order + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "0"}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + got1 = [a["name"] for a in b1["assets"]] + assert got1 == sorted(names)[:2] + assert b1["has_more"] is True + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "2"}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 + got2 = [a["name"] for a in b2["assets"]] + assert got2 == sorted(names)[2:] + assert b2["has_more"] is False + + +@pytest.mark.asyncio +async def test_list_assets_include_exclude_and_name_contains(http: aiohttp.ClientSession, api_base: str, asset_factory): + a = await asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024) + b = await asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,alpha", "exclude_tags": "beta", "limit": "50"}, + ) as r: + body = await r.json() + assert r.status == 200 + names = [x["name"] for x in body["assets"]] + assert a["name"] in names + assert b["name"] not in names + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests", "name_contains": "inc_"}, + ) as r2: + body2 = await r2.json() + assert r2.status == 200 + names2 = [x["name"] for x in body2["assets"]] + assert a["name"] in names2 + assert b["name"] in names2 + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "non-existing-tag"}, + ) as r2: + body3 = await r2.json() + assert r2.status == 200 + assert not body3["assets"] + + +@pytest.mark.asyncio +async def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-size"] + n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors" + await asset_factory(n1, t, {}, make_asset_bytes(n1, 1024)) + await asset_factory(n2, t, {}, make_asset_bytes(n2, 2048)) + await asset_factory(n3, t, {}, make_asset_bytes(n3, 3072)) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "asc"}, + ) as r1: + b1 = await r1.json() + names = [a["name"] for a in b1["assets"]] + assert names[:3] == [n1, n2, n3] + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "desc"}, + ) as r2: + b2 = await r2.json() + names2 = [a["name"] for a in b2["assets"]] + assert names2[:3] == [n3, n2, n1] + + + +@pytest.mark.asyncio +async def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-upd"] + a1 = await asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200)) + a2 = await asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200)) + + # Rename the second asset to bump updated_at + async with http.put(f"{api_base}/api/assets/{a2['id']}", json={"name": "upd_b_renamed.safetensors"}) as rp: + upd = await rp.json() + assert rp.status == 200, upd + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-upd", "sort": "updated_at", "order": "desc"}, + ) as r: + body = await r.json() + assert r.status == 200 + names = [x["name"] for x in body["assets"]] + assert names[0] == "upd_b_renamed.safetensors" + assert a1["name"] in names + + + +@pytest.mark.asyncio +async def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-access"] + await asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100)) + await asyncio.sleep(0.02) + a2 = await asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100)) + + # Touch last_access_time of b by downloading its content + await asyncio.sleep(0.02) + async with http.get(f"{api_base}/api/assets/{a2['id']}/content") as dl: + assert dl.status == 200 + await dl.read() + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-access", "sort": "last_access_time", "order": "desc"}, + ) as r: + body = await r.json() + assert r.status == 200 + names = [x["name"] for x in body["assets"]] + assert names[0] == a2["name"] + + +@pytest.mark.asyncio +async def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-include"] + a = await asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva")) + await asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb")) + + # CSV + case-insensitive + async with http.get( + api_base + "/api/assets", + params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a["name"] in names1 + assert not any("beta" in x for x in names1) + + # Repeated query params for include_tags + params_multi = [ + ("include_tags", "unit-tests"), + ("include_tags", "lf-include"), + ("include_tags", "alpha"), + ] + async with http.get(api_base + "/api/assets", params=params_multi) as r2: + b2 = await r2.json() + assert r2.status == 200 + names2 = [x["name"] for x in b2["assets"]] + assert a["name"] in names2 + assert not any("beta" in x for x in names2) + + # Duplicates and spaces in CSV + async with http.get( + api_base + "/api/assets", + params={"include_tags": " unit-tests , lf-include , alpha , alpha "}, + ) as r3: + b3 = await r3.json() + assert r3.status == 200 + names3 = [x["name"] for x in b3["assets"]] + assert a["name"] in names3 + + +@pytest.mark.asyncio +async def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-exclude"] + a = await asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900)) + await asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900)) + + # Exclude uppercase should work + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a["name"] in names1 + # Repeated excludes with duplicates + params_multi = [ + ("include_tags", "unit-tests"), + ("include_tags", "lf-exclude"), + ("exclude_tags", "beta"), + ("exclude_tags", "beta"), + ] + async with http.get(api_base + "/api/assets", params=params_multi) as r2: + b2 = await r2.json() + assert r2.status == 200 + names2 = [x["name"] for x in b2["assets"]] + assert all("beta" not in x for x in names2) + + +@pytest.mark.asyncio +async def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-name"] + a1 = await asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800)) + a2 = await asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800)) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": "casemix"}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a1["name"] in names1 + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": ".SAFE"}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 + names2 = [x["name"] for x in b2["assets"]] + assert a1["name"] in names2 + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": "case-"}, + ) as r3: + b3 = await r3.json() + assert r3.status == 200 + names3 = [x["name"] for x in b3["assets"]] + assert a2["name"] in names3 + + +@pytest.mark.asyncio +async def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"] + await asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600)) + await asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600)) + await asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600)) + + # Offset far beyond total + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-pagelimits", "limit": "2", "offset": "10"}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + assert not b1["assets"] + assert b1["has_more"] is False + + # Boundary large limit (<=500 is valid) + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-pagelimits", "limit": "500"}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 + assert len(b2["assets"]) == 3 + assert b2["has_more"] is False + + +@pytest.mark.asyncio +async def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base): + async with http.get(api_base + "/api/assets", params={"offset": "-1"}) as r1: + b1 = await r1.json() + assert r1.status == 400 + assert b1["error"]["code"] == "INVALID_QUERY" + + async with http.get(api_base + "/api/assets", params={"limit": "abc"}) as r2: + b2 = await r2.json() + assert r2.status == 400 + assert b2["error"]["code"] == "INVALID_QUERY" + + +@pytest.mark.asyncio +async def test_list_assets_invalid_query_rejected(http: aiohttp.ClientSession, api_base: str): + # limit too small + async with http.get(api_base + "/api/assets", params={"limit": "0"}) as r1: + b1 = await r1.json() + assert r1.status == 400 + assert b1["error"]["code"] == "INVALID_QUERY" + + # bad metadata JSON + async with http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}) as r2: + b2 = await r2.json() + assert r2.status == 400 + assert b2["error"]["code"] == "INVALID_QUERY" + + +@pytest.mark.asyncio +async def test_list_assets_name_contains_literal_underscore( + http, + api_base, + asset_factory, + make_asset_bytes, +): + """'name_contains' must treat '_' literally, not as a SQL wildcard. + We create: + - foo_bar.safetensors (should match) + - fooxbar.safetensors (must NOT match if '_' is escaped) + - foobar.safetensors (must NOT match) + """ + scope = f"lf-underscore-{uuid.uuid4().hex[:6]}" + tags = ["models", "checkpoints", "unit-tests", scope] + + a = await asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700)) + b = await asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700)) + c = await asset_factory("foobar.safetensors", tags, {}, make_asset_bytes("c", 700)) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": "foo_bar"}, + ) as r: + body = await r.json() + assert r.status == 200, body + names = [x["name"] for x in body["assets"]] + assert a["name"] in names, f"Expected literal underscore match to include {a['name']}" + assert b["name"] not in names, "Underscore must be escaped — should not match 'fooxbar'" + assert c["name"] not in names, "Underscore must be escaped — should not match 'foobar'" + assert body["total"] == 1 diff --git a/tests-assets/test_metadata_filters.py b/tests-assets/test_metadata_filters.py new file mode 100644 index 000000000000..4c4c8f946314 --- /dev/null +++ b/tests-assets/test_metadata_filters.py @@ -0,0 +1,387 @@ +import json + +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_meta_and_across_keys_and_types( + http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes +): + name = "mf_and_mix.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-and"] + meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} + await asset_factory(name, tags, meta, make_asset_bytes(name, 4096)) + + # All keys must match (AND semantics) + f_ok = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-and", + "metadata_filter": json.dumps(f_ok), + }, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names = [a["name"] for a in b1["assets"]] + assert name in names + + # One key mismatched -> no result + f_bad = {"purpose": "mix", "epoch": 2, "active": True} + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-and", + "metadata_filter": json.dumps(f_bad), + }, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 + assert not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes): + name = "mf_types.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-types"] + meta = {"epoch": 1, "active": True} + await asset_factory(name, tags, meta, make_asset_bytes(name)) + + # int filter matches numeric + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"epoch": 1}), + }, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # string "1" must NOT match numeric 1 + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"epoch": "1"}), + }, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + # bool True matches, string "true" must NOT match + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"active": True}), + }, + ) as r3: + b3 = await r3.json() + assert r3.status == 200 and any(a["name"] == name for a in b3["assets"]) + + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"active": "true"}), + }, + ) as r4: + b4 = await r4.json() + assert r4.status == 200 and not b4["assets"] + + +@pytest.mark.asyncio +async def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes): + name = "mf_list_scalars.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-list"] + meta = {"flags": ["red", "green"]} + await asset_factory(name, tags, meta, make_asset_bytes(name, 3000)) + + # Any-of should match because "green" is present + filt_ok = {"flags": ["blue", "green"]} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_ok)}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # None of provided flags present -> no match + filt_miss = {"flags": ["blue", "yellow"]} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_miss)}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + # Duplicates in list should not break matching + filt_dup = {"flags": ["green", "green", "green"]} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_dup)}, + ) as r3: + b3 = await r3.json() + assert r3.status == 200 and any(a["name"] == name for a in b3["assets"]) + + +@pytest.mark.asyncio +async def test_meta_none_semantics_missing_or_null_and_any_of_with_none( + http, api_base, asset_factory, make_asset_bytes +): + # a1: key missing; a2: explicit null; a3: concrete value + t = ["models", "checkpoints", "unit-tests", "mf-none"] + a1 = await asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1")) + a2 = await asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2")) + a3 = await asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3")) + + # Filter {maybe: None} must match a1 and a2, not a3 + filt = {"maybe": None} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt), "sort": "name"}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + got = [a["name"] for a in b1["assets"]] + assert a1["name"] in got and a2["name"] in got and a3["name"] not in got + + # Any-of with None should include missing/null plus value matches + filt_any = {"maybe": [None, "x"]} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt_any), "sort": "name"}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 + got2 = [a["name"] for a in b2["assets"]] + assert a1["name"] in got2 and a2["name"] in got2 and a3["name"] in got2 + + +@pytest.mark.asyncio +async def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes): + name = "mf_nested_json.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-nested"] + cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}} + await asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200)) + + # Exact JSON object equality (same structure) + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-nested", + "metadata_filter": json.dumps({"config": cfg}), + }, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # Different JSON object should not match + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-nested", + "metadata_filter": json.dumps({"config": {"optimizer": "sgd"}}), + }, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes): + name = "mf_list_objects.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-objlist"] + transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}] + await asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048)) + + # Any-of for list of objects should match when one element equals the filter object + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-objlist", + "metadata_filter": json.dumps({"transforms": {"type": "flip", "p": 0.5}}), + }, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # Non-matching object -> no match + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-objlist", + "metadata_filter": json.dumps({"transforms": {"type": "rotate", "deg": 90}}), + }, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes): + name = "mf_keys_unicode.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-keys"] + meta = { + "weird.key": "v1", + "path/like": 7, + "with:colon": True, + "ключ": "значение", + "emoji": "🐍", + } + await asset_factory(name, tags, meta, make_asset_bytes(name, 1500)) + + # Match all the special keys + filt = {"weird.key": "v1", "path/like": 7, "with:colon": True, "emoji": "🐍"} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps(filt)}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # Unicode key match + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps({"ключ": "значение"})}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and any(a["name"] == name for a in b2["assets"]) + + +@pytest.mark.asyncio +async def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"] + a0 = await asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025)) + a1 = await asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026)) + + # count == 0 must match only a0 + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"count": 0})}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names1 = [a["name"] for a in b1["assets"]] + assert a0["name"] in names1 and a1["name"] not in names1 + + # Any-of list of booleans: True matches second asset + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"choices": True})}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and any(a["name"] == a1["name"] for a in b2["assets"]) + + +@pytest.mark.asyncio +async def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes): + name = "mf_mixed_list.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-mixed"] + meta = {"mix": ["1", 1, True, None]} + await asset_factory(name, tags, meta, make_asset_bytes(name, 1999)) + + # Should match because 1 is present + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": [2, 1]})}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # Should NOT match for False + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": False})}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes): + # Use a unique scope tag to avoid interference + t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"] + x = await asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua")) + y = await asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub")) + + # Filtering by unknown key with None should return both (missing key OR null) + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": None})}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names = {a["name"] for a in b1["assets"]} + assert x["name"] in names and y["name"] in names + + # Filtering by unknown key with concrete value should return none + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": "x"})}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_factory, make_asset_bytes): + # alpha matches epoch=1; beta has epoch=2 + a = await asset_factory( + "mf_tag_alpha.safetensors", + ["models", "checkpoints", "unit-tests", "mf-tag", "alpha"], + {"epoch": 1}, + make_asset_bytes("alpha"), + ) + b = await asset_factory( + "mf_tag_beta.safetensors", + ["models", "checkpoints", "unit-tests", "mf-tag", "beta"], + {"epoch": 2}, + make_asset_bytes("beta"), + ) + + params = { + "include_tags": "unit-tests,mf-tag,alpha", + "exclude_tags": "beta", + "name_contains": "mf_tag_", + "metadata_filter": json.dumps({"epoch": 1}), + } + async with http.get(api_base + "/api/assets", params=params) as r: + body = await r.json() + assert r.status == 200 + names = [x["name"] for x in body["assets"]] + assert a["name"] in names + assert b["name"] not in names + + +@pytest.mark.asyncio +async def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes): + # Three assets in same scope with different sizes and a common filter key + t = ["models", "checkpoints", "unit-tests", "mf-sort"] + n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors" + await asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024)) + await asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048)) + await asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072)) + + # Sort by size ascending with paging + q = { + "include_tags": "unit-tests,mf-sort", + "metadata_filter": json.dumps({"group": "g"}), + "sort": "size", "order": "asc", "limit": "2", + } + async with http.get(api_base + "/api/assets", params=q) as r1: + b1 = await r1.json() + assert r1.status == 200 + got1 = [a["name"] for a in b1["assets"]] + assert got1 == [n1, n2] + assert b1["has_more"] is True + + q2 = {**q, "offset": "2"} + async with http.get(api_base + "/api/assets", params=q2) as r2: + b2 = await r2.json() + assert r2.status == 200 + got2 = [a["name"] for a in b2["assets"]] + assert got2 == [n3] + assert b2["has_more"] is False diff --git a/tests-assets/test_scans.py b/tests-assets/test_scans.py new file mode 100644 index 000000000000..e82ae5f6d264 --- /dev/null +++ b/tests-assets/test_scans.py @@ -0,0 +1,510 @@ +import asyncio +import os +import uuid +from pathlib import Path + +import aiohttp +import pytest +from conftest import get_asset_filename, trigger_sync_seed_assets + + +def _base_for(root: str, comfy_tmp_base_dir: Path) -> Path: + assert root in ("input", "output") + return comfy_tmp_base_dir / root + + +def _mkbytes(label: str, size: int) -> bytes: + seed = sum(label.encode("utf-8")) % 251 + return bytes((i * 31 + seed) % 256 for i in range(size)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_scan_schedule_idempotent_while_running( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + run_scan_and_wait, +): + """Idempotent schedule while running.""" + scope = f"idem-{uuid.uuid4().hex[:6]}" + base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope + base.mkdir(parents=True, exist_ok=True) + + # Create several seed files (non-zero) to ensure the scan runs long enough + for i in range(8): + (base / f"f{i}.bin").write_bytes(_mkbytes(f"{scope}-{i}", 2 * 1024 * 1024)) # ~2 MiB each + + # Seed -> states with hash=NULL + await trigger_sync_seed_assets(http, api_base) + + # Schedule once + async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r1: + b1 = await r1.json() + assert r1.status == 202, b1 + scans1 = {s["root"]: s for s in b1.get("scans", [])} + s1 = scans1.get(root) + assert s1 and s1["status"] in {"scheduled", "running"} + sid1 = s1["scan_id"] + + # Schedule again immediately — must return the same scan entry (no new worker) + async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r2: + b2 = await r2.json() + assert r2.status == 202, b2 + scans2 = {s["root"]: s for s in b2.get("scans", [])} + s2 = scans2.get(root) + assert s2 and s2["scan_id"] == sid1 + + # Filtered GET must show exactly one scan for this root + async with http.get(api_base + "/api/assets/scan", params={"root": root}) as gs: + bs = await gs.json() + assert gs.status == 200, bs + scans = bs.get("scans", []) + assert len(scans) == 1 and scans[0]["scan_id"] == sid1 + + # Let it finish to avoid cross-test interference + await run_scan_and_wait(root) + + +@pytest.mark.asyncio +async def test_scan_status_filter_by_root_and_file_errors( + http, + api_base: str, + comfy_tmp_base_dir: Path, + run_scan_and_wait, + asset_factory, +): + """Filtering get scan status by root (schedule for both input and output) + file_errors presence.""" + # Create one hashed asset in input under a dir we will chmod to 000 to force PermissionError in reconcile stage + in_scope = f"filter-in-{uuid.uuid4().hex[:6]}" + protected_dir = _base_for("input", comfy_tmp_base_dir) / "unit-tests" / in_scope / "deny" + protected_dir.mkdir(parents=True, exist_ok=True) + name_in = "protected.bin" + + data = b"A" * 4096 + await asset_factory(name_in, ["input", "unit-tests", in_scope, "deny"], {}, data) + try: + os.chmod(protected_dir, 0o000) + + # Also schedule a scan for output root (no errors there) + out_scope = f"filter-out-{uuid.uuid4().hex[:6]}" + out_dir = _base_for("output", comfy_tmp_base_dir) / "unit-tests" / out_scope + out_dir.mkdir(parents=True, exist_ok=True) + (out_dir / "ok.bin").write_bytes(b"B" * 1024) + await trigger_sync_seed_assets(http, api_base) # seed output file + + # Schedule both roots + async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": ["input"]}) as r_in: + assert r_in.status == 202 + async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": ["output"]}) as r_out: + assert r_out.status == 202 + + # Wait both to complete, input last (we want its errors) + await run_scan_and_wait("output") + await run_scan_and_wait("input") + + # Filter by root=input: only input scan listed and must have file_errors + async with http.get(api_base + "/api/assets/scan", params={"root": "input"}) as gs: + body = await gs.json() + assert gs.status == 200, body + scans = body.get("scans", []) + assert len(scans) == 1 + errs = scans[0].get("file_errors", []) + # Must contain at least one error with a message + assert errs and any(e.get("message") for e in errs) + finally: + # Restore perms so cleanup can remove files/dirs + try: + os.chmod(protected_dir, 0o755) + except Exception: + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +@pytest.mark.skipif(os.name == "nt", reason="Permission-based file_errors are unreliable on Windows") +async def test_scan_records_file_errors_permission_denied( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + run_scan_and_wait, +): + """file_errors recording (permission denied) for input/output""" + scope = f"errs-{uuid.uuid4().hex[:6]}" + deny_dir = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / "deny" + deny_dir.mkdir(parents=True, exist_ok=True) + name = "deny.bin" + + a1 = await asset_factory(name, [root, "unit-tests", scope, "deny"], {}, b"X" * 2048) + asset_filename = get_asset_filename(a1["asset_hash"], ".bin") + try: + os.chmod(deny_dir, 0o000) + async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r: + assert r.status == 202 + await run_scan_and_wait(root) + + async with http.get(api_base + "/api/assets/scan", params={"root": root}) as gs: + body = await gs.json() + assert gs.status == 200, body + scans = body.get("scans", []) + assert len(scans) == 1 + errs = scans[0].get("file_errors", []) + # Should contain at least one PermissionError-like record + assert errs + assert any(e.get("path", "").endswith(asset_filename) and e.get("message") for e in errs) + finally: + try: + os.chmod(deny_dir, 0o755) + except Exception: + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_missing_tag_created_and_visible_in_tags( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, +): + """Missing tag appears in tags list and increments count (input/output)""" + # Baseline count of 'missing' tag (may be absent) + async with http.get(api_base + "/api/tags", params={"limit": "1000"}) as r0: + t0 = await r0.json() + assert r0.status == 200, t0 + byname = {t["name"]: t for t in t0.get("tags", [])} + old_count = int(byname.get("missing", {}).get("count", 0)) + + scope = f"miss-{uuid.uuid4().hex[:6]}" + name = "missing_me.bin" + created = await asset_factory(name, [root, "unit-tests", scope], {}, b"Y" * 4096) + + # Remove the only file and trigger fast pass + p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(created["asset_hash"], ".bin") + assert p.exists() + p.unlink() + await trigger_sync_seed_assets(http, api_base) + + # Asset has 'missing' tag + async with http.get(f"{api_base}/api/assets/{created['id']}") as g1: + d1 = await g1.json() + assert g1.status == 200, d1 + assert "missing" in set(d1.get("tags", [])) + + # Tag list now contains 'missing' with increased count + async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1: + t1 = await r1.json() + assert r1.status == 200, t1 + byname1 = {t["name"]: t for t in t1.get("tags", [])} + assert "missing" in byname1 + assert int(byname1["missing"]["count"]) >= old_count + 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_missing_reapplies_after_manual_removal( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, +): + """Manual removal of 'missing' does not block automatic re-apply (input/output)""" + scope = f"reapply-{uuid.uuid4().hex[:6]}" + name = "reapply.bin" + created = await asset_factory(name, [root, "unit-tests", scope], {}, b"Z" * 1024) + + # Make it missing + p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(created["asset_hash"], ".bin") + p.unlink() + await trigger_sync_seed_assets(http, api_base) + + # Remove the 'missing' tag manually + async with http.delete(f"{api_base}/api/assets/{created['id']}/tags", json={"tags": ["missing"]}) as rdel: + b = await rdel.json() + assert rdel.status == 200, b + assert "missing" in set(b.get("removed", [])) + + # Next sync must re-add it + await trigger_sync_seed_assets(http, api_base) + async with http.get(f"{api_base}/api/assets/{created['id']}") as g2: + d2 = await g2.json() + assert g2.status == 200, d2 + assert "missing" in set(d2.get("tags", [])) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_delete_one_asset_info_of_missing_asset_keeps_identity( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, +): + """Delete one AssetInfo of a missing asset while another exists (input/output)""" + scope = f"twoinfos-{uuid.uuid4().hex[:6]}" + name = "twoinfos.bin" + a1 = await asset_factory(name, [root, "unit-tests", scope], {}, b"W" * 2048) + + # Second AssetInfo for the same content under same root (different name to avoid collision) + a2 = await asset_factory("copy_" + name, [root, "unit-tests", scope], {}, b"W" * 2048) + + # Remove file of the first (both point to the same Asset, but we know on-disk path name for a1) + p1 = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(a1["asset_hash"], ".bin") + p1.unlink() + await trigger_sync_seed_assets(http, api_base) + + # Both infos should be marked missing + async with http.get(f"{api_base}/api/assets/{a1['id']}") as g1: + d1 = await g1.json() + assert "missing" in set(d1.get("tags", [])) + async with http.get(f"{api_base}/api/assets/{a2['id']}") as g2: + d2 = await g2.json() + assert "missing" in set(d2.get("tags", [])) + + # Delete one info + async with http.delete(f"{api_base}/api/assets/{a1['id']}") as rd: + assert rd.status == 204 + + # Asset identity still exists (by hash) + h = a1["asset_hash"] + async with http.head(f"{api_base}/api/assets/hash/{h}") as rh: + assert rh.status == 200 + + # Remaining info still reflects 'missing' + async with http.get(f"{api_base}/api/assets/{a2['id']}") as g3: + d3 = await g3.json() + assert g3.status == 200 and "missing" in set(d3.get("tags", [])) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("keep_root", ["input", "output"]) +async def test_delete_last_asset_info_false_keeps_asset_and_states_multiroot( + keep_root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + make_asset_bytes, + asset_factory, +): + """Delete last AssetInfo with delete_content_if_orphan=false keeps asset and the underlying on-disk content.""" + other_root = "output" if keep_root == "input" else "input" + scope = f"delfalse-{uuid.uuid4().hex[:6]}" + data = make_asset_bytes(scope, 3072) + + # First upload creates the physical file + a1 = await asset_factory("keep1.bin", [keep_root, "unit-tests", scope], {}, data) + # Second upload (other root) is deduped to the same content; no new file on disk + a2 = await asset_factory("keep2.bin", [other_root, "unit-tests", scope], {}, data) + + h = a1["asset_hash"] + p1 = _base_for(keep_root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(h, ".bin") + + # De-dup semantics: only the first physical file exists + assert p1.exists(), "Expected the first physical file to exist" + + # Delete both AssetInfos; keep content on the very last delete + async with http.delete(f"{api_base}/api/assets/{a2['id']}") as rfirst: + assert rfirst.status == 204 + async with http.delete(f"{api_base}/api/assets/{a1['id']}?delete_content=false") as rlast: + assert rlast.status == 204 + + # Asset identity remains and physical content is still present + async with http.head(f"{api_base}/api/assets/hash/{h}") as rh: + assert rh.status == 200 + assert p1.exists(), "Content file should remain after keep-content delete" + + # Cleanup: re-create a reference by hash and then delete to purge content + payload = { + "hash": h, + "name": "cleanup.bin", + "tags": [keep_root, "unit-tests", scope, "cleanup"], + "user_metadata": {}, + } + async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as rfh: + ref = await rfh.json() + assert rfh.status == 201, ref + cid = ref["id"] + async with http.delete(f"{api_base}/api/assets/{cid}") as rdel: + assert rdel.status == 204 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_sync_seed_ignores_zero_byte_files( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, +): + scope = f"zero-{uuid.uuid4().hex[:6]}" + base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope + base.mkdir(parents=True, exist_ok=True) + z = base / "empty.dat" + z.write_bytes(b"") # zero bytes + + await trigger_sync_seed_assets(http, api_base) + + # No AssetInfo created for this zero-byte file + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests," + scope, "name_contains": "empty.dat"}, + ) as r: + body = await r.json() + assert r.status == 200, body + assert not [a for a in body.get("assets", []) if a.get("name") == "empty.dat"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_sync_seed_idempotency( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, +): + scope = f"idemseed-{uuid.uuid4().hex[:6]}" + base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope + base.mkdir(parents=True, exist_ok=True) + files = [f"f{i}.dat" for i in range(3)] + for i, n in enumerate(files): + (base / n).write_bytes(_mkbytes(n, 1500 + i * 10)) + + await trigger_sync_seed_assets(http, api_base) + async with http.get(api_base + "/api/assets", params={"include_tags": "unit-tests," + scope}) as r1: + b1 = await r1.json() + assert r1.status == 200, b1 + c1 = len(b1.get("assets", [])) + + # Seed again -> count must stay the same + await trigger_sync_seed_assets(http, api_base) + async with http.get(api_base + "/api/assets", params={"include_tags": "unit-tests," + scope}) as r2: + b2 = await r2.json() + assert r2.status == 200, b2 + c2 = len(b2.get("assets", [])) + assert c1 == c2 == len(files) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_sync_seed_nested_dirs_produce_parent_tags( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, +): + scope = f"nest-{uuid.uuid4().hex[:6]}" + # nested: unit-tests / scope / a / b / c / deep.txt + deep_dir = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / "a" / "b" / "c" + deep_dir.mkdir(parents=True, exist_ok=True) + (deep_dir / "deep.txt").write_bytes(scope.encode()) + + await trigger_sync_seed_assets(http, api_base) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": "deep.txt"}, + ) as r: + body = await r.json() + assert r.status == 200, body + assets = body.get("assets", []) + assert assets, "seeded asset not found" + tags = set(assets[0].get("tags", [])) + # Must include all parent parts as tags + the root + for must in {root, "unit-tests", scope, "a", "b", "c"}: + assert must in tags + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_concurrent_seed_hashing_same_file_no_dupes( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + run_scan_and_wait, +): + """ + Create a single seed file, then schedule two scans back-to-back. + Expect: no duplicate AssetInfos, a single hashed asset, and no scan failure. + """ + scope = f"conc-seed-{uuid.uuid4().hex[:6]}" + name = "seed_concurrent.bin" + + base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope + base.mkdir(parents=True, exist_ok=True) + (base / name).write_bytes(b"Z" * 2048) + + await trigger_sync_seed_assets(http, api_base) + + s1, s2 = await asyncio.gather( + http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}), + http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}), + ) + await s1.read() + await s2.read() + assert s1.status in (200, 202) + assert s2.status in (200, 202) + + await run_scan_and_wait(root) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": name}, + ) as r: + b = await r.json() + assert r.status == 200, b + matches = [a for a in b.get("assets", []) if a.get("name") == name] + assert len(matches) == 1 + assert matches[0].get("asset_hash"), "Seed should have been hashed into an Asset" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_cache_state_retarget_on_content_change_asset_info_stays( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + Start with hashed H1 (AssetInfo A1). Replace file bytes on disk to become H2. + After scan: AssetCacheState points to H2; A1 still references H1; downloading A1 -> 404. + """ + scope = f"retarget-{uuid.uuid4().hex[:6]}" + name = "content_change.bin" + d1 = make_asset_bytes("v1-" + scope, 2048) + + a1 = await asset_factory(name, [root, "unit-tests", scope], {}, d1) + aid = a1["id"] + h1 = a1["asset_hash"] + + p = comfy_tmp_base_dir / root / "unit-tests" / scope / get_asset_filename(a1["asset_hash"], ".bin") + assert p.exists() + + # Change the bytes in place to force a new content hash (H2) + d2 = make_asset_bytes("v2-" + scope, 3072) + p.write_bytes(d2) + + # Scan to verify and retarget the state; reconcilers run after scan + await run_scan_and_wait(root) + + # AssetInfo still on the old content identity (H1) + async with http.get(f"{api_base}/api/assets/{aid}") as rg: + g = await rg.json() + assert rg.status == 200, g + assert g.get("asset_hash") == h1 + + # Download must fail until a state exists for H1 (we removed the only one by retarget) + async with http.get(f"{api_base}/api/assets/{aid}/content") as dl: + body = await dl.json() + assert dl.status == 404, body + assert body["error"]["code"] == "FILE_NOT_FOUND" diff --git a/tests-assets/test_tags.py b/tests-assets/test_tags.py new file mode 100644 index 000000000000..9bdf770c4c54 --- /dev/null +++ b/tests-assets/test_tags.py @@ -0,0 +1,228 @@ +import json +import uuid + +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_tags_present(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + # Include zero-usage tags by default + async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1: + body1 = await r1.json() + assert r1.status == 200 + names = [t["name"] for t in body1["tags"]] + # A few system tags from migration should exist: + assert "models" in names + assert "checkpoints" in names + + # Only used tags before we add anything new from this test cycle + async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2: + body2 = await r2.json() + assert r2.status == 200 + # We already seeded one asset via fixture, so used tags must be non-empty + used_names = [t["name"] for t in body2["tags"]] + assert "models" in used_names + assert "checkpoints" in used_names + + # Prefix filter should refine the list + async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}) as r3: + b3 = await r3.json() + assert r3.status == 200 + names3 = [t["name"] for t in b3["tags"]] + assert "unit-tests" in names3 + assert "models" not in names3 # filtered out by prefix + + # Order by name ascending should be stable + async with http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}) as r4: + b4 = await r4.json() + assert r4.status == 200 + names4 = [t["name"] for t in b4["tags"]] + assert names4 == sorted(names4) + + +@pytest.mark.asyncio +async def test_tags_empty_usage(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes): + # Baseline: system tags exist when include_zero (default) is true + async with http.get(api_base + "/api/tags", params={"limit": "500"}) as r1: + body1 = await r1.json() + assert r1.status == 200 + names = [t["name"] for t in body1["tags"]] + assert "models" in names and "checkpoints" in names + + # Create a short-lived asset under input with a unique custom tag + scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}" + custom_tag = f"temp-{uuid.uuid4().hex[:8]}" + name = "tag_seed.bin" + _asset = await asset_factory( + name, + ["input", "unit-tests", scope, custom_tag], + {}, + make_asset_bytes(name, 512), + ) + + # While the asset exists, the custom tag must appear when include_zero=false + async with http.get( + api_base + "/api/tags", + params={"include_zero": "false", "prefix": custom_tag, "limit": "50"}, + ) as r2: + body2 = await r2.json() + assert r2.status == 200 + used_names = [t["name"] for t in body2["tags"]] + assert custom_tag in used_names + + # Delete the asset so the tag usage drops to zero + async with http.delete(f"{api_base}/api/assets/{_asset['id']}") as rd: + assert rd.status == 204 + + # Now the custom tag must not be returned when include_zero=false + async with http.get( + api_base + "/api/tags", + params={"include_zero": "false", "prefix": custom_tag, "limit": "50"}, + ) as r3: + body3 = await r3.json() + assert r3.status == 200 + names_after = [t["name"] for t in body3["tags"]] + assert custom_tag not in names_after + assert not names_after # filtered view should be empty now + + +@pytest.mark.asyncio +async def test_add_and_remove_tags(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # Add tags with duplicates and mixed case + payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]} + async with http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add) as r1: + b1 = await r1.json() + assert r1.status == 200, b1 + # normalized, deduplicated; 'unit-tests' was already present from the seed + assert set(b1["added"]) == {"newtag", "beta"} + assert set(b1["already_present"]) == {"unit-tests"} + assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"] + + async with http.get(f"{api_base}/api/assets/{aid}") as rg: + g = await rg.json() + assert rg.status == 200 + tags_now = set(g["tags"]) + assert {"newtag", "beta"}.issubset(tags_now) + + # Remove a tag and a non-existent tag + payload_del = {"tags": ["newtag", "does-not-exist"]} + async with http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del) as r2: + b2 = await r2.json() + assert r2.status == 200 + assert set(b2["removed"]) == {"newtag"} + assert set(b2["not_present"]) == {"does-not-exist"} + + # Verify remaining tags after deletion + async with http.get(f"{api_base}/api/assets/{aid}") as rg2: + g2 = await rg2.json() + assert rg2.status == 200 + tags_later = set(g2["tags"]) + assert "newtag" not in tags_later + assert "beta" in tags_later # still present + + +@pytest.mark.asyncio +async def test_tags_list_order_and_prefix(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + h = seeded_asset["asset_hash"] + + # Add both tags to the seeded asset (usage: orderaaa=1, orderbbb=1) + async with http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": ["orderaaa", "orderbbb"]}) as r_add: + add_body = await r_add.json() + assert r_add.status == 200, add_body + + # Create another AssetInfo from the same content but tagged ONLY with 'orderbbb'. + payload = { + "hash": h, + "name": "order_only_bbb.safetensors", + "tags": ["input", "unit-tests", "orderbbb"], + "user_metadata": {}, + } + async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r_copy: + copy_body = await r_copy.json() + assert r_copy.status == 201, copy_body + + # 1) Default order (count_desc): 'orderbbb' should come before 'orderaaa' + # because it has higher usage (2 vs 1). + async with http.get(api_base + "/api/tags", params={"prefix": "order", "include_zero": "false"}) as r1: + b1 = await r1.json() + assert r1.status == 200, b1 + names1 = [t["name"] for t in b1["tags"]] + counts1 = {t["name"]: t["count"] for t in b1["tags"]} + # Both must be present within the prefix subset + assert "orderaaa" in names1 and "orderbbb" in names1 + # Usage of 'orderbbb' must be >= 'orderaaa'; in our setup it's 2 vs 1 + assert counts1["orderbbb"] >= counts1["orderaaa"] + # And with count_desc, 'orderbbb' appears earlier than 'orderaaa' + assert names1.index("orderbbb") < names1.index("orderaaa") + + # 2) name_asc: lexical order should flip the relative order + async with http.get( + api_base + "/api/tags", + params={"prefix": "order", "include_zero": "false", "order": "name_asc"}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200, b2 + names2 = [t["name"] for t in b2["tags"]] + assert "orderaaa" in names2 and "orderbbb" in names2 + assert names2.index("orderaaa") < names2.index("orderbbb") + + # 3) invalid limit rejected (existing negative case retained) + async with http.get(api_base + "/api/tags", params={"limit": "1001"}) as r3: + b3 = await r3.json() + assert r3.status == 400 + assert b3["error"]["code"] == "INVALID_QUERY" + + +@pytest.mark.asyncio +async def test_tags_endpoints_invalid_bodies(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # Add with empty list + async with http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": []}) as r1: + b1 = await r1.json() + assert r1.status == 400 + assert b1["error"]["code"] == "INVALID_BODY" + + # Remove with wrong type + async with http.delete(f"{api_base}/api/assets/{aid}/tags", json={"tags": [123]}) as r2: + b2 = await r2.json() + assert r2.status == 400 + assert b2["error"]["code"] == "INVALID_BODY" + + # metadata_filter provided as JSON array should be rejected (must be object) + async with http.get( + api_base + "/api/assets", + params={"metadata_filter": json.dumps([{"x": 1}])}, + ) as r3: + b3 = await r3.json() + assert r3.status == 400 + assert b3["error"]["code"] == "INVALID_QUERY" + + +@pytest.mark.asyncio +async def test_tags_prefix_treats_underscore_literal( + http, + api_base, + asset_factory, + make_asset_bytes, +): + """'prefix' for /api/tags must treat '_' literally, not as a wildcard.""" + base = f"pref_{uuid.uuid4().hex[:6]}" + tag_ok = f"{base}_ok" # should match prefix=f"{base}_" + tag_bad = f"{base}xok" # must NOT match if '_' is escaped + scope = f"tags-underscore-{uuid.uuid4().hex[:6]}" + + await asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512)) + await asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512)) + + async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}) as r: + body = await r.json() + assert r.status == 200, body + names = [t["name"] for t in body["tags"]] + assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'" + assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard" + assert body["total"] == 1 diff --git a/tests-assets/test_uploads.py b/tests-assets/test_uploads.py new file mode 100644 index 000000000000..f1b116c1aca2 --- /dev/null +++ b/tests-assets/test_uploads.py @@ -0,0 +1,325 @@ +import asyncio +import json +import uuid + +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_upload_ok_duplicate_reference(http: aiohttp.ClientSession, api_base: str, make_asset_bytes): + name = "dup_a.safetensors" + tags = ["models", "checkpoints", "unit-tests", "alpha"] + meta = {"purpose": "dup"} + data = make_asset_bytes(name) + form1 = aiohttp.FormData() + form1.add_field("file", data, filename=name, content_type="application/octet-stream") + form1.add_field("tags", json.dumps(tags)) + form1.add_field("name", name) + form1.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form1) as r1: + a1 = await r1.json() + assert r1.status == 201, a1 + assert a1["created_new"] is True + + # Second upload with the same data and name should return created_new == False and the same asset + form2 = aiohttp.FormData() + form2.add_field("file", data, filename=name, content_type="application/octet-stream") + form2.add_field("tags", json.dumps(tags)) + form2.add_field("name", name) + form2.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form2) as r2: + a2 = await r2.json() + assert r2.status == 200, a2 + assert a2["created_new"] is False + assert a2["asset_hash"] == a1["asset_hash"] + assert a2["id"] == a1["id"] # old reference + + # Third upload with the same data but new name should return created_new == False and the new AssetReference + form3 = aiohttp.FormData() + form3.add_field("file", data, filename=name, content_type="application/octet-stream") + form3.add_field("tags", json.dumps(tags)) + form3.add_field("name", name + "_d") + form3.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form3) as r2: + a3 = await r2.json() + assert r2.status == 200, a3 + assert a3["created_new"] is False + assert a3["asset_hash"] == a1["asset_hash"] + assert a3["id"] != a1["id"] # old reference + + +@pytest.mark.asyncio +async def test_upload_fastpath_from_existing_hash_no_file(http: aiohttp.ClientSession, api_base: str): + # Seed a small file first + name = "fastpath_seed.safetensors" + tags = ["models", "checkpoints", "unit-tests"] + meta = {} + form1 = aiohttp.FormData() + form1.add_field("file", b"B" * 1024, filename=name, content_type="application/octet-stream") + form1.add_field("tags", json.dumps(tags)) + form1.add_field("name", name) + form1.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form1) as r1: + b1 = await r1.json() + assert r1.status == 201, b1 + h = b1["asset_hash"] + + # Now POST /api/assets with only hash and no file + form2 = aiohttp.FormData(default_to_multipart=True) + form2.add_field("hash", h) + form2.add_field("tags", json.dumps(tags)) + form2.add_field("name", "fastpath_copy.safetensors") + form2.add_field("user_metadata", json.dumps({"purpose": "copy"})) + async with http.post(api_base + "/api/assets", data=form2) as r2: + b2 = await r2.json() + assert r2.status == 200, b2 # fast path returns 200 with created_new == False + assert b2["created_new"] is False + assert b2["asset_hash"] == h + + +@pytest.mark.asyncio +async def test_upload_fastpath_with_known_hash_and_file( + http: aiohttp.ClientSession, api_base: str +): + # Seed + form1 = aiohttp.FormData() + form1.add_field("file", b"C" * 128, filename="seed.safetensors", content_type="application/octet-stream") + form1.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "fp"])) + form1.add_field("name", "seed.safetensors") + form1.add_field("user_metadata", json.dumps({})) + async with http.post(api_base + "/api/assets", data=form1) as r1: + b1 = await r1.json() + assert r1.status == 201, b1 + h = b1["asset_hash"] + + # Send both file and hash of existing content -> server must drain file and create from hash (200) + form2 = aiohttp.FormData() + form2.add_field("file", b"ignored" * 10, filename="ignored.bin", content_type="application/octet-stream") + form2.add_field("hash", h) + form2.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "fp"])) + form2.add_field("name", "copy_from_hash.safetensors") + form2.add_field("user_metadata", json.dumps({})) + async with http.post(api_base + "/api/assets", data=form2) as r2: + b2 = await r2.json() + assert r2.status == 200, b2 + assert b2["created_new"] is False + assert b2["asset_hash"] == h + + +@pytest.mark.asyncio +async def test_upload_multiple_tags_fields_are_merged(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"B" * 256, filename="merge.safetensors", content_type="application/octet-stream") + form.add_field("tags", "models,checkpoints") # CSV + form.add_field("tags", json.dumps(["unit-tests", "alpha"])) # JSON array in second field + form.add_field("name", "merge.safetensors") + form.add_field("user_metadata", json.dumps({"u": 1})) + async with http.post(api_base + "/api/assets", data=form) as r1: + created = await r1.json() + assert r1.status in (200, 201), created + aid = created["id"] + + # Verify all tags are present on the resource + async with http.get(f"{api_base}/api/assets/{aid}") as rg: + detail = await rg.json() + assert rg.status == 200, detail + tags = set(detail["tags"]) + assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_concurrent_upload_identical_bytes_different_names( + root: str, + http: aiohttp.ClientSession, + api_base: str, + make_asset_bytes, +): + """ + Two concurrent uploads of identical bytes but different names. + Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True. + """ + scope = f"concupload-{uuid.uuid4().hex[:6]}" + name1, name2 = "cu_a.bin", "cu_b.bin" + data = make_asset_bytes("concurrent", 4096) + tags = [root, "unit-tests", scope] + + def _form(name: str) -> aiohttp.FormData: + f = aiohttp.FormData() + f.add_field("file", data, filename=name, content_type="application/octet-stream") + f.add_field("tags", json.dumps(tags)) + f.add_field("name", name) + f.add_field("user_metadata", json.dumps({})) + return f + + r1, r2 = await asyncio.gather( + http.post(api_base + "/api/assets", data=_form(name1)), + http.post(api_base + "/api/assets", data=_form(name2)), + ) + b1, b2 = await r1.json(), await r2.json() + assert r1.status in (200, 201), b1 + assert r2.status in (200, 201), b2 + assert b1["asset_hash"] == b2["asset_hash"] + assert b1["id"] != b2["id"] + + created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))]) + assert created_flags == [False, True] + + async with http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "sort": "name"}, + ) as rl: + bl = await rl.json() + assert rl.status == 200, bl + names = [a["name"] for a in bl.get("assets", [])] + assert set([name1, name2]).issubset(names) + + +@pytest.mark.asyncio +async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str): + payload = { + "hash": "blake3:" + "0" * 64, + "name": "nonexistent.bin", + "tags": ["models", "checkpoints", "unit-tests"], + } + async with http.post(api_base + "/api/assets/from-hash", json=payload) as r: + body = await r.json() + assert r.status == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +@pytest.mark.asyncio +async def test_upload_zero_byte_rejected(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"", filename="empty.safetensors", content_type="application/octet-stream") + form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "edge"])) + form.add_field("name", "empty.safetensors") + form.add_field("user_metadata", json.dumps({})) + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "EMPTY_UPLOAD" + + +@pytest.mark.asyncio +async def test_upload_invalid_root_tag_rejected(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 64, filename="badroot.bin", content_type="application/octet-stream") + form.add_field("tags", json.dumps(["not-a-root", "whatever"])) + form.add_field("name", "badroot.bin") + form.add_field("user_metadata", json.dumps({})) + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.asyncio +async def test_upload_user_metadata_must_be_json(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 128, filename="badmeta.bin", content_type="application/octet-stream") + form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "edge"])) + form.add_field("name", "badmeta.bin") + form.add_field("user_metadata", "{not json}") # invalid + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.asyncio +async def test_upload_requires_multipart(http: aiohttp.ClientSession, api_base: str): + async with http.post(api_base + "/api/assets", json={"foo": "bar"}) as r: + body = await r.json() + assert r.status == 415 + assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE" + + +@pytest.mark.asyncio +async def test_upload_missing_file_and_hash(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData(default_to_multipart=True) + form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests"])) + form.add_field("name", "x.safetensors") + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "MISSING_FILE" + + +@pytest.mark.asyncio +async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 128, filename="m.safetensors", content_type="application/octet-stream") + form.add_field("tags", json.dumps(["models", "no_such_category", "unit-tests"])) + form.add_field("name", "m.safetensors") + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "INVALID_BODY" + assert body["error"]["message"].startswith("unknown models category") + + +@pytest.mark.asyncio +async def test_upload_models_requires_category(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 64, filename="nocat.safetensors", content_type="application/octet-stream") + form.add_field("tags", json.dumps(["models"])) # missing category + form.add_field("name", "nocat.safetensors") + form.add_field("user_metadata", json.dumps({})) + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.asyncio +async def test_upload_tags_traversal_guard(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 256, filename="evil.safetensors", content_type="application/octet-stream") + # '..' should be rejected by destination resolver + form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"])) + form.add_field("name", "evil.safetensors") + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_duplicate_upload_same_display_name_does_not_clobber( + root: str, + http, + api_base: str, + asset_factory, + make_asset_bytes, +): + """ + Two uploads use the same tags and the same display name but different bytes. + With hash-based filenames, they must NOT overwrite each other. Both assets + remain accessible and serve their original content. + """ + scope = f"dup-path-{uuid.uuid4().hex[:6]}" + display_name = "same_display.bin" + + d1 = make_asset_bytes(scope + "-v1", 1536) + d2 = make_asset_bytes(scope + "-v2", 2048) + tags = [root, "unit-tests", scope] + + first = await asset_factory(display_name, tags, {}, d1) + second = await asset_factory(display_name, tags, {}, d2) + + assert first["id"] != second["id"] + assert first["asset_hash"] != second["asset_hash"] # different content + assert first["name"] == second["name"] == display_name + + # Both must be independently retrievable + async with http.get(f"{api_base}/api/assets/{first['id']}/content") as r1: + b1 = await r1.read() + assert r1.status == 200 + assert b1 == d1 + async with http.get(f"{api_base}/api/assets/{second['id']}/content") as r2: + b2 = await r2.read() + assert r2.status == 200 + assert b2 == d2 diff --git a/tests-unit/app_test/model_manager_test.py b/tests-unit/app_test/model_manager_test.py deleted file mode 100644 index ae59206f6563..000000000000 --- a/tests-unit/app_test/model_manager_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import pytest -import base64 -import json -import struct -from io import BytesIO -from PIL import Image -from aiohttp import web -from unittest.mock import patch -from app.model_manager import ModelFileManager - -pytestmark = ( - pytest.mark.asyncio -) # This applies the asyncio mark to all test functions in the module - -@pytest.fixture -def model_manager(): - return ModelFileManager() - -@pytest.fixture -def app(model_manager): - app = web.Application() - routes = web.RouteTableDef() - model_manager.add_routes(routes) - app.add_routes(routes) - return app - -async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path): - img = Image.new('RGB', (100, 100), 'white') - img_byte_arr = BytesIO() - img.save(img_byte_arr, format='PNG') - img_byte_arr.seek(0) - img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') - - safetensors_file = tmp_path / "test_model.safetensors" - header_bytes = json.dumps({ - "__metadata__": { - "ssmd_cover_images": json.dumps([img_b64]) - } - }).encode('utf-8') - length_bytes = struct.pack('