From 1cb3c98947c36acc14103312c432805d46570a3c Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 1 Jun 2025 15:32:02 +0100 Subject: [PATCH 01/82] Implement database & model hashing --- alembic.ini | 84 ++++++++++++++++ alembic_db/README.md | 4 + alembic_db/env.py | 69 +++++++++++++ alembic_db/script.py.mako | 28 ++++++ alembic_db/versions/565b08122d00_init.py | 34 +++++++ app/database/db.py | 90 +++++++++++++++++ app/database/models.py | 50 ++++++++++ app/frontend_management.py | 91 +++++++++++------ app/model_processor.py | 122 +++++++++++++++++++++++ comfy/cli_args.py | 6 ++ comfy/utils.py | 4 + folder_paths.py | 21 ++++ main.py | 8 +- requirements.txt | 2 + utils/install_util.py | 19 ++++ 15 files changed, 601 insertions(+), 31 deletions(-) create mode 100644 alembic.ini create mode 100644 alembic_db/README.md create mode 100644 alembic_db/env.py create mode 100644 alembic_db/script.py.mako create mode 100644 alembic_db/versions/565b08122d00_init.py create mode 100644 app/database/db.py create mode 100644 app/database/models.py create mode 100644 app/model_processor.py create mode 100644 utils/install_util.py diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 000000000000..12f18712f430 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,84 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = alembic_db + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic_db/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +version_path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = sqlite:///user/comfyui.db + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME diff --git a/alembic_db/README.md b/alembic_db/README.md new file mode 100644 index 000000000000..3b808c7cab30 --- /dev/null +++ b/alembic_db/README.md @@ -0,0 +1,4 @@ +## Generate new revision + +1. Update models in `/app/database/models.py` +2. Run `alembic revision --autogenerate -m "{your message}"` diff --git a/alembic_db/env.py b/alembic_db/env.py new file mode 100644 index 000000000000..d278cfc5304a --- /dev/null +++ b/alembic_db/env.py @@ -0,0 +1,69 @@ +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + + +from app.database.models import Base +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic_db/script.py.mako b/alembic_db/script.py.mako new file mode 100644 index 000000000000..480b130d632c --- /dev/null +++ b/alembic_db/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic_db/versions/565b08122d00_init.py b/alembic_db/versions/565b08122d00_init.py new file mode 100644 index 000000000000..9a8a51fbc133 --- /dev/null +++ b/alembic_db/versions/565b08122d00_init.py @@ -0,0 +1,34 @@ +"""init + +Revision ID: 565b08122d00 +Revises: +Create Date: 2025-05-29 19:15:56.230322 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '565b08122d00' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.create_table('model', + sa.Column('type', sa.Text(), nullable=False), + sa.Column('path', sa.Text(), nullable=False), + sa.Column('hash', sa.Text(), nullable=True), + sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.PrimaryKeyConstraint('type', 'path') + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_table('model') diff --git a/app/database/db.py b/app/database/db.py new file mode 100644 index 000000000000..d17fa4f1f038 --- /dev/null +++ b/app/database/db.py @@ -0,0 +1,90 @@ +import logging +import os +import shutil +from utils.install_util import get_missing_requirements_message +from comfy.cli_args import args + +Session = None + + +def can_create_session(): + return Session is not None + + +try: + import alembic + import sqlalchemy +except ImportError as e: + logging.error(get_missing_requirements_message()) + raise e + +from alembic import command +from alembic.config import Config +from alembic.runtime.migration import MigrationContext +from alembic.script import ScriptDirectory +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + + +def get_alembic_config(): + root_path = os.path.join(os.path.dirname(__file__), "../..") + config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) + scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) + + config = Config(config_path) + config.set_main_option("script_location", scripts_path) + config.set_main_option("sqlalchemy.url", args.database_url) + + return config + + +def get_db_path(): + url = args.database_url + if url.startswith("sqlite:///"): + return url.split("///")[1] + else: + raise ValueError(f"Unsupported database URL '{url}'.") + + +def init_db(): + db_url = args.database_url + logging.debug(f"Database URL: {db_url}") + + config = get_alembic_config() + + # Check if we need to upgrade + engine = create_engine(db_url) + conn = engine.connect() + + context = MigrationContext.configure(conn) + current_rev = context.get_current_revision() + + script = ScriptDirectory.from_config(config) + target_rev = script.get_current_head() + + if current_rev != target_rev: + # Backup the database pre upgrade + db_path = get_db_path() + backup_path = db_path + ".bkp" + if os.path.exists(db_path): + shutil.copy(db_path, backup_path) + else: + backup_path = None + + try: + command.upgrade(config, target_rev) + logging.info(f"Database upgraded from {current_rev} to {target_rev}") + except Exception as e: + if backup_path: + # Restore the database from backup if upgrade fails + shutil.copy(backup_path, db_path) + os.remove(backup_path) + logging.error(f"Error upgrading database: {e}") + raise e + + global Session + Session = sessionmaker(bind=engine) + + +def create_session(): + return Session() diff --git a/app/database/models.py b/app/database/models.py new file mode 100644 index 000000000000..d2c1e042d85d --- /dev/null +++ b/app/database/models.py @@ -0,0 +1,50 @@ +from sqlalchemy import ( + Column, + Text, + DateTime, +) +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql import func + +Base = declarative_base() + + +def to_dict(obj): + fields = obj.__table__.columns.keys() + return { + field: (val.to_dict() if hasattr(val, "to_dict") else val) + for field in fields + if (val := getattr(obj, field)) + } + + +class Model(Base): + """ + SQLAlchemy model representing a model file in the system. + + This class defines the database schema for storing information about model files, + including their type, path, hash, and when they were added to the system. + + Attributes: + type (Text): The type of the model, this is the name of the folder in the models folder (primary key) + path (Text): The file path of the model relative to the type folder (primary key) + hash (Text): A sha256 hash of the model file + date_added (DateTime): Timestamp of when the model was added to the system + """ + + __tablename__ = "model" + + type = Column(Text, primary_key=True) + path = Column(Text, primary_key=True) + hash = Column(Text) + date_added = Column(DateTime, server_default=func.now()) + + def to_dict(self): + """ + Convert the model instance to a dictionary representation. + + Returns: + dict: A dictionary containing the attributes of the model + """ + dict = to_dict(self) + return dict diff --git a/app/frontend_management.py b/app/frontend_management.py index d9ef8c9213bc..3e54e4d512c8 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -16,26 +16,15 @@ import requests from typing_extensions import NotRequired +from utils.install_util import get_missing_requirements_message, requirements_path from comfy.cli_args import DEFAULT_VERSION_STRING import app.logger -# The path to the requirements.txt file -req_path = Path(__file__).parents[1] / "requirements.txt" - - def frontend_install_warning_message(): - """The warning message to display when the frontend version is not up to date.""" - - extra = "" - if sys.flags.no_user_site: - extra = "-s " return f""" -Please install the updated requirements.txt file by running: -{sys.executable} {extra}-m pip install -r {req_path} +{get_missing_requirements_message()} This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead. - -If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem """.strip() @@ -48,7 +37,7 @@ def parse_version(version: str) -> tuple[int, int, int]: try: frontend_version_str = version("comfyui-frontend-package") frontend_version = parse_version(frontend_version_str) - with open(req_path, "r", encoding="utf-8") as f: + with open(requirements_path, "r", encoding="utf-8") as f: required_frontend = parse_version(f.readline().split("=")[-1]) if frontend_version < required_frontend: app.logger.log_startup_warning( @@ -162,10 +151,30 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None: class FrontendManager: + """ + A class to manage ComfyUI frontend versions and installations. + + This class handles the initialization and management of different frontend versions, + including the default frontend from the pip package and custom frontend versions + from GitHub repositories. + + Attributes: + CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored. + """ + CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions") @classmethod def default_frontend_path(cls) -> str: + """ + Get the path to the default frontend installation from the pip package. + + Returns: + str: The path to the default frontend static files. + + Raises: + SystemExit: If the comfyui-frontend-package is not installed. + """ try: import comfyui_frontend_package @@ -186,6 +195,15 @@ def default_frontend_path(cls) -> str: @classmethod def templates_path(cls) -> str: + """ + Get the path to the workflow templates. + + Returns: + str: The path to the workflow templates directory. + + Raises: + SystemExit: If the comfyui-workflow-templates package is not installed. + """ try: import comfyui_workflow_templates @@ -221,12 +239,17 @@ def embedded_docs_path(cls) -> str: @classmethod def parse_version_string(cls, value: str) -> tuple[str, str, str]: """ + Parse a version string into its components. + + The version string should be in the format: 'owner/repo@version' + where version can be either a semantic version (v1.2.3) or 'latest'. + Args: value (str): The version string to parse. - + Returns: - tuple[str, str]: A tuple containing provider name and version. - + tuple[str, str, str]: A tuple containing (owner, repo, version). + Raises: argparse.ArgumentTypeError: If the version string is invalid. """ @@ -242,18 +265,22 @@ def init_frontend_unsafe( cls, version_string: str, provider: Optional[FrontEndProvider] = None ) -> str: """ - Initializes the frontend for the specified version. - + Initialize a frontend version without error handling. + + This method attempts to initialize a specific frontend version, either from + the default pip package or from a custom GitHub repository. It will download + and extract the frontend files if necessary. + Args: - version_string (str): The version string. - provider (FrontEndProvider, optional): The provider to use. Defaults to None. - + version_string (str): The version string specifying which frontend to use. + provider (FrontEndProvider, optional): The provider to use for custom frontends. + Returns: str: The path to the initialized frontend. - + Raises: - Exception: If there is an error during the initialization process. - main error source might be request timeout or invalid URL. + Exception: If there is an error during initialization (e.g., network timeout, + invalid URL, or missing assets). """ if version_string == DEFAULT_VERSION_STRING: check_frontend_version() @@ -305,13 +332,17 @@ def init_frontend_unsafe( @classmethod def init_frontend(cls, version_string: str) -> str: """ - Initializes the frontend with the specified version string. - + Initialize a frontend version with error handling. + + This is the main method to initialize a frontend version. It wraps init_frontend_unsafe + with error handling, falling back to the default frontend if initialization fails. + Args: - version_string (str): The version string to initialize the frontend with. - + version_string (str): The version string specifying which frontend to use. + Returns: - str: The path of the initialized frontend. + str: The path to the initialized frontend. If initialization fails, + returns the path to the default frontend. """ try: return cls.init_frontend_unsafe(version_string) diff --git a/app/model_processor.py b/app/model_processor.py new file mode 100644 index 000000000000..980940262fc5 --- /dev/null +++ b/app/model_processor.py @@ -0,0 +1,122 @@ +import hashlib +import os +import logging +import time +from app.database.models import Model +from app.database.db import create_session +from folder_paths import get_relative_path + + +class ModelProcessor: + def _validate_path(self, model_path): + try: + if not os.path.exists(model_path): + logging.error(f"Model file not found: {model_path}") + return None + + result = get_relative_path(model_path) + if not result: + logging.error( + f"Model file not in a recognized model directory: {model_path}" + ) + return None + + return result + except Exception as e: + logging.error(f"Error validating model path {model_path}: {str(e)}") + return None + + def _hash_file(self, model_path): + try: + h = hashlib.sha256() + with open(model_path, "rb", buffering=0) as f: + b = bytearray(128 * 1024) + mv = memoryview(b) + while n := f.readinto(mv): + h.update(mv[:n]) + return h.hexdigest() + except Exception as e: + logging.error(f"Error hashing file {model_path}: {str(e)}") + return None + + def _get_existing_model(self, session, model_type, model_relative_path): + return ( + session.query(Model) + .filter(Model.type == model_type) + .filter(Model.path == model_relative_path) + .first() + ) + + def _update_database( + self, session, model_type, model_relative_path, model_hash, model=None + ): + try: + if not model: + model = self._get_existing_model( + session, model_type, model_relative_path + ) + + if not model: + model = Model( + path=model_relative_path, + type=model_type, + ) + session.add(model) + + model.hash = model_hash + session.commit() + return model + except Exception as e: + logging.error( + f"Error updating database for {model_relative_path}: {str(e)}" + ) + + def process_file(self, model_path): + try: + result = self._validate_path(model_path) + if not result: + return + model_type, model_relative_path = result + + with create_session() as session: + existing_model = self._get_existing_model( + session, model_type, model_relative_path + ) + if existing_model and existing_model.hash: + # File exists with hash, no need to process + return existing_model + + start_time = time.time() + logging.info(f"Hashing model {model_relative_path}") + model_hash = self._hash_file(model_path) + if not model_hash: + return + logging.info( + f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)" + ) + + return self._update_database(session, model_type, model_relative_path, model_hash) + except Exception as e: + logging.error(f"Error processing model file {model_path}: {str(e)}") + + def retrieve_hash(self, model_path, model_type=None): + try: + if model_type is not None: + result = self._validate_path(model_path) + if not result: + return None + model_type, model_relative_path = result + + with create_session() as session: + model = self._get_existing_model( + session, model_type, model_relative_path + ) + if model and model.hash: + return model.hash + return None + except Exception as e: + logging.error(f"Error retrieving hash for {model_path}: {str(e)}") + return None + + +model_processor = ModelProcessor() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 4fb675f990b0..154491fe07d3 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -203,6 +203,12 @@ def is_valid_directory(path: str) -> str: help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)", ) +database_default_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") +) +parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") +parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.") + if comfy.options.args_parsing: args = parser.parse_args() else: diff --git a/comfy/utils.py b/comfy/utils.py index 1f8d71292511..547ce9fc9247 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -20,6 +20,7 @@ import torch import math import struct +from app.model_processor import model_processor import comfy.checkpoint_pickle import safetensors.torch import numpy as np @@ -53,6 +54,9 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): if device is None: device = torch.device("cpu") metadata = None + + model_processor.process_file(ckpt) + if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: diff --git a/folder_paths.py b/folder_paths.py index f0b3fd103739..452409bf04b2 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -299,6 +299,27 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str: return full_path +def get_relative_path(full_path: str) -> tuple[str, str] | None: + """Convert a full path back to a type-relative path. + + Args: + full_path: The full path to the file + + Returns: + tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise + """ + global folder_names_and_paths + full_path = os.path.normpath(full_path) + + for model_type, (paths, _) in folder_names_and_paths.items(): + for base_path in paths: + base_path = os.path.normpath(base_path) + if full_path.startswith(base_path): + relative_path = os.path.relpath(full_path, base_path) + return model_type, relative_path + + return None + def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]: folder_name = map_legacy(folder_name) global folder_names_and_paths diff --git a/main.py b/main.py index fb1f8d20bfb2..d6f8193c4dd7 100644 --- a/main.py +++ b/main.py @@ -147,7 +147,6 @@ def cuda_malloc_warning(): if cuda_malloc_warning: logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") - def prompt_worker(q, server_instance): current_time: float = 0.0 cache_type = execution.CacheType.CLASSIC @@ -237,6 +236,12 @@ def cleanup_temp(): if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) +def setup_database(): + try: + from app.database.db import init_db + init_db() + except Exception as e: + logging.error(f"Failed to initialize database. Please report this error as in future the database will be required: {e}") def start_comfyui(asyncio_loop=None): """ @@ -266,6 +271,7 @@ def start_comfyui(asyncio_loop=None): hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() + setup_database() prompt_server.add_routes() hijack_progress(prompt_server) diff --git a/requirements.txt b/requirements.txt index b98dc1268056..ea51f24aba1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,8 @@ Pillow scipy tqdm psutil +alembic +SQLAlchemy #non essential dependencies: kornia>=0.7.1 diff --git a/utils/install_util.py b/utils/install_util.py new file mode 100644 index 000000000000..5e6d51a2d8e3 --- /dev/null +++ b/utils/install_util.py @@ -0,0 +1,19 @@ +from pathlib import Path +import sys + +# The path to the requirements.txt file +requirements_path = Path(__file__).parents[1] / "requirements.txt" + + +def get_missing_requirements_message(): + """The warning message to display when a package is missing.""" + + extra = "" + if sys.flags.no_user_site: + extra = "-s " + return f""" +Please install the updated requirements.txt file by running: +{sys.executable} {extra}-m pip install -r {requirements_path} + +If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem. +""".strip() From 9da6aca0d01f23bdcc218346497a48935434016c Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 1 Jun 2025 13:34:26 +0100 Subject: [PATCH 02/82] Add additional db model metadata fields and model downloading function --- ...b08122d00_init.py => e9c714da8d57_init.py} | 14 +- app/database/db.py | 54 ++-- app/database/models.py | 13 +- app/model_processor.py | 251 +++++++++++++++-- comfy/utils.py | 17 +- folder_paths.py | 4 +- main.py | 7 +- requirements.txt | 1 + tests-unit/app_test/model_processor_test.py | 253 ++++++++++++++++++ 9 files changed, 565 insertions(+), 49 deletions(-) rename alembic_db/versions/{565b08122d00_init.py => e9c714da8d57_init.py} (59%) create mode 100644 tests-unit/app_test/model_processor_test.py diff --git a/alembic_db/versions/565b08122d00_init.py b/alembic_db/versions/e9c714da8d57_init.py similarity index 59% rename from alembic_db/versions/565b08122d00_init.py rename to alembic_db/versions/e9c714da8d57_init.py index 9a8a51fbc133..1a296104436f 100644 --- a/alembic_db/versions/565b08122d00_init.py +++ b/alembic_db/versions/e9c714da8d57_init.py @@ -1,8 +1,8 @@ """init -Revision ID: 565b08122d00 +Revision ID: e9c714da8d57 Revises: -Create Date: 2025-05-29 19:15:56.230322 +Create Date: 2025-05-30 20:14:33.772039 """ from typing import Sequence, Union @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. -revision: str = '565b08122d00' +revision: str = 'e9c714da8d57' down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -20,15 +20,23 @@ def upgrade() -> None: """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### op.create_table('model', sa.Column('type', sa.Text(), nullable=False), sa.Column('path', sa.Text(), nullable=False), + sa.Column('file_name', sa.Text(), nullable=True), + sa.Column('file_size', sa.Integer(), nullable=True), sa.Column('hash', sa.Text(), nullable=True), + sa.Column('hash_algorithm', sa.Text(), nullable=True), + sa.Column('source_url', sa.Text(), nullable=True), sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), sa.PrimaryKeyConstraint('type', 'path') ) + # ### end Alembic commands ### def downgrade() -> None: """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### op.drop_table('model') + # ### end Alembic commands ### diff --git a/app/database/db.py b/app/database/db.py index d17fa4f1f038..45bcfcba6f4d 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -1,29 +1,50 @@ import logging import os import shutil +from app.logger import log_startup_warning from utils.install_util import get_missing_requirements_message from comfy.cli_args import args +_DB_AVAILABLE = False Session = None -def can_create_session(): - return Session is not None - - try: - import alembic - import sqlalchemy + from alembic import command + from alembic.config import Config + from alembic.runtime.migration import MigrationContext + from alembic.script import ScriptDirectory + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + + _DB_AVAILABLE = True except ImportError as e: - logging.error(get_missing_requirements_message()) - raise e + log_startup_warning( + f""" +------------------------------------------------------------------------ +Error importing dependencies: {e} + +{get_missing_requirements_message()} + +This error is happening because ComfyUI now uses a local sqlite database. +------------------------------------------------------------------------ +""".strip() + ) -from alembic import command -from alembic.config import Config -from alembic.runtime.migration import MigrationContext -from alembic.script import ScriptDirectory -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker + +def dependencies_available(): + """ + Temporary function to check if the dependencies are available + """ + return _DB_AVAILABLE + + +def can_create_session(): + """ + Temporary function to check if the database is available to create a session + During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created + """ + return dependencies_available() and Session is not None def get_alembic_config(): @@ -49,6 +70,8 @@ def get_db_path(): def init_db(): db_url = args.database_url logging.debug(f"Database URL: {db_url}") + db_path = get_db_path() + db_exists = os.path.exists(db_path) config = get_alembic_config() @@ -64,9 +87,8 @@ def init_db(): if current_rev != target_rev: # Backup the database pre upgrade - db_path = get_db_path() backup_path = db_path + ".bkp" - if os.path.exists(db_path): + if db_exists: shutil.copy(db_path, backup_path) else: backup_path = None diff --git a/app/database/models.py b/app/database/models.py index d2c1e042d85d..b0225c41204c 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,5 +1,6 @@ from sqlalchemy import ( Column, + Integer, Text, DateTime, ) @@ -20,7 +21,7 @@ def to_dict(obj): class Model(Base): """ - SQLAlchemy model representing a model file in the system. + sqlalchemy model representing a model file in the system. This class defines the database schema for storing information about model files, including their type, path, hash, and when they were added to the system. @@ -28,7 +29,11 @@ class Model(Base): Attributes: type (Text): The type of the model, this is the name of the folder in the models folder (primary key) path (Text): The file path of the model relative to the type folder (primary key) - hash (Text): A sha256 hash of the model file + file_name (Text): The name of the model file + file_size (Integer): The size of the model file in bytes + hash (Text): A hash of the model file + hash_algorithm (Text): The algorithm used to generate the hash + source_url (Text): The URL of the model file date_added (DateTime): Timestamp of when the model was added to the system """ @@ -36,7 +41,11 @@ class Model(Base): type = Column(Text, primary_key=True) path = Column(Text, primary_key=True) + file_name = Column(Text) + file_size = Column(Integer) hash = Column(Text) + hash_algorithm = Column(Text) + source_url = Column(Text) date_added = Column(DateTime, server_default=func.now()) def to_dict(self): diff --git a/app/model_processor.py b/app/model_processor.py index 980940262fc5..6cf8fd6fae7a 100644 --- a/app/model_processor.py +++ b/app/model_processor.py @@ -1,16 +1,23 @@ -import hashlib import os import logging import time -from app.database.models import Model -from app.database.db import create_session -from folder_paths import get_relative_path + +import requests +from tqdm import tqdm +from folder_paths import get_relative_path, get_full_path +from app.database.db import create_session, dependencies_available, can_create_session +import blake3 +import comfy.utils + + +if dependencies_available(): + from app.database.models import Model class ModelProcessor: def _validate_path(self, model_path): try: - if not os.path.exists(model_path): + if not self._file_exists(model_path): logging.error(f"Model file not found: {model_path}") return None @@ -26,15 +33,26 @@ def _validate_path(self, model_path): logging.error(f"Error validating model path {model_path}: {str(e)}") return None + def _file_exists(self, path): + """Check if a file exists.""" + return os.path.exists(path) + + def _get_file_size(self, path): + """Get file size.""" + return os.path.getsize(path) + + def _get_hasher(self): + return blake3.blake3() + def _hash_file(self, model_path): try: - h = hashlib.sha256() + hasher = self._get_hasher() with open(model_path, "rb", buffering=0) as f: b = bytearray(128 * 1024) mv = memoryview(b) while n := f.readinto(mv): - h.update(mv[:n]) - return h.hexdigest() + hasher.update(mv[:n]) + return hasher.hexdigest() except Exception as e: logging.error(f"Error hashing file {model_path}: {str(e)}") return None @@ -46,9 +64,21 @@ def _get_existing_model(self, session, model_type, model_relative_path): .filter(Model.path == model_relative_path) .first() ) + + def _ensure_source_url(self, session, model, source_url): + if model.source_url is None: + model.source_url = source_url + session.commit() def _update_database( - self, session, model_type, model_relative_path, model_hash, model=None + self, + session, + model_type, + model_path, + model_relative_path, + model_hash, + model, + source_url, ): try: if not model: @@ -60,10 +90,16 @@ def _update_database( model = Model( path=model_relative_path, type=model_type, + file_name=os.path.basename(model_path), ) session.add(model) + model.file_size = self._get_file_size(model_path) model.hash = model_hash + if model_hash: + model.hash_algorithm = "blake3" + model.source_url = source_url + session.commit() return model except Exception as e: @@ -71,36 +107,97 @@ def _update_database( f"Error updating database for {model_relative_path}: {str(e)}" ) - def process_file(self, model_path): + def process_file(self, model_path, source_url=None, model_hash=None): + """ + Process a model file and update the database with metadata. + If the file already exists and matches the database, it will not be processed again. + Returns the model object or if an error occurs, returns None. + """ try: + if not can_create_session(): + return + result = self._validate_path(model_path) if not result: return model_type, model_relative_path = result with create_session() as session: + session.expire_on_commit = False + existing_model = self._get_existing_model( session, model_type, model_relative_path ) - if existing_model and existing_model.hash: - # File exists with hash, no need to process + if ( + existing_model + and existing_model.hash + and existing_model.file_size == self._get_file_size(model_path) + ): + # File exists with hash and same size, no need to process + self._ensure_source_url(session, existing_model, source_url) return existing_model - start_time = time.time() - logging.info(f"Hashing model {model_relative_path}") - model_hash = self._hash_file(model_path) - if not model_hash: - return - logging.info( - f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)" - ) + if model_hash: + model_hash = model_hash.lower() + logging.info(f"Using provided hash: {model_hash}") + else: + start_time = time.time() + logging.info(f"Hashing model {model_relative_path}") + model_hash = self._hash_file(model_path) + if not model_hash: + return + logging.info( + f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)" + ) - return self._update_database(session, model_type, model_relative_path, model_hash) + return self._update_database( + session, + model_type, + model_path, + model_relative_path, + model_hash, + existing_model, + source_url, + ) except Exception as e: logging.error(f"Error processing model file {model_path}: {str(e)}") + return None + + def retrieve_model_by_hash(self, model_hash, model_type=None, session=None): + """ + Retrieve a model file from the database by hash and optionally by model type. + Returns the model object or None if the model doesnt exist or an error occurs. + """ + try: + if not can_create_session(): + return + + dispose_session = False + + if session is None: + session = create_session() + dispose_session = True + + model = session.query(Model).filter(Model.hash == model_hash) + if model_type is not None: + model = model.filter(Model.type == model_type) + return model.first() + except Exception as e: + logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}") + return None + finally: + if dispose_session: + session.close() def retrieve_hash(self, model_path, model_type=None): + """ + Retrieve the hash of a model file from the database. + Returns the hash or None if the model doesnt exist or an error occurs. + """ try: + if not can_create_session(): + return + if model_type is not None: result = self._validate_path(model_path) if not result: @@ -118,5 +215,117 @@ def retrieve_hash(self, model_path, model_type=None): logging.error(f"Error retrieving hash for {model_path}: {str(e)}") return None + def _validate_file_extension(self, file_name): + """Validate that the file extension is supported.""" + extension = os.path.splitext(file_name)[1] + if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"): + raise ValueError(f"Unsupported unsafe file for download: {file_name}") + + def _check_existing_file(self, model_type, file_name, expected_hash): + """Check if file exists and has correct hash.""" + destination_path = get_full_path(model_type, file_name, allow_missing=True) + if self._file_exists(destination_path): + model = self.process_file(destination_path) + if model and (expected_hash is None or model.hash == expected_hash): + logging.debug( + f"File {destination_path} already exists in the database and has the correct hash or no hash was provided." + ) + return destination_path + else: + raise ValueError( + f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again." + ) + return None + + def _check_existing_file_by_hash(self, hash, type, url): + """Check if a file with the given hash exists in the database and on disk.""" + hash = hash.lower() + with create_session() as session: + model = self.retrieve_model_by_hash(hash, type, session) + if model: + existing_path = get_full_path(type, model.path) + if existing_path: + logging.debug( + f"File {model.path} already exists in the database at {existing_path}" + ) + self._ensure_source_url(session, model, url) + return existing_path + else: + logging.debug( + f"File {model.path} exists in the database but not on disk" + ) + return None + + def _download_file(self, url, destination_path, hasher): + """Download a file and update the hasher with its contents.""" + response = requests.get(url, stream=True) + logging.info(f"Downloading {url} to {destination_path}") + + with open(destination_path, "wb") as f: + total_size = int(response.headers.get("content-length", 0)) + if total_size > 0: + pbar = comfy.utils.ProgressBar(total_size) + else: + pbar = None + with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar: + for chunk in response.iter_content(chunk_size=128 * 1024): + if chunk: + f.write(chunk) + hasher.update(chunk) + progress_bar.update(len(chunk)) + if pbar: + pbar.update(len(chunk)) + + def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path): + """Verify that the downloaded file has the expected hash.""" + if expected_hash is not None and calculated_hash != expected_hash: + self._remove_file(destination_path) + raise ValueError( + f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}" + ) + + def _remove_file(self, file_path): + """Remove a file from disk.""" + os.remove(file_path) + + def ensure_downloaded(self, type, url, desired_file_name, hash=None): + """ + Ensure a model file is downloaded and has the correct hash. + Returns the path to the downloaded file. + """ + logging.debug( + f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'" + ) + + # Validate file extension + self._validate_file_extension(desired_file_name) + + # Check if file exists with correct hash + if hash: + existing_path = self._check_existing_file_by_hash(hash, type, url) + if existing_path: + return existing_path + + # Check if file exists locally + destination_path = get_full_path(type, desired_file_name, allow_missing=True) + existing_path = self._check_existing_file(type, desired_file_name, hash) + if existing_path: + return existing_path + + # Download the file + hasher = self._get_hasher() + self._download_file(url, destination_path, hasher) + + # Verify hash + calculated_hash = hasher.hexdigest() + self._verify_downloaded_hash(calculated_hash, hash, destination_path) + + # Update database + self.process_file(destination_path, url, calculated_hash) + + # TODO: Notify frontend to reload models + + return destination_path + model_processor = ModelProcessor() diff --git a/comfy/utils.py b/comfy/utils.py index 547ce9fc9247..7768f363cf0a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -20,7 +20,6 @@ import torch import math import struct -from app.model_processor import model_processor import comfy.checkpoint_pickle import safetensors.torch import numpy as np @@ -50,13 +49,16 @@ class ModelCheckpoint: else: logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") +def is_html_file(file_path): + with open(file_path, "rb") as f: + content = f.read(100) + return b"" in content or b" 0: message = e.args[0] if "HeaderTooLarge" in message: @@ -92,6 +96,13 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): sd = pl_sd else: sd = pl_sd + + try: + from app.model_processor import model_processor + model_processor.process_file(ckpt) + except Exception as e: + logging.error(f"Error processing file {ckpt}: {e}") + return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): diff --git a/folder_paths.py b/folder_paths.py index 452409bf04b2..5b5554a30669 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -275,7 +275,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -def get_full_path(folder_name: str, filename: str) -> str | None: +def get_full_path(folder_name: str, filename: str, allow_missing: bool = False) -> str | None: global folder_names_and_paths folder_name = map_legacy(folder_name) if folder_name not in folder_names_and_paths: @@ -288,6 +288,8 @@ def get_full_path(folder_name: str, filename: str) -> str | None: return full_path elif os.path.islink(full_path): logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path)) + elif allow_missing: + return full_path return None diff --git a/main.py b/main.py index d6f8193c4dd7..17581a42f33d 100644 --- a/main.py +++ b/main.py @@ -238,10 +238,11 @@ def cleanup_temp(): def setup_database(): try: - from app.database.db import init_db - init_db() + from app.database.db import init_db, dependencies_available + if dependencies_available(): + init_db() except Exception as e: - logging.error(f"Failed to initialize database. Please report this error as in future the database will be required: {e}") + logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") def start_comfyui(asyncio_loop=None): """ diff --git a/requirements.txt b/requirements.txt index ea51f24aba1b..1ae6de3e5c5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,7 @@ tqdm psutil alembic SQLAlchemy +blake3 #non essential dependencies: kornia>=0.7.1 diff --git a/tests-unit/app_test/model_processor_test.py b/tests-unit/app_test/model_processor_test.py new file mode 100644 index 000000000000..d1e43d375e7e --- /dev/null +++ b/tests-unit/app_test/model_processor_test.py @@ -0,0 +1,253 @@ +import pytest +from unittest.mock import patch, MagicMock +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from app.model_processor import ModelProcessor +from app.database.models import Model, Base +import os + +# Test data constants +TEST_MODEL_TYPE = "checkpoints" +TEST_URL = "http://example.com/model.safetensors" +TEST_FILE_NAME = "model.safetensors" +TEST_EXPECTED_HASH = "abc123" +TEST_DESTINATION_PATH = "/path/to/model.safetensors" + + +def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None): + """Helper to create a test model in the database.""" + model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url) + session.add(model) + session.commit() + return model + + +def setup_mock_hash_calculation(model_processor, hash_value): + """Helper to setup hash calculation mocks.""" + mock_hash = MagicMock() + mock_hash.hexdigest.return_value = hash_value + return patch.object(model_processor, "_get_hasher", return_value=mock_hash) + + +def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None): + """Helper to verify model exists in database with correct attributes.""" + db_model = session.query(Model).filter_by(path=file_name).first() + assert db_model is not None + if expected_hash: + assert db_model.hash == expected_hash + if expected_type: + assert db_model.type == expected_type + return db_model + + +@pytest.fixture +def db_engine(): + # Configure in-memory database + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + yield engine + Base.metadata.drop_all(engine) + + +@pytest.fixture +def db_session(db_engine): + Session = sessionmaker(bind=db_engine) + session = Session() + yield session + session.close() + + +@pytest.fixture +def mock_get_relative_path(): + with patch("app.model_processor.get_relative_path") as mock: + mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path)) + yield mock + + +@pytest.fixture +def mock_get_full_path(): + with patch("app.model_processor.get_full_path") as mock: + mock.return_value = TEST_DESTINATION_PATH + yield mock + + +@pytest.fixture +def model_processor(db_session, mock_get_relative_path, mock_get_full_path): + with patch("app.model_processor.create_session", return_value=db_session): + with patch("app.model_processor.can_create_session", return_value=True): + processor = ModelProcessor() + # Setup test state + processor.removed_files = [] + processor.downloaded_files = [] + processor.file_exists = {} + + def mock_download_file(url, destination_path, hasher): + processor.downloaded_files.append((url, destination_path)) + processor.file_exists[destination_path] = True + # Simulate writing some data to the file + test_data = b"test data" + hasher.update(test_data) + + def mock_remove_file(file_path): + processor.removed_files.append(file_path) + if file_path in processor.file_exists: + del processor.file_exists[file_path] + + # Setup common patches + file_exists_patch = patch.object( + processor, + "_file_exists", + side_effect=lambda path: processor.file_exists.get(path, False), + ) + file_size_patch = patch.object( + processor, + "_get_file_size", + side_effect=lambda path: ( + 1000 if processor.file_exists.get(path, False) else 0 + ), + ) + download_file_patch = patch.object( + processor, "_download_file", side_effect=mock_download_file + ) + remove_file_patch = patch.object( + processor, "_remove_file", side_effect=mock_remove_file + ) + + with ( + file_exists_patch, + file_size_patch, + download_file_patch, + remove_file_patch, + ): + yield processor + + +def test_ensure_downloaded_invalid_extension(model_processor): + # Ensure that an unsupported file extension raises an error to prevent unsafe file downloads + with pytest.raises(ValueError, match="Unsupported unsafe file for download"): + model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe") + + +def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session): + # Ensure that a file with the same hash but from a different source is not downloaded again + SOURCE_URL = "https://example.com/other.sft" + create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL) + model_processor.file_exists[TEST_DESTINATION_PATH] = True + + result = model_processor.ensure_downloaded( + TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH + ) + + assert result == TEST_DESTINATION_PATH + model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE) + assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten + + +def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session): + # Ensure that a file with a different hash raises an error + create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash") + model_processor.file_exists[TEST_DESTINATION_PATH] = True + + with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"): + model_processor.ensure_downloaded( + TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH + ) + + +def test_ensure_downloaded_new_file(model_processor, db_session): + # Ensure that a new file is downloaded + model_processor.file_exists[TEST_DESTINATION_PATH] = False + + with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH): + result = model_processor.ensure_downloaded( + TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH + ) + + assert result == TEST_DESTINATION_PATH + assert len(model_processor.downloaded_files) == 1 + assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH) + assert model_processor.file_exists[TEST_DESTINATION_PATH] + verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE) + + +def test_ensure_downloaded_hash_mismatch(model_processor, db_session): + # Ensure that download that results in a different hash raises an error + model_processor.file_exists[TEST_DESTINATION_PATH] = False + + with setup_mock_hash_calculation(model_processor, "different_hash"): + with pytest.raises( + ValueError, + match="Downloaded file hash .* does not match expected hash .*", + ): + model_processor.ensure_downloaded( + TEST_MODEL_TYPE, + TEST_URL, + TEST_FILE_NAME, + TEST_EXPECTED_HASH, + ) + + assert len(model_processor.removed_files) == 1 + assert model_processor.removed_files[0] == TEST_DESTINATION_PATH + assert TEST_DESTINATION_PATH not in model_processor.file_exists + assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None + + +def test_process_file_without_hash(model_processor, db_session): + # Test processing file without provided hash + model_processor.file_exists[TEST_DESTINATION_PATH] = True + + with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH): + result = model_processor.process_file(TEST_DESTINATION_PATH) + assert result is not None + assert result.hash == TEST_EXPECTED_HASH + + +def test_retrieve_model_by_hash(model_processor, db_session): + # Test retrieving model by hash + create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH) + result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH) + assert result is not None + assert result.hash == TEST_EXPECTED_HASH + + +def test_retrieve_model_by_hash_and_type(model_processor, db_session): + # Test retrieving model by hash and type + create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH) + result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE) + assert result is not None + assert result.hash == TEST_EXPECTED_HASH + assert result.type == TEST_MODEL_TYPE + + +def test_retrieve_hash(model_processor, db_session): + # Test retrieving hash for existing model + create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH) + with patch.object( + model_processor, + "_validate_path", + return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME), + ): + result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE) + assert result == TEST_EXPECTED_HASH + + +def test_validate_file_extension_valid_extensions(model_processor): + # Test all valid file extensions + valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"] + for ext in valid_extensions: + model_processor._validate_file_extension(f"test{ext}") # Should not raise + + +def test_process_file_existing_without_source_url(model_processor, db_session): + # Test processing an existing file that needs its source URL updated + model_processor.file_exists[TEST_DESTINATION_PATH] = True + + create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH) + result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL) + + assert result is not None + assert result.hash == TEST_EXPECTED_HASH + assert result.source_url == TEST_URL + + db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() + assert db_model.source_url == TEST_URL From 7f7b3f1695a132ecd92e6c1a45e8571a59031b42 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 1 Jun 2025 15:41:00 +0100 Subject: [PATCH 03/82] tidy --- app/frontend_management.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index 3e54e4d512c8..9b4d0794794d 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -153,11 +153,11 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None: class FrontendManager: """ A class to manage ComfyUI frontend versions and installations. - + This class handles the initialization and management of different frontend versions, including the default frontend from the pip package and custom frontend versions from GitHub repositories. - + Attributes: CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored. """ @@ -168,10 +168,10 @@ class FrontendManager: def default_frontend_path(cls) -> str: """ Get the path to the default frontend installation from the pip package. - + Returns: str: The path to the default frontend static files. - + Raises: SystemExit: If the comfyui-frontend-package is not installed. """ @@ -197,10 +197,10 @@ def default_frontend_path(cls) -> str: def templates_path(cls) -> str: """ Get the path to the workflow templates. - + Returns: str: The path to the workflow templates directory. - + Raises: SystemExit: If the comfyui-workflow-templates package is not installed. """ @@ -240,16 +240,16 @@ def embedded_docs_path(cls) -> str: def parse_version_string(cls, value: str) -> tuple[str, str, str]: """ Parse a version string into its components. - + The version string should be in the format: 'owner/repo@version' where version can be either a semantic version (v1.2.3) or 'latest'. - + Args: value (str): The version string to parse. - + Returns: tuple[str, str, str]: A tuple containing (owner, repo, version). - + Raises: argparse.ArgumentTypeError: If the version string is invalid. """ @@ -266,18 +266,18 @@ def init_frontend_unsafe( ) -> str: """ Initialize a frontend version without error handling. - + This method attempts to initialize a specific frontend version, either from the default pip package or from a custom GitHub repository. It will download and extract the frontend files if necessary. - + Args: version_string (str): The version string specifying which frontend to use. provider (FrontEndProvider, optional): The provider to use for custom frontends. - + Returns: str: The path to the initialized frontend. - + Raises: Exception: If there is an error during initialization (e.g., network timeout, invalid URL, or missing assets). @@ -333,13 +333,13 @@ def init_frontend_unsafe( def init_frontend(cls, version_string: str) -> str: """ Initialize a frontend version with error handling. - + This is the main method to initialize a frontend version. It wraps init_frontend_unsafe with error handling, falling back to the default frontend if initialization fails. - + Args: version_string (str): The version string specifying which frontend to use. - + Returns: str: The path to the initialized frontend. If initialization fails, returns the path to the default frontend. From 7d5160f92c7688519ce2e062be036218bc18cd0d Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 1 Jun 2025 15:45:15 +0100 Subject: [PATCH 04/82] Tidy --- alembic_db/versions/e9c714da8d57_init.py | 4 +--- app/model_processor.py | 6 +++--- comfy/cli_args.py | 2 +- folder_paths.py | 8 ++++---- tests-unit/app_test/model_processor_test.py | 8 ++++---- 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/alembic_db/versions/e9c714da8d57_init.py b/alembic_db/versions/e9c714da8d57_init.py index 1a296104436f..995365f90c62 100644 --- a/alembic_db/versions/e9c714da8d57_init.py +++ b/alembic_db/versions/e9c714da8d57_init.py @@ -1,7 +1,7 @@ """init Revision ID: e9c714da8d57 -Revises: +Revises: Create Date: 2025-05-30 20:14:33.772039 """ @@ -20,7 +20,6 @@ def upgrade() -> None: """Upgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### op.create_table('model', sa.Column('type', sa.Text(), nullable=False), sa.Column('path', sa.Text(), nullable=False), @@ -32,7 +31,6 @@ def upgrade() -> None: sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), sa.PrimaryKeyConstraint('type', 'path') ) - # ### end Alembic commands ### def downgrade() -> None: diff --git a/app/model_processor.py b/app/model_processor.py index 6cf8fd6fae7a..5018c2fe6039 100644 --- a/app/model_processor.py +++ b/app/model_processor.py @@ -64,7 +64,7 @@ def _get_existing_model(self, session, model_type, model_relative_path): .filter(Model.path == model_relative_path) .first() ) - + def _ensure_source_url(self, session, model, source_url): if model.source_url is None: model.source_url = source_url @@ -171,9 +171,9 @@ def retrieve_model_by_hash(self, model_hash, model_type=None, session=None): try: if not can_create_session(): return - + dispose_session = False - + if session is None: session = create_session() dispose_session = True diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 154491fe07d3..ce20eb404eb4 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -204,7 +204,7 @@ def is_valid_directory(path: str) -> str: ) database_default_path = os.path.abspath( - os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") + os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.") diff --git a/folder_paths.py b/folder_paths.py index 5b5554a30669..e4916eff484e 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -303,23 +303,23 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str: def get_relative_path(full_path: str) -> tuple[str, str] | None: """Convert a full path back to a type-relative path. - + Args: full_path: The full path to the file - + Returns: tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise """ global folder_names_and_paths full_path = os.path.normpath(full_path) - + for model_type, (paths, _) in folder_names_and_paths.items(): for base_path in paths: base_path = os.path.normpath(base_path) if full_path.startswith(base_path): relative_path = os.path.relpath(full_path, base_path) return model_type, relative_path - + return None def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]: diff --git a/tests-unit/app_test/model_processor_test.py b/tests-unit/app_test/model_processor_test.py index d1e43d375e7e..692b631b2485 100644 --- a/tests-unit/app_test/model_processor_test.py +++ b/tests-unit/app_test/model_processor_test.py @@ -195,7 +195,7 @@ def test_ensure_downloaded_hash_mismatch(model_processor, db_session): def test_process_file_without_hash(model_processor, db_session): # Test processing file without provided hash model_processor.file_exists[TEST_DESTINATION_PATH] = True - + with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH): result = model_processor.process_file(TEST_DESTINATION_PATH) assert result is not None @@ -241,13 +241,13 @@ def test_validate_file_extension_valid_extensions(model_processor): def test_process_file_existing_without_source_url(model_processor, db_session): # Test processing an existing file that needs its source URL updated model_processor.file_exists[TEST_DESTINATION_PATH] = True - + create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH) result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL) - + assert result is not None assert result.hash == TEST_EXPECTED_HASH assert result.source_url == TEST_URL - + db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() assert db_model.source_url == TEST_URL From d7062277a717f876247bc802e65b033a9e88880f Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 3 Aug 2025 16:40:27 +0100 Subject: [PATCH 05/82] fix bad merge --- main.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/main.py b/main.py index 3b97d48b3603..81001c0b5b86 100644 --- a/main.py +++ b/main.py @@ -286,15 +286,6 @@ def setup_database(): except Exception as e: logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") -def setup_database(): - try: - from app.database.db import init_db, dependencies_available - if dependencies_available(): - init_db() - except Exception as e: - logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") - - def start_comfyui(asyncio_loop=None): """ Starts the ComfyUI server using the provided asyncio event loop or creates a new one. From f032c1a50a854a06c8951adffbfbe2cc1a6fbe49 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 11 Aug 2025 22:25:42 -0700 Subject: [PATCH 06/82] Brainstorming abstraction for asset management stuff --- comfy/asset_management.py | 90 +++++++++++++++++++++++++++++++++++++++ comfy_api/latest/_io.py | 14 ++++++ 2 files changed, 104 insertions(+) create mode 100644 comfy/asset_management.py diff --git a/comfy/asset_management.py b/comfy/asset_management.py new file mode 100644 index 000000000000..6bdb18f45264 --- /dev/null +++ b/comfy/asset_management.py @@ -0,0 +1,90 @@ +from abc import ABC, abstractmethod +from typing import Any, TypedDict +import os +import logging +import comfy.utils + +import folder_paths + +class AssetMetadata(TypedDict): + device: Any + return_metadata: bool + subdir: str + download_url: str + +class AssetInfo: + def __init__(self, hash: str=None, name: str=None, tags: list[str]=None, metadata: AssetMetadata={}): + self.hash = hash + self.name = name + self.tags = tags + self.metadata = metadata + + +class ReturnedAssetABC(ABC): + def __init__(self, mimetype: str): + self.mimetype = mimetype + + +class ModelReturnedAsset(ReturnedAssetABC): + def __init__(self, model: dict[str, str] | tuple[dict[str, str], dict[str, str]]): + super().__init__("model") + self.model = model + + +class AssetResolverABC(ABC): + @abstractmethod + def resolve(self, asset_info: AssetInfo) -> ReturnedAssetABC: + ... + + +class LocalAssetResolver(AssetResolverABC): + def resolve(self, asset_info: AssetInfo) -> ReturnedAssetABC: + # currently only supports models - make sure models is in the tags + if "models" not in asset_info.tags: + return None + # TODO: if hash exists, call model processor to try to get info about model: + if asset_info.hash: + ... + # if subdir metadata and name exists, use that as the model name going forward + if "subdir" in asset_info.metadata and asset_info.name: + relative_path = os.path.join(asset_info.metadata["subdir"], asset_info.name) + # the good ol' bread and butter - folder_paths's keys as tags + folder_keys = folder_paths.folder_names_and_paths.keys() + parent_paths = [] + for tag in asset_info.tags: + if tag in folder_keys: + parent_paths.append(tag) + if len(parent_paths) == 0: + return None + # now we have the parent keys, we can try to get the local path + chosen_parent = None + full_path = None + for parent_path in parent_paths: + full_path = folder_paths.get_full_path(parent_path, relative_path) + if full_path: + chosen_parent = parent_path + break + logging.info(f"Resolved {asset_info.name} to {full_path} in {chosen_parent}") + # we know the path, so load the model and return it + model = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=asset_info.metadata.get("return_metadata", False)) + return ModelReturnedAsset(model) + # TODO: if name exists, try to find model by name in all subdirs of parent paths + if asset_info.name: + ... + # TODO: if download_url metadata exists, download the model and load it + if asset_info.metadata.get("download_url", None): + ... + return None + + +resolvers: list[AssetResolverABC] = [] + + +def resolve(asset_info: AssetInfo) -> Any: + global resolvers + for resolver in resolvers: + try: + return resolver.resolve(asset_info) + except Exception as e: + logging.error(f"Error resolving asset {asset_info.hash}: {e}") + return None diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index ec1efb51d1a4..bc5582094f09 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -384,6 +384,20 @@ def as_dict(self): }) return to_return +@comfytype(io_type="ASSET") +class Asset(ComfyTypeI): + class Input(WidgetInput): + def __init__(self, id: str, query_tags: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: str=None, socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + self.query_tags = query_tags + + def as_dict(self): + to_return = super().as_dict() | prune_dict({ + "query_tags": self.query_tags + }) + return to_return + @comfytype(io_type="IMAGE") class Image(ComfyTypeIO): Type = torch.Tensor From 1aa089e0b6c21185a930377f0255f19e2e5202ba Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 12 Aug 2025 17:59:16 -0700 Subject: [PATCH 07/82] More progress on brainstorming code for asset management for models --- comfy/asset_management.py | 60 ++++++++++++++++++++----------- comfy_extras/nodes_assets_test.py | 56 +++++++++++++++++++++++++++++ nodes.py | 1 + 3 files changed, 97 insertions(+), 20 deletions(-) create mode 100644 comfy_extras/nodes_assets_test.py diff --git a/comfy/asset_management.py b/comfy/asset_management.py index 6bdb18f45264..e47996320cf5 100644 --- a/comfy/asset_management.py +++ b/comfy/asset_management.py @@ -26,9 +26,10 @@ def __init__(self, mimetype: str): class ModelReturnedAsset(ReturnedAssetABC): - def __init__(self, model: dict[str, str] | tuple[dict[str, str], dict[str, str]]): + def __init__(self, state_dict: dict[str, str], metadata: dict[str, str]=None): super().__init__("model") - self.model = model + self.state_dict = state_dict + self.metadata = metadata class AssetResolverABC(ABC): @@ -38,24 +39,30 @@ def resolve(self, asset_info: AssetInfo) -> ReturnedAssetABC: class LocalAssetResolver(AssetResolverABC): - def resolve(self, asset_info: AssetInfo) -> ReturnedAssetABC: + def resolve(self, asset_info: AssetInfo, cache_result: bool=True) -> ReturnedAssetABC: # currently only supports models - make sure models is in the tags if "models" not in asset_info.tags: return None # TODO: if hash exists, call model processor to try to get info about model: if asset_info.hash: - ... + try: + from app.model_processor import model_processor + model_db = model_processor.retrieve_model_by_hash(asset_info.hash) + full_path = model_db.path + except Exception as e: + logging.error(f"Could not get model by hash with error: {e}") + # the good ol' bread and butter - folder_paths's keys as tags + folder_keys = folder_paths.folder_names_and_paths.keys() + parent_paths = [] + for tag in asset_info.tags: + if tag in folder_keys: + parent_paths.append(tag) # if subdir metadata and name exists, use that as the model name going forward if "subdir" in asset_info.metadata and asset_info.name: - relative_path = os.path.join(asset_info.metadata["subdir"], asset_info.name) - # the good ol' bread and butter - folder_paths's keys as tags - folder_keys = folder_paths.folder_names_and_paths.keys() - parent_paths = [] - for tag in asset_info.tags: - if tag in folder_keys: - parent_paths.append(tag) + # if no matching parent paths, then something went wrong and should return None if len(parent_paths) == 0: return None + relative_path = os.path.join(asset_info.metadata["subdir"], asset_info.name) # now we have the parent keys, we can try to get the local path chosen_parent = None full_path = None @@ -64,27 +71,40 @@ def resolve(self, asset_info: AssetInfo) -> ReturnedAssetABC: if full_path: chosen_parent = parent_path break - logging.info(f"Resolved {asset_info.name} to {full_path} in {chosen_parent}") - # we know the path, so load the model and return it - model = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=asset_info.metadata.get("return_metadata", False)) - return ModelReturnedAsset(model) - # TODO: if name exists, try to find model by name in all subdirs of parent paths + if full_path is not None: + logging.info(f"Resolved {asset_info.name} to {full_path} in {chosen_parent}") + # we know the path, so load the model and return it + state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True) + # TODO: handle caching + return ModelReturnedAsset(state_dict, metadata) + # if just name exists, try to find model by name in all subdirs of parent paths + # TODO: this behavior should be configurable by user if asset_info.name: - ... - # TODO: if download_url metadata exists, download the model and load it + for parent_path in parent_paths: + filelist = folder_paths.get_filename_list(parent_path) + for file in filelist: + if os.path.basename(file) == asset_info.name: + full_path = folder_paths.get_full_path(parent_path, file) + state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True) + # TODO: handle caching + return ModelReturnedAsset(state_dict, metadata) + # TODO: if download_url metadata exists, download the model and load it; this should be configurable by user if asset_info.metadata.get("download_url", None): ... return None resolvers: list[AssetResolverABC] = [] +resolvers.append(LocalAssetResolver()) def resolve(asset_info: AssetInfo) -> Any: global resolvers for resolver in resolvers: try: - return resolver.resolve(asset_info) + to_return = resolver.resolve(asset_info) + if to_return is not None: + return resolver.resolve(asset_info) except Exception as e: - logging.error(f"Error resolving asset {asset_info.hash}: {e}") + logging.error(f"Error resolving asset {asset_info.name} using {resolver.__class__.__name__}: {e}") return None diff --git a/comfy_extras/nodes_assets_test.py b/comfy_extras/nodes_assets_test.py new file mode 100644 index 000000000000..5172cd6282c6 --- /dev/null +++ b/comfy_extras/nodes_assets_test.py @@ -0,0 +1,56 @@ +from comfy_api.latest import io, ComfyExtension +import comfy.asset_management +import comfy.sd +import folder_paths +import logging +import os + + +class AssetTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="AssetTestNode", + is_experimental=True, + inputs=[ + io.Combo.Input("ckpt_name", folder_paths.get_filename_list("checkpoints")), + ], + outputs=[ + io.Model.Output(), + io.Clip.Output(), + io.Vae.Output(), + ], + ) + + @classmethod + def execute(cls, ckpt_name: str): + hash = None + # lets get the full path just so we can retrieve the hash from db, if exists + try: + full_path = folder_paths.get_full_path("checkpoints", ckpt_name) + if full_path is None: + raise Exception(f"Model {ckpt_name} not found") + from app.model_processor import model_processor + hash = model_processor.retrieve_hash(full_path) + except Exception as e: + logging.error(f"Could not get model by hash with error: {e}") + subdir, name = os.path.split(ckpt_name) + asset_info = comfy.asset_management.AssetInfo(hash=hash, name=name, tags=["models", "checkpoints"], metadata={"subdir": subdir}) + asset = comfy.asset_management.resolve(asset_info) + # /\ the stuff above should happen in execution code instead of inside the node + # \/ the stuff below should happen in the node - confirm is a model asset, do stuff to it (already loaded? or should be called to 'load'?) + if asset is None: + raise Exception(f"Model {asset_info.name} not found") + assert isinstance(asset, comfy.asset_management.ModelReturnedAsset) + out = comfy.sd.load_state_dict_guess_config(asset.state_dict, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), metadata=asset.metadata) + return io.NodeOutput(out[0], out[1], out[2]) + + +class AssetTestExtension(ComfyExtension): + @classmethod + async def get_node_list(cls): + return [AssetTestNode] + + +def comfy_entrypoint(): + return AssetTestExtension() diff --git a/nodes.py b/nodes.py index 9448f9c1b383..0f1f4e93793e 100644 --- a/nodes.py +++ b/nodes.py @@ -2320,6 +2320,7 @@ async def init_builtin_extra_nodes(): "nodes_camera_trajectory.py", "nodes_edit_model.py", "nodes_tcfg.py", + "nodes_assets_test.py", ] import_failed = [] From f92307cd4c13a3b1981cc818fce0a56280d3e404 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 19 Aug 2025 19:56:59 +0300 Subject: [PATCH 08/82] dev: Everything is Assets --- alembic_db/versions/0001_assets.py | 158 +++++ alembic_db/versions/e9c714da8d57_init.py | 40 -- app/api/__init__.py | 0 app/api/assets_routes.py | 110 ++++ app/assets_manager.py | 148 +++++ app/database/__init__.py | 0 app/database/db.py | 287 ++++++-- app/database/models.py | 298 +++++++-- app/database/services.py | 683 ++++++++++++++++++++ app/model_processor.py | 331 ---------- app/storage/__init__.py | 0 app/storage/hashing.py | 72 +++ comfy/asset_management.py | 110 ---- comfy/cli_args.py | 2 +- comfy/utils.py | 6 - comfy_extras/nodes_assets_test.py | 56 -- main.py | 9 +- nodes.py | 7 +- requirements.txt | 1 + server.py | 2 + tests-unit/app_test/model_manager_test.py | 62 -- tests-unit/app_test/model_processor_test.py | 253 -------- 22 files changed, 1654 insertions(+), 981 deletions(-) create mode 100644 alembic_db/versions/0001_assets.py delete mode 100644 alembic_db/versions/e9c714da8d57_init.py create mode 100644 app/api/__init__.py create mode 100644 app/api/assets_routes.py create mode 100644 app/assets_manager.py create mode 100644 app/database/__init__.py create mode 100644 app/database/services.py delete mode 100644 app/model_processor.py create mode 100644 app/storage/__init__.py create mode 100644 app/storage/hashing.py delete mode 100644 comfy/asset_management.py delete mode 100644 comfy_extras/nodes_assets_test.py delete mode 100644 tests-unit/app_test/model_manager_test.py delete mode 100644 tests-unit/app_test/model_processor_test.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py new file mode 100644 index 000000000000..6705d8122440 --- /dev/null +++ b/alembic_db/versions/0001_assets.py @@ -0,0 +1,158 @@ +"""initial assets schema + per-asset state cache + +Revision ID: 0001_assets +Revises: +Create Date: 2025-08-20 00:00:00 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0001_assets" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ASSETS: content identity (deduplicated by hash) + op.create_table( + "assets", + sa.Column("hash", sa.String(length=128), primary_key=True), + sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), + sa.Column("mime_type", sa.String(length=255), nullable=True), + sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"), + sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"), + sa.Column("storage_locator", sa.Text(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), + sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"), + ) + op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) + op.create_index("ix_assets_backend_locator", "assets", ["storage_backend", "storage_locator"]) + + # ASSETS_INFO: user-visible references (mutable metadata) + op.create_table( + "assets_info", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("owner_id", sa.String(length=128), nullable=True), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.Column("last_access_time", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sqlite_autoincrement=True, + ) + op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) + op.create_index("ix_assets_info_asset_hash", "assets_info", ["asset_hash"]) + op.create_index("ix_assets_info_name", "assets_info", ["name"]) + op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) + op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + + # TAGS: normalized tag vocabulary + op.create_table( + "tags", + sa.Column("name", sa.String(length=128), primary_key=True), + sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"), + sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"), + ) + op.create_index("ix_tags_tag_type", "tags", ["tag_type"]) + + # ASSET_INFO_TAGS: many-to-many for tags on AssetInfo + op.create_table( + "asset_info_tags", + sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), + sa.Column("added_by", sa.String(length=128), nullable=True), + sa.Column("added_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), + ) + op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) + op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) + + # ASSET_LOCATOR_STATE: 1:1 filesystem metadata(for fast integrity checking) for an Asset records + op.create_table( + "asset_locator_state", + sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True), + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("etag", sa.String(length=256), nullable=True), + sa.Column("last_modified", sa.String(length=128), nullable=True), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"), + ) + + # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting + op.create_table( + "asset_info_meta", + sa.Column("asset_info_id", sa.Integer(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"), + ) + op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"]) + op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"]) + op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) + op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) + + # Tags vocabulary for models + tags_table = sa.table( + "tags", + sa.column("name", sa.String()), + sa.column("tag_type", sa.String()), + ) + op.bulk_insert( + tags_table, + [ + # Core concept tags + {"name": "models", "tag_type": "system"}, + + # Canonical single-word types + {"name": "checkpoint", "tag_type": "system"}, + {"name": "lora", "tag_type": "system"}, + {"name": "vae", "tag_type": "system"}, + {"name": "text-encoder", "tag_type": "system"}, + {"name": "clip-vision", "tag_type": "system"}, + {"name": "embedding", "tag_type": "system"}, + {"name": "controlnet", "tag_type": "system"}, + {"name": "upscale", "tag_type": "system"}, + {"name": "diffusion-model", "tag_type": "system"}, + {"name": "hypernetwork", "tag_type": "system"}, + {"name": "vae_approx", "tag_type": "system"}, + # TODO: decide what to do with: style_models, diffusers, gligen, photomaker, classifiers + ], + ) + + +def downgrade() -> None: + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") + op.drop_table("asset_info_meta") + + op.drop_table("asset_locator_state") + + op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") + op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") + op.drop_table("asset_info_tags") + + op.drop_index("ix_tags_tag_type", table_name="tags") + op.drop_table("tags") + + op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") + op.drop_index("ix_assets_info_created_at", table_name="assets_info") + op.drop_index("ix_assets_info_name", table_name="assets_info") + op.drop_index("ix_assets_info_asset_hash", table_name="assets_info") + op.drop_index("ix_assets_info_owner_id", table_name="assets_info") + op.drop_table("assets_info") + + op.drop_index("ix_assets_backend_locator", table_name="assets") + op.drop_index("ix_assets_mime_type", table_name="assets") + op.drop_table("assets") diff --git a/alembic_db/versions/e9c714da8d57_init.py b/alembic_db/versions/e9c714da8d57_init.py deleted file mode 100644 index 995365f90c62..000000000000 --- a/alembic_db/versions/e9c714da8d57_init.py +++ /dev/null @@ -1,40 +0,0 @@ -"""init - -Revision ID: e9c714da8d57 -Revises: -Create Date: 2025-05-30 20:14:33.772039 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = 'e9c714da8d57' -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - """Upgrade schema.""" - op.create_table('model', - sa.Column('type', sa.Text(), nullable=False), - sa.Column('path', sa.Text(), nullable=False), - sa.Column('file_name', sa.Text(), nullable=True), - sa.Column('file_size', sa.Integer(), nullable=True), - sa.Column('hash', sa.Text(), nullable=True), - sa.Column('hash_algorithm', sa.Text(), nullable=True), - sa.Column('source_url', sa.Text(), nullable=True), - sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.PrimaryKeyConstraint('type', 'path') - ) - - -def downgrade() -> None: - """Downgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('model') - # ### end Alembic commands ### diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py new file mode 100644 index 000000000000..aed1d3cea1f5 --- /dev/null +++ b/app/api/assets_routes.py @@ -0,0 +1,110 @@ +import json +from typing import Sequence +from aiohttp import web + +from app import assets_manager + + +ROUTES = web.RouteTableDef() + + +@ROUTES.get("/api/assets") +async def list_assets(request: web.Request) -> web.Response: + q = request.rel_url.query + + include_tags: Sequence[str] = _parse_csv_tags(q.get("include_tags")) + exclude_tags: Sequence[str] = _parse_csv_tags(q.get("exclude_tags")) + name_contains = q.get("name_contains") + + # Optional JSON metadata filter (top-level key equality only for now) + metadata_filter = None + raw_meta = q.get("metadata_filter") + if raw_meta: + try: + metadata_filter = json.loads(raw_meta) + if not isinstance(metadata_filter, dict): + metadata_filter = None + except Exception: + # Silently ignore malformed JSON for first iteration; could 400 in future + metadata_filter = None + + limit = _parse_int(q.get("limit"), default=20, lo=1, hi=100) + offset = _parse_int(q.get("offset"), default=0, lo=0, hi=10_000_000) + sort = q.get("sort", "created_at") + order = q.get("order", "desc") + + payload = await assets_manager.list_assets( + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + offset=offset, + sort=sort, + order=order, + ) + return web.json_response(payload) + + +@ROUTES.put("/api/assets/{id}") +async def update_asset(request: web.Request) -> web.Response: + asset_info_id_raw = request.match_info.get("id") + try: + asset_info_id = int(asset_info_id_raw) + except Exception: + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + + try: + payload = await request.json() + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + name = payload.get("name", None) + tags = payload.get("tags", None) + user_metadata = payload.get("user_metadata", None) + + if name is None and tags is None and user_metadata is None: + return _error_response(400, "NO_FIELDS", "Provide at least one of: name, tags, user_metadata.") + + if tags is not None and (not isinstance(tags, list) or not all(isinstance(t, str) for t in tags)): + return _error_response(400, "INVALID_TAGS", "Field 'tags' must be an array of strings.") + + if user_metadata is not None and not isinstance(user_metadata, dict): + return _error_response(400, "INVALID_METADATA", "Field 'user_metadata' must be an object.") + + try: + result = await assets_manager.update_asset( + asset_info_id=asset_info_id, + name=name, + tags=tags, + user_metadata=user_metadata, + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result, status=200) + + +def register_assets_routes(app: web.Application) -> None: + app.add_routes(ROUTES) + + +def _parse_csv_tags(raw: str | None) -> list[str]: + if not raw: + return [] + return [t.strip() for t in raw.split(",") if t.strip()] + + +def _parse_int(qval: str | None, default: int, lo: int, hi: int) -> int: + if not qval: + return default + try: + v = int(qval) + except Exception: + return default + return max(lo, min(hi, v)) + + +def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response: + return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status) diff --git a/app/assets_manager.py b/app/assets_manager.py new file mode 100644 index 000000000000..1cccd6acb443 --- /dev/null +++ b/app/assets_manager.py @@ -0,0 +1,148 @@ +import os +from datetime import datetime, timezone +from typing import Optional, Sequence + +from .database.db import create_session +from .storage import hashing +from .database.services import ( + check_fs_asset_exists_quick, + ingest_fs_asset, + touch_asset_infos_by_fs_path, + list_asset_infos_page, + update_asset_info_full, + get_asset_tags, +) + + +def get_size_mtime_ns(path: str) -> tuple[int, int]: + st = os.stat(path, follow_symlinks=True) + return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + + +async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None: + """Adds a local asset to the DB. If already present and unchanged, does nothing. + + Notes: + - Uses absolute path as the canonical locator for the 'fs' backend. + - Computes BLAKE3 only when the fast existence check indicates it's needed. + - This function ensures the identity row and seeds mtime in asset_locator_state. + """ + abs_path = os.path.abspath(file_path) + size_bytes, mtime_ns = get_size_mtime_ns(abs_path) + if not size_bytes: + return + + async with await create_session() as session: + if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns): + await touch_asset_infos_by_fs_path(session, abs_path=abs_path, ts=datetime.now(timezone.utc)) + await session.commit() + return + + asset_hash = hashing.blake3_hash_sync(abs_path) + + async with await create_session() as session: + await ingest_fs_asset( + session, + asset_hash="blake3:" + asset_hash, + abs_path=abs_path, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=None, + info_name=file_name, + tag_origin="automatic", + tags=tags, + ) + await session.commit() + + +async def list_assets( + *, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: Optional[str] = None, + metadata_filter: Optional[dict] = None, + limit: int = 20, + offset: int = 0, + sort: str | None = "created_at", + order: str | None = "desc", +) -> dict: + sort = _safe_sort_field(sort) + order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() + + async with await create_session() as session: + infos, tag_map, total = await list_asset_infos_page( + session, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + offset=offset, + sort=sort, + order=order, + ) + + assets_json = [] + for info in infos: + asset = info.asset # populated via contains_eager + tags = tag_map.get(info.id, []) + assets_json.append( + { + "id": info.id, + "name": info.name, + "asset_hash": info.asset_hash, + "size": int(asset.size_bytes) if asset else None, + "mime_type": asset.mime_type if asset else None, + "tags": tags, + "preview_url": f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later + "created_at": info.created_at.isoformat() if info.created_at else None, + "updated_at": info.updated_at.isoformat() if info.updated_at else None, + "last_access_time": info.last_access_time.isoformat() if info.last_access_time else None, + } + ) + + return { + "assets": assets_json, + "total": total, + "has_more": (offset + len(assets_json)) < total, + } + + +async def update_asset( + *, + asset_info_id: int, + name: str | None = None, + tags: list[str] | None = None, + user_metadata: dict | None = None, +) -> dict: + async with await create_session() as session: + info = await update_asset_info_full( + session, + asset_info_id=asset_info_id, + name=name, + tags=tags, + user_metadata=user_metadata, + tag_origin="manual", + added_by=None, + ) + + tag_names = await get_asset_tags(session, asset_info_id=asset_info_id) + await session.commit() + + return { + "id": info.id, + "name": info.name, + "asset_hash": info.asset_hash, + "tags": tag_names, + "user_metadata": info.user_metadata or {}, + "updated_at": info.updated_at.isoformat() if info.updated_at else None, + } + + +def _safe_sort_field(requested: str | None) -> str: + if not requested: + return "created_at" + v = requested.lower() + if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: + return v + return "created_at" diff --git a/app/database/__init__.py b/app/database/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/app/database/db.py b/app/database/db.py index 1de8b80edd8a..2a619f13b751 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -1,112 +1,267 @@ import logging import os import shutil +from contextlib import asynccontextmanager +from typing import Optional + from app.logger import log_startup_warning from utils.install_util import get_missing_requirements_message from comfy.cli_args import args -_DB_AVAILABLE = False -Session = None +LOGGER = logging.getLogger(__name__) +# Attempt imports which may not exist in some environments try: from alembic import command from alembic.config import Config from alembic.runtime.migration import MigrationContext from alembic.script import ScriptDirectory - from sqlalchemy import create_engine - from sqlalchemy.orm import sessionmaker + from sqlalchemy import create_engine, text + from sqlalchemy.engine import make_url + from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine _DB_AVAILABLE = True + ENGINE: AsyncEngine | None = None + SESSION: async_sessionmaker | None = None except ImportError as e: log_startup_warning( - f""" ------------------------------------------------------------------------- -Error importing dependencies: {e} -{get_missing_requirements_message()} -This error is happening because ComfyUI now uses a local sqlite database. ------------------------------------------------------------------------- -""".strip() + ( + "------------------------------------------------------------------------\n" + f"Error importing DB dependencies: {e}\n" + f"{get_missing_requirements_message()}\n" + "This error is happening because ComfyUI now uses a local database.\n" + "------------------------------------------------------------------------" + ).strip() ) + _DB_AVAILABLE = False + ENGINE = None + SESSION = None -def dependencies_available(): - """ - Temporary function to check if the dependencies are available - """ +def dependencies_available() -> bool: + """Check if DB dependencies are importable.""" return _DB_AVAILABLE -def can_create_session(): - """ - Temporary function to check if the database is available to create a session - During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created - """ - return dependencies_available() and Session is not None - - -def get_alembic_config(): - root_path = os.path.join(os.path.dirname(__file__), "../..") +def _root_paths(): + """Resolve alembic.ini and migrations script folder.""" + root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) + return config_path, scripts_path - config = Config(config_path) - config.set_main_option("script_location", scripts_path) - config.set_main_option("sqlalchemy.url", args.database_url) - return config +def _absolutize_sqlite_url(db_url: str) -> str: + """Make SQLite database path absolute. No-op for non-SQLite URLs.""" + try: + u = make_url(db_url) + except Exception: + return db_url + if not u.drivername.startswith("sqlite"): + return db_url -def get_db_path(): - url = args.database_url - if url.startswith("sqlite:///"): - return url.split("///")[1] + # Make path absolute if relative + db_path = u.database or "" + if not os.path.isabs(db_path): + db_path = os.path.abspath(os.path.join(os.getcwd(), db_path)) + u = u.set(database=db_path) + return str(u) + + +def _to_sync_driver_url(async_url: str) -> str: + """Convert an async SQLAlchemy URL to a sync URL for Alembic.""" + u = make_url(async_url) + driver = u.drivername + + if driver.startswith("sqlite+aiosqlite"): + u = u.set(drivername="sqlite") + elif driver.startswith("postgresql+asyncpg"): + u = u.set(drivername="postgresql") else: - raise ValueError(f"Unsupported database URL '{url}'.") + # Generic: strip the async driver part if present + if "+" in driver: + u = u.set(drivername=driver.split("+", 1)[0]) + + return str(u) + + +def _get_sqlite_file_path(sync_url: str) -> Optional[str]: + """Return the on-disk path for a SQLite URL, else None.""" + try: + u = make_url(sync_url) + except Exception: + return None + if not u.drivername.startswith("sqlite"): + return None + return u.database -def init_db(): - db_url = args.database_url - logging.debug(f"Database URL: {db_url}") - db_path = get_db_path() - db_exists = os.path.exists(db_path) - config = get_alembic_config() +def _get_alembic_config(sync_url: str) -> Config: + """Prepare Alembic Config with script location and DB URL.""" + config_path, scripts_path = _root_paths() + cfg = Config(config_path) + cfg.set_main_option("script_location", scripts_path) + cfg.set_main_option("sqlalchemy.url", sync_url) + return cfg + + +async def init_db_engine() -> None: + """Initialize async engine + sessionmaker and run migrations to head. + + This must be called once on application startup before any DB usage. + """ + global ENGINE, SESSION + + if not dependencies_available(): + raise RuntimeError("Database dependencies are not available.") + + if ENGINE is not None: + return + + raw_url = args.database_url + if not raw_url: + raise RuntimeError("Database URL is not configured.") + + # Absolutize SQLite path for async engine + db_url = _absolutize_sqlite_url(raw_url) + + # Prepare async engine + connect_args = {} + if db_url.startswith("sqlite"): + connect_args = { + "check_same_thread": False, + "timeout": 12, + } + + ENGINE = create_async_engine( + db_url, + connect_args=connect_args, + pool_pre_ping=True, + future=True, + ) + + # Enforce SQLite pragmas on the async engine + if db_url.startswith("sqlite"): + async with ENGINE.begin() as conn: + # WAL for concurrency and durability, Foreign Keys for referential integrity + current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar() + if str(current_mode).lower() != "wal": + new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar() + if str(new_mode).lower() != "wal": + raise RuntimeError("Failed to set SQLite journal mode to WAL.") + LOGGER.info("SQLite journal mode set to WAL.") - # Check if we need to upgrade - engine = create_engine(db_url) - conn = engine.connect() + await conn.execute(text("PRAGMA foreign_keys = ON;")) + await conn.execute(text("PRAGMA synchronous = NORMAL;")) - context = MigrationContext.configure(conn) - current_rev = context.get_current_revision() + await _run_migrations(raw_url=db_url) - script = ScriptDirectory.from_config(config) + SESSION = async_sessionmaker( + bind=ENGINE, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + + +async def _run_migrations(raw_url: str) -> None: + """ + Run Alembic migrations up to head. + + We deliberately use a synchronous engine for migrations because Alembic's + programmatic API is synchronous by default and this path is robust. + """ + # Convert to sync URL and make SQLite URL an absolute one + sync_url = _to_sync_driver_url(raw_url) + sync_url = _absolutize_sqlite_url(sync_url) + + cfg = _get_alembic_config(sync_url) + + # Inspect current and target heads + engine = create_engine(sync_url, future=True) + with engine.connect() as conn: + context = MigrationContext.configure(conn) + current_rev = context.get_current_revision() + + script = ScriptDirectory.from_config(cfg) target_rev = script.get_current_head() if target_rev is None: - logging.warning("No target revision found.") - elif current_rev != target_rev: - # Backup the database pre upgrade - backup_path = db_path + ".bkp" - if db_exists: - shutil.copy(db_path, backup_path) - else: - backup_path = None + LOGGER.warning("Alembic: no target revision found.") + return + + if current_rev == target_rev: + LOGGER.debug("Alembic: database already at head %s", target_rev) + return + + LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev) + # Optional backup for SQLite file DBs + backup_path = None + sqlite_path = _get_sqlite_file_path(sync_url) + if sqlite_path and os.path.exists(sqlite_path): + backup_path = sqlite_path + ".bkp" try: - command.upgrade(config, target_rev) - logging.info(f"Database upgraded from {current_rev} to {target_rev}") - except Exception as e: - if backup_path: - # Restore the database from backup if upgrade fails - shutil.copy(backup_path, db_path) + shutil.copy(sqlite_path, backup_path) + except Exception as exc: + LOGGER.warning("Failed to create SQLite backup before migration: %s", exc) + + try: + command.upgrade(cfg, target_rev) + except Exception: + if backup_path and os.path.exists(backup_path): + LOGGER.exception("Error upgrading database, attempting restore from backup.") + try: + shutil.copy(backup_path, sqlite_path) # restore os.remove(backup_path) - logging.exception("Error upgrading database: ") - raise e + except Exception as re: + LOGGER.error("Failed to restore SQLite backup: %s", re) + else: + LOGGER.exception("Error upgrading database, backup is not available.") + raise + + +def get_engine(): + """Return the global async engine (initialized after init_db_engine()).""" + if ENGINE is None: + raise RuntimeError("Engine is not initialized. Call init_db_engine() first.") + return ENGINE + + +def get_session_maker(): + """Return the global async_sessionmaker (initialized after init_db_engine()).""" + if SESSION is None: + raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.") + return SESSION - global Session - Session = sessionmaker(bind=engine) +@asynccontextmanager +async def session_scope(): + """Async context manager for a unit of work: -def create_session(): - return Session() + async with session_scope() as sess: + ... use sess ... + """ + maker = get_session_maker() + async with maker() as sess: + try: + yield sess + await sess.commit() + except Exception: + await sess.rollback() + raise + + +async def create_session(): + """Convenience helper to acquire a single AsyncSession instance. + + Typical usage: + async with (await create_session()) as sess: + ... + """ + maker = get_session_maker() + return maker() diff --git a/app/database/models.py b/app/database/models.py index b0225c41204c..ca7ad67f81e6 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,59 +1,257 @@ +from datetime import datetime +from typing import Any, Optional + from sqlalchemy import ( - Column, Integer, - Text, + BigInteger, DateTime, + ForeignKey, + Index, + JSON, + String, + Text, + CheckConstraint, + Numeric, + Boolean, ) -from sqlalchemy.orm import declarative_base from sqlalchemy.sql import func +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign -Base = declarative_base() +class Base(DeclarativeBase): + pass -def to_dict(obj): + +def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: fields = obj.__table__.columns.keys() - return { - field: (val.to_dict() if hasattr(val, "to_dict") else val) - for field in fields - if (val := getattr(obj, field)) - } - - -class Model(Base): - """ - sqlalchemy model representing a model file in the system. - - This class defines the database schema for storing information about model files, - including their type, path, hash, and when they were added to the system. - - Attributes: - type (Text): The type of the model, this is the name of the folder in the models folder (primary key) - path (Text): The file path of the model relative to the type folder (primary key) - file_name (Text): The name of the model file - file_size (Integer): The size of the model file in bytes - hash (Text): A hash of the model file - hash_algorithm (Text): The algorithm used to generate the hash - source_url (Text): The URL of the model file - date_added (DateTime): Timestamp of when the model was added to the system - """ - - __tablename__ = "model" - - type = Column(Text, primary_key=True) - path = Column(Text, primary_key=True) - file_name = Column(Text) - file_size = Column(Integer) - hash = Column(Text) - hash_algorithm = Column(Text) - source_url = Column(Text) - date_added = Column(DateTime, server_default=func.now()) - - def to_dict(self): - """ - Convert the model instance to a dictionary representation. - - Returns: - dict: A dictionary containing the attributes of the model - """ - dict = to_dict(self) - return dict + out: dict[str, Any] = {} + for field in fields: + val = getattr(obj, field) + if val is None and not include_none: + continue + if isinstance(val, datetime): + out[field] = val.isoformat() + else: + out[field] = val + return out + + +class Asset(Base): + __tablename__ = "assets" + + hash: Mapped[str] = mapped_column(String(256), primary_key=True) + size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + mime_type: Mapped[str | None] = mapped_column(String(255)) + refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs") + storage_locator: Mapped[str] = mapped_column(Text, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + ) + + infos: Mapped[list["AssetInfo"]] = relationship( + "AssetInfo", + back_populates="asset", + primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash), + foreign_keys=lambda: [AssetInfo.asset_hash], + cascade="all,delete-orphan", + passive_deletes=True, + ) + + preview_of: Mapped[list["AssetInfo"]] = relationship( + "AssetInfo", + back_populates="preview_asset", + primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash), + foreign_keys=lambda: [AssetInfo.preview_hash], + viewonly=True, + ) + + locator_state: Mapped["AssetLocatorState | None"] = relationship( + back_populates="asset", + uselist=False, + cascade="all, delete-orphan", + passive_deletes=True, + ) + + __table_args__ = ( + Index("ix_assets_mime_type", "mime_type"), + Index("ix_assets_backend_locator", "storage_backend", "storage_locator"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + return to_dict(self, include_none=include_none) + + def __repr__(self) -> str: + return f"" + + +class AssetLocatorState(Base): + __tablename__ = "asset_locator_state" + + asset_hash: Mapped[str] = mapped_column( + String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True + ) + # For fs backends: nanosecond mtime; nullable if not applicable + mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + # For HTTP/S3/GCS/Azure, etc.: optional validators + etag: Mapped[str | None] = mapped_column(String(256), nullable=True) + last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True) + + asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False) + + __table_args__ = ( + CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + return to_dict(self, include_none=include_none) + + def __repr__(self) -> str: + return f"" + + +class AssetInfo(Base): + __tablename__ = "assets_info" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + owner_id: Mapped[str | None] = mapped_column(String(128)) + name: Mapped[str] = mapped_column(String(512), nullable=False) + asset_hash: Mapped[str] = mapped_column( + String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False + ) + preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL")) + user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + ) + last_access_time: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + ) + + # Relationships + asset: Mapped[Asset] = relationship( + "Asset", + back_populates="infos", + foreign_keys=[asset_hash], + ) + preview_asset: Mapped[Asset | None] = relationship( + "Asset", + back_populates="preview_of", + foreign_keys=[preview_hash], + ) + + metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship( + back_populates="asset_info", + cascade="all,delete-orphan", + passive_deletes=True, + ) + + tag_links: Mapped[list["AssetInfoTag"]] = relationship( + back_populates="asset_info", + cascade="all,delete-orphan", + passive_deletes=True, + overlaps="tags,asset_infos", + ) + + tags: Mapped[list["Tag"]] = relationship( + secondary="asset_info_tags", + back_populates="asset_infos", + lazy="joined", + viewonly=True, + overlaps="tag_links,asset_info_links,asset_infos,tag", + ) + + __table_args__ = ( + Index("ix_assets_info_owner_id", "owner_id"), + Index("ix_assets_info_asset_hash", "asset_hash"), + Index("ix_assets_info_name", "name"), + Index("ix_assets_info_created_at", "created_at"), + Index("ix_assets_info_last_access_time", "last_access_time"), + {"sqlite_autoincrement": True}, + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + data = to_dict(self, include_none=include_none) + data["tags"] = [t.name for t in self.tags] + return data + + def __repr__(self) -> str: + return f"" + + + +class AssetInfoMeta(Base): + __tablename__ = "asset_info_meta" + + asset_info_id: Mapped[int] = mapped_column( + Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + ) + key: Mapped[str] = mapped_column(String(256), primary_key=True) + ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0) + + val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) + val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True) + val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True) + val_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True) + + asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries") + + __table_args__ = ( + Index("ix_asset_info_meta_key", "key"), + Index("ix_asset_info_meta_key_val_str", "key", "val_str"), + Index("ix_asset_info_meta_key_val_num", "key", "val_num"), + Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"), + ) + + +class AssetInfoTag(Base): + __tablename__ = "asset_info_tags" + + asset_info_id: Mapped[int] = mapped_column( + Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + ) + tag_name: Mapped[str] = mapped_column( + String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True + ) + origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") + added_by: Mapped[str | None] = mapped_column(String(128)) + added_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links") + tag: Mapped["Tag"] = relationship(back_populates="asset_info_links") + + __table_args__ = ( + Index("ix_asset_info_tags_tag_name", "tag_name"), + Index("ix_asset_info_tags_asset_info_id", "asset_info_id"), + ) + + +class Tag(Base): + __tablename__ = "tags" + + name: Mapped[str] = mapped_column(String(128), primary_key=True) + tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") + + asset_info_links: Mapped[list["AssetInfoTag"]] = relationship( + back_populates="tag", + overlaps="asset_infos,tags", + ) + asset_infos: Mapped[list["AssetInfo"]] = relationship( + secondary="asset_info_tags", + back_populates="tags", + viewonly=True, + overlaps="asset_info_links,tag_links,tags,asset_info", + ) + + __table_args__ = ( + Index("ix_tags_tag_type", "tag_type"), + ) + + def __repr__(self) -> str: + return f"" diff --git a/app/database/services.py b/app/database/services.py new file mode 100644 index 000000000000..c2792b4c4501 --- /dev/null +++ b/app/database/services.py @@ -0,0 +1,683 @@ +import os +import logging +from collections import defaultdict +from datetime import datetime, timezone +from decimal import Decimal +from typing import Any, Sequence, Optional, Iterable + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, delete, exists, func +from sqlalchemy.orm import contains_eager +from sqlalchemy.exc import IntegrityError + +from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta + + +async def check_fs_asset_exists_quick( + session, + *, + file_path: str, + size_bytes: Optional[int] = None, + mtime_ns: Optional[int] = None, +) -> bool: + """ + Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path, + AND (if provided) mtime_ns matches stored locator-state, + AND (if provided) size_bytes matches verified size when known. + """ + locator = os.path.abspath(file_path) + + stmt = select(sa.literal(True)).select_from(Asset) + + conditions = [ + Asset.storage_backend == "fs", + Asset.storage_locator == locator, + ] + + # If size_bytes provided require equality when the asset has a verified (non-zero) size. + # If verified size is 0 (unknown), we don't force equality. + if size_bytes is not None: + conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes))) + + # If mtime_ns provided require the locator-state to exist and match. + if mtime_ns is not None: + stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash) + conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns)) + + stmt = stmt.where(*conditions).limit(1) + + row = (await session.execute(stmt)).first() + return row is not None + + +async def ingest_fs_asset( + session: AsyncSession, + *, + asset_hash: str, + abs_path: str, + size_bytes: int, + mtime_ns: int, + mime_type: Optional[str] = None, + info_name: Optional[str] = None, + owner_id: Optional[str] = None, + preview_hash: Optional[str] = None, + user_metadata: Optional[dict] = None, + tags: Sequence[str] = (), + tag_origin: str = "manual", + added_by: Optional[str] = None, + require_existing_tags: bool = False, +) -> dict: + """ + Creates or updates Asset record for a local (fs) asset. + + Always: + - Insert Asset if missing; else update size_bytes (and updated_at) if different. + - Insert AssetLocatorState if missing; else update mtime_ns if different. + + Optionally (when info_name is provided): + - Create an AssetInfo (no refcount changes). + - Link provided tags to that AssetInfo. + * If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table. + * If False (default), silently skips unknown tags. + + Returns flags and ids: + { + "asset_created": bool, + "asset_updated": bool, + "state_created": bool, + "state_updated": bool, + "asset_info_id": int | None, + "tags_added": list[str], + "tags_missing": list[str], # filled only when require_existing_tags=False + } + """ + locator = os.path.abspath(abs_path) + datetime_now = datetime.now(timezone.utc) + + out = { + "asset_created": False, + "asset_updated": False, + "state_created": False, + "state_updated": False, + "asset_info_id": None, + "tags_added": [], + "tags_missing": [], + } + + # ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ---- + async with session.begin_nested() as sp1: + try: + session.add( + Asset( + hash=asset_hash, + size_bytes=int(size_bytes), + mime_type=mime_type, + refcount=0, + storage_backend="fs", + storage_locator=locator, + created_at=datetime_now, + updated_at=datetime_now, + ) + ) + await session.flush() + out["asset_created"] = True + except IntegrityError: + await sp1.rollback() + # Already exists by hash -> update selected fields if different + existing = await session.get(Asset, asset_hash) + if existing is not None: + desired_size = int(size_bytes) + if existing.size_bytes != desired_size: + existing.size_bytes = desired_size + existing.updated_at = datetime_now + out["asset_updated"] = True + else: + # This should not occur. Log for visibility. + logging.error("Asset %s not found after conflict; skipping update.", asset_hash) + except Exception: + await sp1.rollback() + logging.exception("Unexpected error inserting Asset (hash=%s, locator=%s)", asset_hash, locator) + raise + + # ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ---- + async with session.begin_nested() as sp2: + try: + session.add( + AssetLocatorState( + asset_hash=asset_hash, + mtime_ns=int(mtime_ns), + ) + ) + await session.flush() + out["state_created"] = True + except IntegrityError: + await sp2.rollback() + state = await session.get(AssetLocatorState, asset_hash) + if state is not None: + desired_mtime = int(mtime_ns) + if state.mtime_ns != desired_mtime: + state.mtime_ns = desired_mtime + out["state_updated"] = True + else: + logging.debug("Locator state missing for %s after conflict; skipping update.", asset_hash) + except Exception: + await sp2.rollback() + logging.exception("Unexpected error inserting AssetLocatorState (hash=%s)", asset_hash) + raise + + # ---- Optional: AssetInfo + tag links ---- + if info_name: + # 2a) Create AssetInfo (no refcount bump) + async with session.begin_nested() as sp3: + try: + info = AssetInfo( + owner_id=owner_id, + name=info_name, + asset_hash=asset_hash, + preview_hash=preview_hash, + created_at=datetime_now, + updated_at=datetime_now, + last_access_time=datetime_now, + ) + session.add(info) + await session.flush() # get info.id + out["asset_info_id"] = info.id + except Exception: + await sp3.rollback() + logging.exception( + "Unexpected error inserting AssetInfo (hash=%s, name=%s)", asset_hash, info_name + ) + raise + + # 2b) Link tags (if any). We DO NOT create new Tag rows here by default. + norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] + if norm and out["asset_info_id"] is not None: + # Which tags exist? + existing_tag_names = set( + name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() + ) + missing = [t for t in norm if t not in existing_tag_names] + if missing and require_existing_tags: + raise ValueError(f"Unknown tags: {missing}") + + # Which links already exist? + existing_links = set( + tag_name + for (tag_name,) in ( + await session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) + ) + ).all() + ) + to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] + if to_add: + session.add_all( + [ + AssetInfoTag( + asset_info_id=out["asset_info_id"], + tag_name=t, + origin=tag_origin, + added_by=added_by, + added_at=datetime_now, + ) + for t in to_add + ] + ) + await session.flush() + out["tags_added"] = to_add + out["tags_missing"] = missing + + # 2c) Rebuild metadata projection if provided + if user_metadata is not None and out["asset_info_id"] is not None: + await replace_asset_info_metadata_projection( + session, + asset_info_id=out["asset_info_id"], + user_metadata=user_metadata, + ) + return out + + +async def touch_asset_infos_by_fs_path( + session: AsyncSession, + *, + abs_path: str, + ts: Optional[datetime] = None, + only_if_newer: bool = True, +) -> int: + locator = os.path.abspath(abs_path) + ts = ts or datetime.now(timezone.utc) + + stmt = sa.update(AssetInfo).where( + sa.exists( + sa.select(sa.literal(1)) + .select_from(Asset) + .where( + Asset.hash == AssetInfo.asset_hash, + Asset.storage_backend == "fs", + Asset.storage_locator == locator, + ) + ) + ) + + if only_if_newer: + stmt = stmt.where( + sa.or_( + AssetInfo.last_access_time.is_(None), + AssetInfo.last_access_time < ts, + ) + ) + + stmt = stmt.values(last_access_time=ts) + + res = await session.execute(stmt) + return int(res.rowcount or 0) + + +async def list_asset_infos_page( + session: AsyncSession, + *, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", +) -> tuple[list[AssetInfo], dict[int, list[str]], int]: + """ + Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1), + plus a map of asset_info_id -> [tags], and the total count. + + We purposely collect tags in a separate (single) query to avoid row explosion. + """ + # Clamp + if limit <= 0: + limit = 1 + if limit > 100: + limit = 100 + if offset < 0: + offset = 0 + + # Build base query + base = ( + select(AssetInfo) + .join(Asset, Asset.hash == AssetInfo.asset_hash) + .options(contains_eager(AssetInfo.asset)) + ) + + # Filters + if name_contains: + base = base.where(AssetInfo.name.ilike(f"%{name_contains}%")) + + base = _apply_tag_filters(base, include_tags, exclude_tags) + + base = _apply_metadata_filter(base, metadata_filter) + + # Sort + sort = (sort or "created_at").lower() + order = (order or "desc").lower() + sort_map = { + "name": AssetInfo.name, + "created_at": AssetInfo.created_at, + "updated_at": AssetInfo.updated_at, + "last_access_time": AssetInfo.last_access_time, + "size": Asset.size_bytes, + } + sort_col = sort_map.get(sort, AssetInfo.created_at) + sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + + base = base.order_by(sort_exp).limit(limit).offset(offset) + + # Total count (same filters, no ordering/limit/offset) + count_stmt = ( + select(func.count()) + .select_from(AssetInfo) + .join(Asset, Asset.hash == AssetInfo.asset_hash) + ) + if name_contains: + count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%")) + count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) + + total = (await session.execute(count_stmt)).scalar_one() + + # Fetch rows + infos = (await session.execute(base)).scalars().unique().all() + + # Collect tags in bulk (single query) + id_list = [i.id for i in infos] + tag_map: dict[int, list[str]] = defaultdict(list) + if id_list: + rows = await session.execute( + select(AssetInfoTag.asset_info_id, Tag.name) + .join(Tag, Tag.name == AssetInfoTag.tag_name) + .where(AssetInfoTag.asset_info_id.in_(id_list)) + ) + for aid, tag_name in rows.all(): + tag_map[aid].append(tag_name) + + return infos, tag_map, total + + +async def set_asset_info_tags( + session: AsyncSession, + *, + asset_info_id: int, + tags: Sequence[str], + origin: str = "manual", + added_by: Optional[str] = None, +) -> dict: + """ + Replace the tag set on an AssetInfo with `tags`. Idempotent. + Creates missing tag names as 'user'. + """ + desired = _normalize_tags(tags) + now = datetime.now(timezone.utc) + + # current links + current = set( + tag_name for (tag_name,) in ( + await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) + ).all() + ) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + await _ensure_tags_exist(session, to_add, tag_type="user") + session.add_all([ + AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=now) + for t in to_add + ]) + await session.flush() + + if to_remove: + await session.execute( + delete(AssetInfoTag) + .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) + ) + await session.flush() + + return {"added": to_add, "removed": to_remove, "total": desired} + + +async def update_asset_info_full( + session: AsyncSession, + *, + asset_info_id: int, + name: Optional[str] = None, + tags: Optional[Sequence[str]] = None, + user_metadata: Optional[dict] = None, + tag_origin: str = "manual", + added_by: Optional[str] = None, +) -> AssetInfo: + """ + Update AssetInfo fields: + - name (if provided) + - user_metadata blob + rebuild projection (if provided) + - replace tags with provided set (if provided) + Returns the updated AssetInfo. + """ + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + touched = False + if name is not None and name != info.name: + info.name = name + touched = True + + if user_metadata is not None: + await replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=user_metadata + ) + touched = True + + if tags is not None: + await set_asset_info_tags( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=tag_origin, + added_by=added_by, + ) + touched = True + + if touched and user_metadata is None: + info.updated_at = datetime.now(timezone.utc) + await session.flush() + + return info + + +async def replace_asset_info_metadata_projection( + session: AsyncSession, + *, + asset_info_id: int, + user_metadata: dict | None, +) -> None: + """Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`.""" + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + info.user_metadata = user_metadata or {} + info.updated_at = datetime.now(timezone.utc) + await session.flush() + + await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) + await session.flush() + + if not user_metadata: + return + + rows: list[AssetInfoMeta] = [] + for k, v in user_metadata.items(): + for r in _project_kv(k, v): + rows.append( + AssetInfoMeta( + asset_info_id=asset_info_id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + await session.flush() + + +async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]: + return [ + tag_name + for (tag_name,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + ] + + +def _normalize_tags(tags: Sequence[str] | None) -> list[str]: + return [t.strip().lower() for t in (tags or []) if (t or "").strip()] + + +async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: + wanted = _normalize_tags(list(names)) + if not wanted: + return [] + existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() + by_name = {t.name: t for t in existing} + to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name] + if to_create: + session.add_all(to_create) + await session.flush() + by_name.update({t.name: t for t in to_create}) + return [by_name[n] for n in wanted] + + +def _apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Sequence[str] | None, + exclude_tags: Sequence[str] | None, +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = _normalize_tags(include_tags) + exclude_tags = _normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + +def _apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: dict | None, +) -> sa.sql.Select: + """Apply metadata filters using the projection table asset_info_meta. + + Semantics: + - For scalar values: require EXISTS(asset_info_meta) with matching key + typed value. + - For None: key is missing OR key has explicit null (val_json IS NULL). + - For list values: ANY-of the list elements matches (EXISTS for any). + (Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))') + """ + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + subquery = ( + select(sa.literal(1)) + .select_from(AssetInfoMeta) + .where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + *preds, + ) + .limit(1) + ) + return sa.exists(subquery) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + # Missing OR null: + if value is None: + # either: no row for key OR a row for key with explicit null + no_row_for_key = ~sa.exists( + select(sa.literal(1)) + .select_from(AssetInfoMeta) + .where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + ) + .limit(1) + ) + null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None)) + return sa.or_(no_row_for_key, null_row) + + # Typed scalar matches: + if isinstance(value, bool): + return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) + if isinstance(value, (int, float, Decimal)): + # store as Decimal for equality against NUMERIC(38,10) + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetInfoMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetInfoMeta.val_str == value) + + # Complex: compare JSON (no index, but supported) + return _exists_for_pred(key, AssetInfoMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + # ANY-of (exists for any element) + ors = [ _exists_clause_for_value(k, elem) for elem in v ] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt + + +def _is_scalar(v: Any) -> bool: + if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + + +def _project_kv(key: str, value: Any) -> list[dict]: + """ + Turn a metadata key/value into one or more projection rows: + - scalar -> one row (ordinal=0) in the proper typed column + - list of scalars -> one row per element with ordinal=i + - dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list) + - None -> single row with val_json = None + Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...} + """ + rows: list[dict] = [] + + # None + if value is None: + rows.append({"key": key, "ordinal": 0, "val_json": None}) + return rows + + # Scalars + if _is_scalar(value): + if isinstance(value, bool): + rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) + elif isinstance(value, (int, float, Decimal)): + # store numeric; SQLAlchemy will coerce to Numeric + rows.append({"key": key, "ordinal": 0, "val_num": value}) + elif isinstance(value, str): + rows.append({"key": key, "ordinal": 0, "val_str": value}) + else: + # Fallback to json + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows + + # Lists + if isinstance(value, list): + # list of scalars? + if all(_is_scalar(x) for x in value): + for i, x in enumerate(value): + if x is None: + rows.append({"key": key, "ordinal": i, "val_json": None}) + elif isinstance(x, bool): + rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) + elif isinstance(x, (int, float, Decimal)): + rows.append({"key": key, "ordinal": i, "val_num": x}) + elif isinstance(x, str): + rows.append({"key": key, "ordinal": i, "val_str": x}) + else: + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + # list contains objects -> one val_json per element + for i, x in enumerate(value): + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + + # Dict or any other structure -> single json row + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows diff --git a/app/model_processor.py b/app/model_processor.py deleted file mode 100644 index 5018c2fe6039..000000000000 --- a/app/model_processor.py +++ /dev/null @@ -1,331 +0,0 @@ -import os -import logging -import time - -import requests -from tqdm import tqdm -from folder_paths import get_relative_path, get_full_path -from app.database.db import create_session, dependencies_available, can_create_session -import blake3 -import comfy.utils - - -if dependencies_available(): - from app.database.models import Model - - -class ModelProcessor: - def _validate_path(self, model_path): - try: - if not self._file_exists(model_path): - logging.error(f"Model file not found: {model_path}") - return None - - result = get_relative_path(model_path) - if not result: - logging.error( - f"Model file not in a recognized model directory: {model_path}" - ) - return None - - return result - except Exception as e: - logging.error(f"Error validating model path {model_path}: {str(e)}") - return None - - def _file_exists(self, path): - """Check if a file exists.""" - return os.path.exists(path) - - def _get_file_size(self, path): - """Get file size.""" - return os.path.getsize(path) - - def _get_hasher(self): - return blake3.blake3() - - def _hash_file(self, model_path): - try: - hasher = self._get_hasher() - with open(model_path, "rb", buffering=0) as f: - b = bytearray(128 * 1024) - mv = memoryview(b) - while n := f.readinto(mv): - hasher.update(mv[:n]) - return hasher.hexdigest() - except Exception as e: - logging.error(f"Error hashing file {model_path}: {str(e)}") - return None - - def _get_existing_model(self, session, model_type, model_relative_path): - return ( - session.query(Model) - .filter(Model.type == model_type) - .filter(Model.path == model_relative_path) - .first() - ) - - def _ensure_source_url(self, session, model, source_url): - if model.source_url is None: - model.source_url = source_url - session.commit() - - def _update_database( - self, - session, - model_type, - model_path, - model_relative_path, - model_hash, - model, - source_url, - ): - try: - if not model: - model = self._get_existing_model( - session, model_type, model_relative_path - ) - - if not model: - model = Model( - path=model_relative_path, - type=model_type, - file_name=os.path.basename(model_path), - ) - session.add(model) - - model.file_size = self._get_file_size(model_path) - model.hash = model_hash - if model_hash: - model.hash_algorithm = "blake3" - model.source_url = source_url - - session.commit() - return model - except Exception as e: - logging.error( - f"Error updating database for {model_relative_path}: {str(e)}" - ) - - def process_file(self, model_path, source_url=None, model_hash=None): - """ - Process a model file and update the database with metadata. - If the file already exists and matches the database, it will not be processed again. - Returns the model object or if an error occurs, returns None. - """ - try: - if not can_create_session(): - return - - result = self._validate_path(model_path) - if not result: - return - model_type, model_relative_path = result - - with create_session() as session: - session.expire_on_commit = False - - existing_model = self._get_existing_model( - session, model_type, model_relative_path - ) - if ( - existing_model - and existing_model.hash - and existing_model.file_size == self._get_file_size(model_path) - ): - # File exists with hash and same size, no need to process - self._ensure_source_url(session, existing_model, source_url) - return existing_model - - if model_hash: - model_hash = model_hash.lower() - logging.info(f"Using provided hash: {model_hash}") - else: - start_time = time.time() - logging.info(f"Hashing model {model_relative_path}") - model_hash = self._hash_file(model_path) - if not model_hash: - return - logging.info( - f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)" - ) - - return self._update_database( - session, - model_type, - model_path, - model_relative_path, - model_hash, - existing_model, - source_url, - ) - except Exception as e: - logging.error(f"Error processing model file {model_path}: {str(e)}") - return None - - def retrieve_model_by_hash(self, model_hash, model_type=None, session=None): - """ - Retrieve a model file from the database by hash and optionally by model type. - Returns the model object or None if the model doesnt exist or an error occurs. - """ - try: - if not can_create_session(): - return - - dispose_session = False - - if session is None: - session = create_session() - dispose_session = True - - model = session.query(Model).filter(Model.hash == model_hash) - if model_type is not None: - model = model.filter(Model.type == model_type) - return model.first() - except Exception as e: - logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}") - return None - finally: - if dispose_session: - session.close() - - def retrieve_hash(self, model_path, model_type=None): - """ - Retrieve the hash of a model file from the database. - Returns the hash or None if the model doesnt exist or an error occurs. - """ - try: - if not can_create_session(): - return - - if model_type is not None: - result = self._validate_path(model_path) - if not result: - return None - model_type, model_relative_path = result - - with create_session() as session: - model = self._get_existing_model( - session, model_type, model_relative_path - ) - if model and model.hash: - return model.hash - return None - except Exception as e: - logging.error(f"Error retrieving hash for {model_path}: {str(e)}") - return None - - def _validate_file_extension(self, file_name): - """Validate that the file extension is supported.""" - extension = os.path.splitext(file_name)[1] - if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"): - raise ValueError(f"Unsupported unsafe file for download: {file_name}") - - def _check_existing_file(self, model_type, file_name, expected_hash): - """Check if file exists and has correct hash.""" - destination_path = get_full_path(model_type, file_name, allow_missing=True) - if self._file_exists(destination_path): - model = self.process_file(destination_path) - if model and (expected_hash is None or model.hash == expected_hash): - logging.debug( - f"File {destination_path} already exists in the database and has the correct hash or no hash was provided." - ) - return destination_path - else: - raise ValueError( - f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again." - ) - return None - - def _check_existing_file_by_hash(self, hash, type, url): - """Check if a file with the given hash exists in the database and on disk.""" - hash = hash.lower() - with create_session() as session: - model = self.retrieve_model_by_hash(hash, type, session) - if model: - existing_path = get_full_path(type, model.path) - if existing_path: - logging.debug( - f"File {model.path} already exists in the database at {existing_path}" - ) - self._ensure_source_url(session, model, url) - return existing_path - else: - logging.debug( - f"File {model.path} exists in the database but not on disk" - ) - return None - - def _download_file(self, url, destination_path, hasher): - """Download a file and update the hasher with its contents.""" - response = requests.get(url, stream=True) - logging.info(f"Downloading {url} to {destination_path}") - - with open(destination_path, "wb") as f: - total_size = int(response.headers.get("content-length", 0)) - if total_size > 0: - pbar = comfy.utils.ProgressBar(total_size) - else: - pbar = None - with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar: - for chunk in response.iter_content(chunk_size=128 * 1024): - if chunk: - f.write(chunk) - hasher.update(chunk) - progress_bar.update(len(chunk)) - if pbar: - pbar.update(len(chunk)) - - def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path): - """Verify that the downloaded file has the expected hash.""" - if expected_hash is not None and calculated_hash != expected_hash: - self._remove_file(destination_path) - raise ValueError( - f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}" - ) - - def _remove_file(self, file_path): - """Remove a file from disk.""" - os.remove(file_path) - - def ensure_downloaded(self, type, url, desired_file_name, hash=None): - """ - Ensure a model file is downloaded and has the correct hash. - Returns the path to the downloaded file. - """ - logging.debug( - f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'" - ) - - # Validate file extension - self._validate_file_extension(desired_file_name) - - # Check if file exists with correct hash - if hash: - existing_path = self._check_existing_file_by_hash(hash, type, url) - if existing_path: - return existing_path - - # Check if file exists locally - destination_path = get_full_path(type, desired_file_name, allow_missing=True) - existing_path = self._check_existing_file(type, desired_file_name, hash) - if existing_path: - return existing_path - - # Download the file - hasher = self._get_hasher() - self._download_file(url, destination_path, hasher) - - # Verify hash - calculated_hash = hasher.hexdigest() - self._verify_downloaded_hash(calculated_hash, hash, destination_path) - - # Update database - self.process_file(destination_path, url, calculated_hash) - - # TODO: Notify frontend to reload models - - return destination_path - - -model_processor = ModelProcessor() diff --git a/app/storage/__init__.py b/app/storage/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/app/storage/hashing.py b/app/storage/hashing.py new file mode 100644 index 000000000000..3eaed77a33eb --- /dev/null +++ b/app/storage/hashing.py @@ -0,0 +1,72 @@ +import asyncio +import os +from typing import IO, Union + +from blake3 import blake3 + +DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB + + +def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str: + """Hash an already-open binary file object by streaming in chunks. + - Seeks to the beginning before reading (if supported). + - Restores the original position afterward (if tell/seek are supported). + """ + if chunk_size <= 0: + chunk_size = DEFAULT_CHUNK + + orig_pos = None + if hasattr(file_obj, "tell"): + orig_pos = file_obj.tell() + + try: + if hasattr(file_obj, "seek"): + file_obj.seek(0) + + h = blake3() + while True: + chunk = file_obj.read(chunk_size) + if not chunk: + break + h.update(chunk) + return h.hexdigest() + finally: + if hasattr(file_obj, "seek") and orig_pos is not None: + file_obj.seek(orig_pos) + + +def blake3_hash_sync( + fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]], + chunk_size: int = DEFAULT_CHUNK, +) -> str: + """Returns a BLAKE3 hex digest for ``fp``, which may be: + - a filename (str/bytes) or PathLike + - an open binary file object + + If ``fp`` is a file object, it must be opened in **binary** mode and support + ``read``, ``seek``, and ``tell``. The function will seek to the start before + reading and will attempt to restore the original position afterward. + """ + if hasattr(fp, "read"): + return _hash_file_obj_sync(fp, chunk_size) + + with open(os.fspath(fp), "rb") as f: + return _hash_file_obj_sync(f, chunk_size) + + +async def blake3_hash( + fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]], + chunk_size: int = DEFAULT_CHUNK, +) -> str: + """Async wrapper for ``blake3_hash_sync``. + Uses a worker thread so the event loop remains responsive. + """ + # If it is a path, open inside the worker thread to keep I/O off the loop. + if hasattr(fp, "read"): + return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size) + + def _worker() -> str: + with open(os.fspath(fp), "rb") as f: + return _hash_file_obj_sync(f, chunk_size) + + return await asyncio.to_thread(_worker) diff --git a/comfy/asset_management.py b/comfy/asset_management.py deleted file mode 100644 index e47996320cf5..000000000000 --- a/comfy/asset_management.py +++ /dev/null @@ -1,110 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, TypedDict -import os -import logging -import comfy.utils - -import folder_paths - -class AssetMetadata(TypedDict): - device: Any - return_metadata: bool - subdir: str - download_url: str - -class AssetInfo: - def __init__(self, hash: str=None, name: str=None, tags: list[str]=None, metadata: AssetMetadata={}): - self.hash = hash - self.name = name - self.tags = tags - self.metadata = metadata - - -class ReturnedAssetABC(ABC): - def __init__(self, mimetype: str): - self.mimetype = mimetype - - -class ModelReturnedAsset(ReturnedAssetABC): - def __init__(self, state_dict: dict[str, str], metadata: dict[str, str]=None): - super().__init__("model") - self.state_dict = state_dict - self.metadata = metadata - - -class AssetResolverABC(ABC): - @abstractmethod - def resolve(self, asset_info: AssetInfo) -> ReturnedAssetABC: - ... - - -class LocalAssetResolver(AssetResolverABC): - def resolve(self, asset_info: AssetInfo, cache_result: bool=True) -> ReturnedAssetABC: - # currently only supports models - make sure models is in the tags - if "models" not in asset_info.tags: - return None - # TODO: if hash exists, call model processor to try to get info about model: - if asset_info.hash: - try: - from app.model_processor import model_processor - model_db = model_processor.retrieve_model_by_hash(asset_info.hash) - full_path = model_db.path - except Exception as e: - logging.error(f"Could not get model by hash with error: {e}") - # the good ol' bread and butter - folder_paths's keys as tags - folder_keys = folder_paths.folder_names_and_paths.keys() - parent_paths = [] - for tag in asset_info.tags: - if tag in folder_keys: - parent_paths.append(tag) - # if subdir metadata and name exists, use that as the model name going forward - if "subdir" in asset_info.metadata and asset_info.name: - # if no matching parent paths, then something went wrong and should return None - if len(parent_paths) == 0: - return None - relative_path = os.path.join(asset_info.metadata["subdir"], asset_info.name) - # now we have the parent keys, we can try to get the local path - chosen_parent = None - full_path = None - for parent_path in parent_paths: - full_path = folder_paths.get_full_path(parent_path, relative_path) - if full_path: - chosen_parent = parent_path - break - if full_path is not None: - logging.info(f"Resolved {asset_info.name} to {full_path} in {chosen_parent}") - # we know the path, so load the model and return it - state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True) - # TODO: handle caching - return ModelReturnedAsset(state_dict, metadata) - # if just name exists, try to find model by name in all subdirs of parent paths - # TODO: this behavior should be configurable by user - if asset_info.name: - for parent_path in parent_paths: - filelist = folder_paths.get_filename_list(parent_path) - for file in filelist: - if os.path.basename(file) == asset_info.name: - full_path = folder_paths.get_full_path(parent_path, file) - state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True) - # TODO: handle caching - return ModelReturnedAsset(state_dict, metadata) - # TODO: if download_url metadata exists, download the model and load it; this should be configurable by user - if asset_info.metadata.get("download_url", None): - ... - return None - - -resolvers: list[AssetResolverABC] = [] -resolvers.append(LocalAssetResolver()) - - -def resolve(asset_info: AssetInfo) -> Any: - global resolvers - for resolver in resolvers: - try: - to_return = resolver.resolve(asset_info) - if to_return is not None: - return resolver.resolve(asset_info) - except Exception as e: - logging.error(f"Error resolving asset {asset_info.name} using {resolver.__class__.__name__}: {e}") - return None diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 84d0173143b1..9ab78b99b149 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -211,7 +211,7 @@ def is_valid_directory(path: str) -> str: database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) -parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") +parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.") parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.") if comfy.options.args_parsing: diff --git a/comfy/utils.py b/comfy/utils.py index af2aace0a37c..220492941342 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -102,12 +102,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): else: sd = pl_sd - try: - from app.model_processor import model_processor - model_processor.process_file(ckpt) - except Exception as e: - logging.error(f"Error processing file {ckpt}: {e}") - return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): diff --git a/comfy_extras/nodes_assets_test.py b/comfy_extras/nodes_assets_test.py deleted file mode 100644 index 5172cd6282c6..000000000000 --- a/comfy_extras/nodes_assets_test.py +++ /dev/null @@ -1,56 +0,0 @@ -from comfy_api.latest import io, ComfyExtension -import comfy.asset_management -import comfy.sd -import folder_paths -import logging -import os - - -class AssetTestNode(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="AssetTestNode", - is_experimental=True, - inputs=[ - io.Combo.Input("ckpt_name", folder_paths.get_filename_list("checkpoints")), - ], - outputs=[ - io.Model.Output(), - io.Clip.Output(), - io.Vae.Output(), - ], - ) - - @classmethod - def execute(cls, ckpt_name: str): - hash = None - # lets get the full path just so we can retrieve the hash from db, if exists - try: - full_path = folder_paths.get_full_path("checkpoints", ckpt_name) - if full_path is None: - raise Exception(f"Model {ckpt_name} not found") - from app.model_processor import model_processor - hash = model_processor.retrieve_hash(full_path) - except Exception as e: - logging.error(f"Could not get model by hash with error: {e}") - subdir, name = os.path.split(ckpt_name) - asset_info = comfy.asset_management.AssetInfo(hash=hash, name=name, tags=["models", "checkpoints"], metadata={"subdir": subdir}) - asset = comfy.asset_management.resolve(asset_info) - # /\ the stuff above should happen in execution code instead of inside the node - # \/ the stuff below should happen in the node - confirm is a model asset, do stuff to it (already loaded? or should be called to 'load'?) - if asset is None: - raise Exception(f"Model {asset_info.name} not found") - assert isinstance(asset, comfy.asset_management.ModelReturnedAsset) - out = comfy.sd.load_state_dict_guess_config(asset.state_dict, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), metadata=asset.metadata) - return io.NodeOutput(out[0], out[1], out[2]) - - -class AssetTestExtension(ComfyExtension): - @classmethod - async def get_node_list(cls): - return [AssetTestNode] - - -def comfy_entrypoint(): - return AssetTestExtension() diff --git a/main.py b/main.py index 81001c0b5b86..557961d40ddb 100644 --- a/main.py +++ b/main.py @@ -278,11 +278,11 @@ def cleanup_temp(): if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) -def setup_database(): +async def setup_database(): try: - from app.database.db import init_db, dependencies_available + from app.database.db import init_db_engine, dependencies_available if dependencies_available(): - init_db() + await init_db_engine() except Exception as e: logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") @@ -309,6 +309,8 @@ def start_comfyui(asyncio_loop=None): asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) + asyncio_loop.run_until_complete(setup_database()) + hook_breaker_ac10a0.save_functions() asyncio_loop.run_until_complete(nodes.init_extra_nodes( init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, @@ -317,7 +319,6 @@ def start_comfyui(asyncio_loop=None): hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() - setup_database() prompt_server.add_routes() hijack_progress(prompt_server) diff --git a/nodes.py b/nodes.py index 95599b8aac8a..b74cfc58ed9d 100644 --- a/nodes.py +++ b/nodes.py @@ -28,9 +28,10 @@ import comfy.utils import comfy.controlnet from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator -from comfy_api.internal import register_versions, ComfyAPIWithVersion +from comfy_api.internal import async_to_sync, register_versions, ComfyAPIWithVersion from comfy_api.version_list import supported_versions from comfy_api.latest import io, ComfyExtension +from app.assets_manager import add_local_asset import comfy.clip_vision @@ -777,6 +778,9 @@ def load_vae(self, vae_name): else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) + async_to_sync.AsyncToSyncConverter.run_async_in_thread( + add_local_asset, tags=["models", "vae"], file_name=vae_name, file_path=vae_path + ) vae = comfy.sd.VAE(sd=sd) vae.throw_exception_if_invalid() return (vae,) @@ -2321,7 +2325,6 @@ async def init_builtin_extra_nodes(): "nodes_edit_model.py", "nodes_tcfg.py", "nodes_context_windows.py", - "nodes_assets_test.py", ] import_failed = [] diff --git a/requirements.txt b/requirements.txt index 0b0e78791b96..f12d2e3facf0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,7 @@ tqdm psutil alembic SQLAlchemy +aiosqlite av>=14.2.0 blake3 diff --git a/server.py b/server.py index 8f9c88ebf771..30c1a8fe7731 100644 --- a/server.py +++ b/server.py @@ -37,6 +37,7 @@ from app.custom_node_manager import CustomNodeManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes +from app.api.assets_routes import register_assets_routes from protocol import BinaryEventTypes async def send_socket_catch_exception(function, message): @@ -183,6 +184,7 @@ def __init__(self, loop): else args.front_end_root ) logging.info(f"[Prompt Server] web root: {self.web_root}") + register_assets_routes(self.app) routes = web.RouteTableDef() self.routes = routes self.last_node_id = None diff --git a/tests-unit/app_test/model_manager_test.py b/tests-unit/app_test/model_manager_test.py deleted file mode 100644 index ae59206f6563..000000000000 --- a/tests-unit/app_test/model_manager_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import pytest -import base64 -import json -import struct -from io import BytesIO -from PIL import Image -from aiohttp import web -from unittest.mock import patch -from app.model_manager import ModelFileManager - -pytestmark = ( - pytest.mark.asyncio -) # This applies the asyncio mark to all test functions in the module - -@pytest.fixture -def model_manager(): - return ModelFileManager() - -@pytest.fixture -def app(model_manager): - app = web.Application() - routes = web.RouteTableDef() - model_manager.add_routes(routes) - app.add_routes(routes) - return app - -async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path): - img = Image.new('RGB', (100, 100), 'white') - img_byte_arr = BytesIO() - img.save(img_byte_arr, format='PNG') - img_byte_arr.seek(0) - img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') - - safetensors_file = tmp_path / "test_model.safetensors" - header_bytes = json.dumps({ - "__metadata__": { - "ssmd_cover_images": json.dumps([img_b64]) - } - }).encode('utf-8') - length_bytes = struct.pack(' Date: Sat, 23 Aug 2025 20:14:22 +0300 Subject: [PATCH 09/82] dev: refactor; populate models in more nodes; use Pydantic in endpoints for input validation --- alembic_db/versions/0001_assets.py | 8 +- app/api/assets_routes.py | 95 +++++--------- app/api/schemas_in.py | 66 ++++++++++ app/assets_manager.py | 18 ++- app/model_manager.py | 195 ----------------------------- comfy/cli_args.py | 2 +- nodes.py | 32 +++-- server.py | 3 - 8 files changed, 140 insertions(+), 279 deletions(-) create mode 100644 app/api/schemas_in.py delete mode 100644 app/model_manager.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index 6705d8122440..369d6710b1c6 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -124,8 +124,12 @@ def upgrade() -> None: {"name": "upscale", "tag_type": "system"}, {"name": "diffusion-model", "tag_type": "system"}, {"name": "hypernetwork", "tag_type": "system"}, - {"name": "vae_approx", "tag_type": "system"}, - # TODO: decide what to do with: style_models, diffusers, gligen, photomaker, classifiers + {"name": "vae-approx", "tag_type": "system"}, + {"name": "gligen", "tag_type": "system"}, + {"name": "style-model", "tag_type": "system"}, + {"name": "encoder", "tag_type": "system"}, + {"name": "decoder", "tag_type": "system"}, + # TODO: decide what to do with: photomaker, classifiers ], ) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index aed1d3cea1f5..2e58532b861c 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -1,8 +1,10 @@ -import json -from typing import Sequence +from typing import Optional + from aiohttp import web +from pydantic import ValidationError -from app import assets_manager +from .. import assets_manager +from .schemas_in import ListAssetsQuery, UpdateAssetBody ROUTES = web.RouteTableDef() @@ -10,38 +12,22 @@ @ROUTES.get("/api/assets") async def list_assets(request: web.Request) -> web.Response: - q = request.rel_url.query - - include_tags: Sequence[str] = _parse_csv_tags(q.get("include_tags")) - exclude_tags: Sequence[str] = _parse_csv_tags(q.get("exclude_tags")) - name_contains = q.get("name_contains") - - # Optional JSON metadata filter (top-level key equality only for now) - metadata_filter = None - raw_meta = q.get("metadata_filter") - if raw_meta: - try: - metadata_filter = json.loads(raw_meta) - if not isinstance(metadata_filter, dict): - metadata_filter = None - except Exception: - # Silently ignore malformed JSON for first iteration; could 400 in future - metadata_filter = None - - limit = _parse_int(q.get("limit"), default=20, lo=1, hi=100) - offset = _parse_int(q.get("offset"), default=0, lo=0, hi=10_000_000) - sort = q.get("sort", "created_at") - order = q.get("order", "desc") + query_dict = dict(request.rel_url.query) + + try: + q = ListAssetsQuery.model_validate(query_dict) + except ValidationError as ve: + return _validation_error_response("INVALID_QUERY", ve) payload = await assets_manager.list_assets( - include_tags=include_tags, - exclude_tags=exclude_tags, - name_contains=name_contains, - metadata_filter=metadata_filter, - limit=limit, - offset=offset, - sort=sort, - order=order, + include_tags=q.include_tags, + exclude_tags=q.exclude_tags, + name_contains=q.name_contains, + metadata_filter=q.metadata_filter, + limit=q.limit, + offset=q.offset, + sort=q.sort, + order=q.order, ) return web.json_response(payload) @@ -55,29 +41,18 @@ async def update_asset(request: web.Request) -> web.Response: return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") try: - payload = await request.json() + body = UpdateAssetBody.model_validate(await request.json()) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) except Exception: return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") - name = payload.get("name", None) - tags = payload.get("tags", None) - user_metadata = payload.get("user_metadata", None) - - if name is None and tags is None and user_metadata is None: - return _error_response(400, "NO_FIELDS", "Provide at least one of: name, tags, user_metadata.") - - if tags is not None and (not isinstance(tags, list) or not all(isinstance(t, str) for t in tags)): - return _error_response(400, "INVALID_TAGS", "Field 'tags' must be an array of strings.") - - if user_metadata is not None and not isinstance(user_metadata, dict): - return _error_response(400, "INVALID_METADATA", "Field 'user_metadata' must be an object.") - try: result = await assets_manager.update_asset( asset_info_id=asset_info_id, - name=name, - tags=tags, - user_metadata=user_metadata, + name=body.name, + tags=body.tags, + user_metadata=body.user_metadata, ) except ValueError as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) @@ -90,21 +65,9 @@ def register_assets_routes(app: web.Application) -> None: app.add_routes(ROUTES) -def _parse_csv_tags(raw: str | None) -> list[str]: - if not raw: - return [] - return [t.strip() for t in raw.split(",") if t.strip()] - - -def _parse_int(qval: str | None, default: int, lo: int, hi: int) -> int: - if not qval: - return default - try: - v = int(qval) - except Exception: - return default - return max(lo, min(hi, v)) +def _error_response(status: int, code: str, message: str, details: Optional[dict] = None) -> web.Response: + return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status) -def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response: - return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status) +def _validation_error_response(code: str, ve: ValidationError) -> web.Response: + return _error_response(400, code, "Validation failed.", {"errors": ve.errors()}) diff --git a/app/api/schemas_in.py b/app/api/schemas_in.py new file mode 100644 index 000000000000..fb936a79af45 --- /dev/null +++ b/app/api/schemas_in.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import Any, Optional, Literal +from pydantic import BaseModel, Field, field_validator, model_validator, conint + + +class ListAssetsQuery(BaseModel): + include_tags: list[str] = Field(default_factory=list) + exclude_tags: list[str] = Field(default_factory=list) + name_contains: Optional[str] = None + + # Accept either a JSON string (query param) or a dict + metadata_filter: Optional[dict[str, Any]] = None + + limit: conint(ge=1, le=500) = 20 + offset: conint(ge=0) = 0 + + sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at" + order: Literal["asc", "desc"] = "desc" + + @field_validator("include_tags", "exclude_tags", mode="before") + @classmethod + def _split_csv_tags(cls, v): + # Accept "a,b,c" or ["a","b"] (we are liberal in what we accept) + if v is None: + return [] + if isinstance(v, str): + return [t.strip() for t in v.split(",") if t.strip()] + if isinstance(v, list): + out: list[str] = [] + for item in v: + if isinstance(item, str): + out.extend([t.strip() for t in item.split(",") if t.strip()]) + return out + return v + + @field_validator("metadata_filter", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v + if isinstance(v, str) and v.strip(): + import json + try: + parsed = json.loads(v) + except Exception as e: + raise ValueError(f"metadata_filter must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("metadata_filter must be a JSON object") + return parsed + return None + + +class UpdateAssetBody(BaseModel): + name: Optional[str] = None + tags: Optional[list[str]] = None + user_metadata: Optional[dict[str, Any]] = None + + @model_validator(mode="after") + def _at_least_one(self): + if self.name is None and self.tags is None and self.user_metadata is None: + raise ValueError("Provide at least one of: name, tags, user_metadata.") + if self.tags is not None: + if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags): + raise ValueError("Field 'tags' must be an array of strings.") + return self diff --git a/app/assets_manager.py b/app/assets_manager.py index 1cccd6acb443..05031a1bf9a8 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -2,6 +2,9 @@ from datetime import datetime, timezone from typing import Optional, Sequence +from comfy.cli_args import args +from comfy_api.internal import async_to_sync + from .database.db import create_session from .storage import hashing from .database.services import ( @@ -14,9 +17,11 @@ ) -def get_size_mtime_ns(path: str) -> tuple[int, int]: - st = os.stat(path, follow_symlinks=True) - return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) +def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None: + if not args.disable_model_processing: + async_to_sync.AsyncToSyncConverter.run_async_in_thread( + add_local_asset, tags=tags, file_name=file_name, file_path=file_path + ) async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None: @@ -28,7 +33,7 @@ async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> No - This function ensures the identity row and seeds mtime in asset_locator_state. """ abs_path = os.path.abspath(file_path) - size_bytes, mtime_ns = get_size_mtime_ns(abs_path) + size_bytes, mtime_ns = _get_size_mtime_ns(abs_path) if not size_bytes: return @@ -146,3 +151,8 @@ def _safe_sort_field(requested: str | None) -> str: if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: return v return "created_at" + + +def _get_size_mtime_ns(path: str) -> tuple[int, int]: + st = os.stat(path, follow_symlinks=True) + return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) diff --git a/app/model_manager.py b/app/model_manager.py deleted file mode 100644 index ab36bca74414..000000000000 --- a/app/model_manager.py +++ /dev/null @@ -1,195 +0,0 @@ -from __future__ import annotations - -import os -import base64 -import json -import time -import logging -import folder_paths -import glob -import comfy.utils -from aiohttp import web -from PIL import Image -from io import BytesIO -from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types - - -class ModelFileManager: - def __init__(self) -> None: - self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} - - def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: - return self.cache.get(key, default) - - def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]): - self.cache[key] = value - - def clear_cache(self): - self.cache.clear() - - def add_routes(self, routes): - # NOTE: This is an experiment to replace `/models` - @routes.get("/experiment/models") - async def get_model_folders(request): - model_types = list(folder_paths.folder_names_and_paths.keys()) - folder_black_list = ["configs", "custom_nodes"] - output_folders: list[dict] = [] - for folder in model_types: - if folder in folder_black_list: - continue - output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)}) - return web.json_response(output_folders) - - # NOTE: This is an experiment to replace `/models/{folder}` - @routes.get("/experiment/models/{folder}") - async def get_all_models(request): - folder = request.match_info.get("folder", None) - if not folder in folder_paths.folder_names_and_paths: - return web.Response(status=404) - files = self.get_model_file_list(folder) - return web.json_response(files) - - @routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}") - async def get_model_preview(request): - folder_name = request.match_info.get("folder", None) - path_index = int(request.match_info.get("path_index", None)) - filename = request.match_info.get("filename", None) - - if not folder_name in folder_paths.folder_names_and_paths: - return web.Response(status=404) - - folders = folder_paths.folder_names_and_paths[folder_name] - folder = folders[0][path_index] - full_filename = os.path.join(folder, filename) - - previews = self.get_model_previews(full_filename) - default_preview = previews[0] if len(previews) > 0 else None - if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)): - return web.Response(status=404) - - try: - with Image.open(default_preview) as img: - img_bytes = BytesIO() - img.save(img_bytes, format="WEBP") - img_bytes.seek(0) - return web.Response(body=img_bytes.getvalue(), content_type="image/webp") - except: - return web.Response(status=404) - - def get_model_file_list(self, folder_name: str): - folder_name = map_legacy(folder_name) - folders = folder_paths.folder_names_and_paths[folder_name] - output_list: list[dict] = [] - - for index, folder in enumerate(folders[0]): - if not os.path.isdir(folder): - continue - out = self.cache_model_file_list_(folder) - if out is None: - out = self.recursive_search_models_(folder, index) - self.set_cache(folder, out) - output_list.extend(out[0]) - - return output_list - - def cache_model_file_list_(self, folder: str): - model_file_list_cache = self.get_cache(folder) - - if model_file_list_cache is None: - return None - if not os.path.isdir(folder): - return None - if os.path.getmtime(folder) != model_file_list_cache[1]: - return None - for x in model_file_list_cache[1]: - time_modified = model_file_list_cache[1][x] - folder = x - if os.path.getmtime(folder) != time_modified: - return None - - return model_file_list_cache - - def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]: - if not os.path.isdir(directory): - return [], {}, time.perf_counter() - - excluded_dir_names = [".git"] - # TODO use settings - include_hidden_files = False - - result: list[str] = [] - dirs: dict[str, float] = {} - - for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): - subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] - if not include_hidden_files: - subdirs[:] = [d for d in subdirs if not d.startswith(".")] - filenames = [f for f in filenames if not f.startswith(".")] - - filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions) - - for file_name in filenames: - try: - full_path = os.path.join(dirpath, file_name) - relative_path = os.path.relpath(full_path, directory) - - # Get file metadata - file_info = { - "name": relative_path, - "pathIndex": pathIndex, - "modified": os.path.getmtime(full_path), # Add modification time - "created": os.path.getctime(full_path), # Add creation time - "size": os.path.getsize(full_path) # Add file size - } - result.append(file_info) - - except Exception as e: - logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.") - continue - - for d in subdirs: - path: str = os.path.join(dirpath, d) - try: - dirs[path] = os.path.getmtime(path) - except FileNotFoundError: - logging.warning(f"Warning: Unable to access {path}. Skipping this path.") - continue - - return result, dirs, time.perf_counter() - - def get_model_previews(self, filepath: str) -> list[str | BytesIO]: - dirname = os.path.dirname(filepath) - - if not os.path.exists(dirname): - return [] - - basename = os.path.splitext(filepath)[0] - match_files = glob.glob(f"{basename}.*", recursive=False) - image_files = filter_files_content_types(match_files, "image") - safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None) - safetensors_metadata = {} - - result: list[str | BytesIO] = [] - - for filename in image_files: - _basename = os.path.splitext(filename)[0] - if _basename == basename: - result.append(filename) - if _basename == f"{basename}.preview": - result.append(filename) - - if safetensors_file: - safetensors_filepath = os.path.join(dirname, safetensors_file) - header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024) - if header: - safetensors_metadata = json.loads(header) - safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None) - if safetensors_images: - safetensors_images = json.loads(safetensors_images) - for image in safetensors_images: - result.append(BytesIO(base64.b64decode(image))) - - return result - - def __exit__(self, exc_type, exc_value, traceback): - self.clear_cache() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 9ab78b99b149..7de4adbdc900 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -212,7 +212,7 @@ def is_valid_directory(path: str) -> str: os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.") -parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.") +parser.add_argument("--disable-model-processing", action="store_true", help="Disable automatic processing of the model file, such as calculating hashes and populating the database.") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/nodes.py b/nodes.py index b74cfc58ed9d..04b60ab2fca7 100644 --- a/nodes.py +++ b/nodes.py @@ -28,10 +28,10 @@ import comfy.utils import comfy.controlnet from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator -from comfy_api.internal import async_to_sync, register_versions, ComfyAPIWithVersion +from comfy_api.internal import register_versions, ComfyAPIWithVersion from comfy_api.version_list import supported_versions from comfy_api.latest import io, ComfyExtension -from app.assets_manager import add_local_asset +from app.assets_manager import populate_db_with_asset import comfy.clip_vision @@ -555,7 +555,9 @@ def INPUT_TYPES(s): def load_checkpoint(self, config_name, ckpt_name): config_path = folder_paths.get_full_path("configs", config_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) - return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + out = comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path) + return out class CheckpointLoaderSimple: @classmethod @@ -577,6 +579,7 @@ def INPUT_TYPES(s): def load_checkpoint(self, ckpt_name): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path) return out[:3] class DiffusersLoader: @@ -619,6 +622,7 @@ def INPUT_TYPES(s): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path) return out class CLIPSetLastLayer: @@ -677,6 +681,7 @@ def load_lora(self, model, clip, lora_name, strength_model, strength_clip): self.loaded_lora = (lora_path, lora) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) + populate_db_with_asset(["models", "lora"], lora_name, lora_path) return (model_lora, clip_lora) class LoraLoaderModelOnly(LoraLoader): @@ -741,11 +746,15 @@ def load_taesd(name): encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes)) decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes)) - enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder)) + encoder_path = folder_paths.get_full_path_or_raise("vae_approx", encoder) + populate_db_with_asset(["models", "vae-approx", "encoder"], name, encoder_path) + enc = comfy.utils.load_torch_file(encoder_path) for k in enc: sd["taesd_encoder.{}".format(k)] = enc[k] - dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder)) + decoder_path = folder_paths.get_full_path_or_raise("vae_approx", decoder) + populate_db_with_asset(["models", "vae-approx", "decoder"], name, decoder_path) + dec = comfy.utils.load_torch_file(decoder_path) for k in dec: sd["taesd_decoder.{}".format(k)] = dec[k] @@ -778,9 +787,7 @@ def load_vae(self, vae_name): else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) - async_to_sync.AsyncToSyncConverter.run_async_in_thread( - add_local_asset, tags=["models", "vae"], file_name=vae_name, file_path=vae_path - ) + populate_db_with_asset(["models", "vae"], vae_name, vae_path) vae = comfy.sd.VAE(sd=sd) vae.throw_exception_if_invalid() return (vae,) @@ -800,6 +807,7 @@ def load_controlnet(self, control_net_name): controlnet = comfy.controlnet.load_controlnet(controlnet_path) if controlnet is None: raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.") + populate_db_with_asset(["models", "controlnet"], control_net_name, controlnet_path) return (controlnet,) class DiffControlNetLoader: @@ -816,6 +824,7 @@ def INPUT_TYPES(s): def load_controlnet(self, model, control_net_name): controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) controlnet = comfy.controlnet.load_controlnet(controlnet_path, model) + populate_db_with_asset(["models", "controlnet"], control_net_name, controlnet_path) return (controlnet,) @@ -923,6 +932,7 @@ def load_unet(self, unet_name, weight_dtype): unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) + populate_db_with_asset(["models", "diffusion-model"], unet_name, unet_path) return (model,) class CLIPLoader: @@ -950,6 +960,7 @@ def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) + populate_db_with_asset(["models", "text-encoder"], clip_name, clip_path) return (clip,) class DualCLIPLoader: @@ -980,6 +991,8 @@ def load_clip(self, clip_name1, clip_name2, type, device="default"): model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) + populate_db_with_asset(["models", "text-encoder"], clip_name1, clip_path1) + populate_db_with_asset(["models", "text-encoder"], clip_name2, clip_path2) return (clip,) class CLIPVisionLoader: @@ -997,6 +1010,7 @@ def load_clip(self, clip_name): clip_vision = comfy.clip_vision.load(clip_path) if clip_vision is None: raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.") + populate_db_with_asset(["models", "clip-vision"], clip_name, clip_path) return (clip_vision,) class CLIPVisionEncode: @@ -1031,6 +1045,7 @@ def INPUT_TYPES(s): def load_style_model(self, style_model_name): style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name) style_model = comfy.sd.load_style_model(style_model_path) + populate_db_with_asset(["models", "style-model"], style_model_name, style_model_path) return (style_model,) @@ -1128,6 +1143,7 @@ def INPUT_TYPES(s): def load_gligen(self, gligen_name): gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name) gligen = comfy.sd.load_gligen(gligen_path) + populate_db_with_asset(["models", "gligen"], gligen_name, gligen_path) return (gligen,) class GLIGENTextBoxApply: diff --git a/server.py b/server.py index 30c1a8fe7731..ba368654fcc2 100644 --- a/server.py +++ b/server.py @@ -33,7 +33,6 @@ from comfy_api.internal import _ComfyNodeInternal from app.user_manager import UserManager -from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes @@ -155,7 +154,6 @@ def __init__(self, loop): mimetypes.add_type('image/webp', '.webp') self.user_manager = UserManager() - self.model_file_manager = ModelFileManager() self.custom_node_manager = CustomNodeManager() self.internal_routes = InternalRoutes(self) self.supports = ["custom_nodes_from_web"] @@ -764,7 +762,6 @@ async def setup(self): def add_routes(self): self.user_manager.add_routes(self.routes) - self.model_file_manager.add_routes(self.routes) self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items()) self.app.add_subapp('/internal', self.internal_routes.get_app()) From 8d46bec951fe5858dfe9992516733dc7d1e39677 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 24 Aug 2025 11:02:30 +0300 Subject: [PATCH 10/82] use Pydantic for output; finished Tags endpoints --- alembic_db/versions/0001_assets.py | 4 +- app/api/assets_routes.py | 92 ++++++++++++++- app/api/schemas_in.py | 47 +++++++- app/api/schemas_out.py | 69 ++++++++++++ app/assets_manager.py | 125 ++++++++++++++++----- app/database/services.py | 175 ++++++++++++++++++++++++++++- 6 files changed, 473 insertions(+), 39 deletions(-) create mode 100644 app/api/schemas_out.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index 369d6710b1c6..1c5563d3fe94 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -65,7 +65,7 @@ def upgrade() -> None: op.create_table( "asset_info_tags", sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), - sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), sa.Column("added_by", sa.String(length=128), nullable=True), sa.Column("added_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), @@ -104,7 +104,7 @@ def upgrade() -> None: # Tags vocabulary for models tags_table = sa.table( "tags", - sa.column("name", sa.String()), + sa.column("name", sa.String(length=512)), sa.column("tag_type", sa.String()), ) op.bulk_insert( diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 2e58532b861c..8c037fd97735 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -4,7 +4,7 @@ from pydantic import ValidationError from .. import assets_manager -from .schemas_in import ListAssetsQuery, UpdateAssetBody +from . import schemas_in ROUTES = web.RouteTableDef() @@ -15,7 +15,7 @@ async def list_assets(request: web.Request) -> web.Response: query_dict = dict(request.rel_url.query) try: - q = ListAssetsQuery.model_validate(query_dict) + q = schemas_in.ListAssetsQuery.model_validate(query_dict) except ValidationError as ve: return _validation_error_response("INVALID_QUERY", ve) @@ -29,7 +29,7 @@ async def list_assets(request: web.Request) -> web.Response: sort=q.sort, order=q.order, ) - return web.json_response(payload) + return web.json_response(payload.model_dump(mode="json")) @ROUTES.put("/api/assets/{id}") @@ -41,7 +41,7 @@ async def update_asset(request: web.Request) -> web.Response: return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") try: - body = UpdateAssetBody.model_validate(await request.json()) + body = schemas_in.UpdateAssetBody.model_validate(await request.json()) except ValidationError as ve: return _validation_error_response("INVALID_BODY", ve) except Exception: @@ -58,7 +58,89 @@ async def update_asset(request: web.Request) -> web.Response: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) except Exception: return _error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result, status=200) + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.get("/api/tags") +async def get_tags(request: web.Request) -> web.Response: + query_map = dict(request.rel_url.query) + + try: + query = schemas_in.TagsListQuery.model_validate(query_map) + except ValidationError as ve: + return web.json_response( + {"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": ve.errors()}}, + status=400, + ) + + result = await assets_manager.list_tags( + prefix=query.prefix, + limit=query.limit, + offset=query.offset, + order=query.order, + include_zero=query.include_zero, + ) + return web.json_response(result.model_dump(mode="json")) + + +@ROUTES.post("/api/assets/{id}/tags") +async def add_asset_tags(request: web.Request) -> web.Response: + asset_info_id_raw = request.match_info.get("id") + try: + asset_info_id = int(asset_info_id_raw) + except Exception: + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + + try: + payload = await request.json() + data = schemas_in.TagsAdd.model_validate(payload) + except ValidationError as ve: + return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()}) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = await assets_manager.add_tags_to_asset( + asset_info_id=asset_info_id, + tags=data.tags, + origin="manual", + added_by=None, + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + return _error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.delete("/api/assets/{id}/tags") +async def delete_asset_tags(request: web.Request) -> web.Response: + asset_info_id_raw = request.match_info.get("id") + try: + asset_info_id = int(asset_info_id_raw) + except Exception: + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + + try: + payload = await request.json() + data = schemas_in.TagsRemove.model_validate(payload) + except ValidationError as ve: + return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()}) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = await assets_manager.remove_tags_from_asset( + asset_info_id=asset_info_id, + tags=data.tags, + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + return _error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(result.model_dump(mode="json"), status=200) def register_assets_routes(app: web.Application) -> None: diff --git a/app/api/schemas_in.py b/app/api/schemas_in.py index fb936a79af45..4e0eb62536d6 100644 --- a/app/api/schemas_in.py +++ b/app/api/schemas_in.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Any, Optional, Literal -from pydantic import BaseModel, Field, field_validator, model_validator, conint +from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator, conint class ListAssetsQuery(BaseModel): @@ -64,3 +64,48 @@ def _at_least_one(self): if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags): raise ValueError("Field 'tags' must be an array of strings.") return self + + +class TagsListQuery(BaseModel): + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + prefix: Optional[str] = Field(None, min_length=1, max_length=256) + limit: int = Field(100, ge=1, le=1000) + offset: int = Field(0, ge=0, le=10_000_000) + order: Literal["count_desc", "name_asc"] = "count_desc" + include_zero: bool = True + + @field_validator("prefix") + @classmethod + def normalize_prefix(cls, v: Optional[str]) -> Optional[str]: + if v is None: + return v + v = v.strip() + return v.lower() or None + + +class TagsAdd(BaseModel): + model_config = ConfigDict(extra="ignore") + tags: list[str] = Field(..., min_length=1) + + @field_validator("tags") + @classmethod + def normalize_tags(cls, v: list[str]) -> list[str]: + out = [] + for t in v: + if not isinstance(t, str): + raise TypeError("tags must be strings") + tnorm = t.strip().lower() + if tnorm: + out.append(tnorm) + seen = set() + deduplicated = [] + for x in out: + if x not in seen: + seen.add(x) + deduplicated.append(x) + return deduplicated + + +class TagsRemove(TagsAdd): + pass diff --git a/app/api/schemas_out.py b/app/api/schemas_out.py new file mode 100644 index 000000000000..f86da3523d92 --- /dev/null +++ b/app/api/schemas_out.py @@ -0,0 +1,69 @@ +from datetime import datetime +from typing import Any, Optional +from pydantic import BaseModel, ConfigDict, Field, field_serializer + + +class AssetSummary(BaseModel): + id: int + name: str + asset_hash: str + size: Optional[int] = None + mime_type: Optional[str] = None + tags: list[str] = Field(default_factory=list) + preview_url: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + last_access_time: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("created_at", "updated_at", "last_access_time") + def _ser_dt(self, v: Optional[datetime], _info): + return v.isoformat() if v else None + + +class AssetsList(BaseModel): + assets: list[AssetSummary] + total: int + has_more: bool + + +class AssetUpdated(BaseModel): + id: int + name: str + asset_hash: str + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + updated_at: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("updated_at") + def _ser_updated(self, v: Optional[datetime], _info): + return v.isoformat() if v else None + + +class TagUsage(BaseModel): + name: str + count: int + type: str + + +class TagsList(BaseModel): + tags: list[TagUsage] = Field(default_factory=list) + total: int + has_more: bool + + +class TagsAdd(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + added: list[str] = Field(default_factory=list) + already_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) + + +class TagsRemove(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + removed: list[str] = Field(default_factory=list) + not_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) diff --git a/app/assets_manager.py b/app/assets_manager.py index 05031a1bf9a8..60c3f08cdb96 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -14,7 +14,11 @@ list_asset_infos_page, update_asset_info_full, get_asset_tags, + list_tags_with_usage, + add_tags_to_asset_info, + remove_tags_from_asset_info, ) +from .api import schemas_out def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None: @@ -70,7 +74,7 @@ async def list_assets( offset: int = 0, sort: str | None = "created_at", order: str | None = "desc", -) -> dict: +) -> schemas_out.AssetsList: sort = _safe_sort_field(sort) order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() @@ -87,30 +91,30 @@ async def list_assets( order=order, ) - assets_json = [] + summaries: list[schemas_out.AssetSummary] = [] for info in infos: - asset = info.asset # populated via contains_eager + asset = info.asset tags = tag_map.get(info.id, []) - assets_json.append( - { - "id": info.id, - "name": info.name, - "asset_hash": info.asset_hash, - "size": int(asset.size_bytes) if asset else None, - "mime_type": asset.mime_type if asset else None, - "tags": tags, - "preview_url": f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later - "created_at": info.created_at.isoformat() if info.created_at else None, - "updated_at": info.updated_at.isoformat() if info.updated_at else None, - "last_access_time": info.last_access_time.isoformat() if info.last_access_time else None, - } + summaries.append( + schemas_out.AssetSummary( + id=info.id, + name=info.name, + asset_hash=info.asset_hash, + size=int(asset.size_bytes) if asset else None, + mime_type=asset.mime_type if asset else None, + tags=tags, + preview_url=f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later + created_at=info.created_at, + updated_at=info.updated_at, + last_access_time=info.last_access_time, + ) ) - return { - "assets": assets_json, - "total": total, - "has_more": (offset + len(assets_json)) < total, - } + return schemas_out.AssetsList( + assets=summaries, + total=total, + has_more=(offset + len(summaries)) < total, + ) async def update_asset( @@ -119,7 +123,7 @@ async def update_asset( name: str | None = None, tags: list[str] | None = None, user_metadata: dict | None = None, -) -> dict: +) -> schemas_out.AssetUpdated: async with await create_session() as session: info = await update_asset_info_full( session, @@ -134,14 +138,40 @@ async def update_asset( tag_names = await get_asset_tags(session, asset_info_id=asset_info_id) await session.commit() - return { - "id": info.id, - "name": info.name, - "asset_hash": info.asset_hash, - "tags": tag_names, - "user_metadata": info.user_metadata or {}, - "updated_at": info.updated_at.isoformat() if info.updated_at else None, - } + return schemas_out.AssetUpdated( + id=info.id, + name=info.name, + asset_hash=info.asset_hash, + tags=tag_names, + user_metadata=info.user_metadata or {}, + updated_at=info.updated_at, + ) + + + +async def list_tags( + *, + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + order: str = "count_desc", + include_zero: bool = True, +) -> schemas_out.TagsList: + limit = max(1, min(1000, limit)) + offset = max(0, offset) + + async with await create_session() as session: + rows, total = await list_tags_with_usage( + session, + prefix=prefix, + limit=limit, + offset=offset, + include_zero=include_zero, + order=order, + ) + + tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] + return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total) def _safe_sort_field(requested: str | None) -> str: @@ -156,3 +186,38 @@ def _safe_sort_field(requested: str | None) -> str: def _get_size_mtime_ns(path: str) -> tuple[int, int]: st = os.stat(path, follow_symlinks=True) return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + + +async def add_tags_to_asset( + *, + asset_info_id: int, + tags: list[str], + origin: str = "manual", + added_by: str | None = None, +) -> schemas_out.TagsAdd: + async with await create_session() as session: + data = await add_tags_to_asset_info( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=origin, + added_by=added_by, + create_if_missing=True, + ) + await session.commit() + return schemas_out.TagsAdd(**data) + + +async def remove_tags_from_asset( + *, + asset_info_id: int, + tags: list[str], +) -> schemas_out.TagsRemove: + async with await create_session() as session: + data = await remove_tags_from_asset_info( + session, + asset_info_id=asset_info_id, + tags=tags, + ) + await session.commit() + return schemas_out.TagsRemove(**data) diff --git a/app/database/services.py b/app/database/services.py index c2792b4c4501..3280fd534134 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -493,7 +493,7 @@ async def replace_asset_info_metadata_projection( await session.flush() -async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]: +async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[str]: return [ tag_name for (tag_name,) in ( @@ -504,6 +504,179 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[T ] +async def list_tags_with_usage( + session, + *, + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", # "count_desc" | "name_asc" +) -> tuple[list[tuple[str, str, int]], int]: + """ + Returns: + rows: list of (name, tag_type, count) + total: number of tags matching filter (independent of pagination) + """ + # Subquery with counts by tag_name + counts_sq = ( + select( + AssetInfoTag.tag_name.label("tag_name"), + func.count(AssetInfoTag.asset_info_id).label("cnt"), + ) + .group_by(AssetInfoTag.tag_name) + .subquery() + ) + + # Base select with LEFT JOIN so we can include zero-usage tags + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + # Prefix filter (tags are lowercase by check constraint) + if prefix: + q = q.where(Tag.name.like(prefix.strip().lower() + "%")) + + # Include_zero toggles: if False, drop zero-usage tags + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + # Ordering + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: # default "count_desc" + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + # Total (without limit/offset, same filters) + total_q = select(func.count()).select_from(Tag) + if prefix: + total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%")) + if not include_zero: + # count only names that appear in counts subquery + total_q = total_q.where( + Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) + ) + + rows = (await session.execute(q.limit(limit).offset(offset))).all() + total = (await session.execute(total_q)).scalar_one() + + # Normalize counts to int for SQLite/Postgres consistency + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) + + +async def add_tags_to_asset_info( + session: AsyncSession, + *, + asset_info_id: int, + tags: Sequence[str], + origin: str = "manual", + added_by: Optional[str] = None, + create_if_missing: bool = True, +) -> dict: + """Adds tags to an AssetInfo. + If create_if_missing=True, missing tag rows are created as 'user'. + Returns: {"added": [...], "already_present": [...], "total_tags": [...]} + """ + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = _normalize_tags(tags) + if not norm: + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"added": [], "already_present": [], "total_tags": total} + + # Ensure tag rows exist if requested. + if create_if_missing: + await _ensure_tags_exist(session, norm, tag_type="user") + + # Current links + existing = { + tname + for (tname,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + to_add = [t for t in norm if t not in existing] + already = [t for t in norm if t in existing] + + if to_add: + now = datetime.now(timezone.utc) + # Make insert race-safe with a nested tx; ignore dup conflicts if any. + async with session.begin_nested(): + session.add_all([ + AssetInfoTag( + asset_info_id=asset_info_id, + tag_name=t, + origin=origin, + added_by=added_by, + added_at=now, + ) for t in to_add + ]) + try: + await session.flush() + except IntegrityError: + # Another writer linked the same tag at the same time -> ok, treat as already present. + await session.rollback() + + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"added": sorted(set(to_add)), "already_present": sorted(set(already)), "total_tags": total} + + +async def remove_tags_from_asset_info( + session: AsyncSession, + *, + asset_info_id: int, + tags: Sequence[str], +) -> dict: + """Removes tags from an AssetInfo. + Returns: {"removed": [...], "not_present": [...], "total_tags": [...]} + """ + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = _normalize_tags(tags) + if not norm: + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": [], "not_present": [], "total_tags": total} + + existing = { + tname + for (tname,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + await session.execute( + delete(AssetInfoTag) + .where( + AssetInfoTag.asset_info_id == asset_info_id, + AssetInfoTag.tag_name.in_(to_remove), + ) + ) + await session.flush() + + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": to_remove, "not_present": not_present, "total_tags": total} + + def _normalize_tags(tags: Sequence[str] | None) -> list[str]: return [t.strip().lower() for t in (tags or []) if (t or "").strip()] From 0755e5320a9d1fa2814e9db4f215f188fea16f91 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 24 Aug 2025 12:01:59 +0300 Subject: [PATCH 11/82] remove timezone; download asset, delete asset endpoints --- alembic_db/versions/0001_assets.py | 12 ++--- app/api/assets_routes.py | 52 +++++++++++++++++++++ app/assets_manager.py | 47 ++++++++++++++++++- app/database/models.py | 17 ++++--- app/database/services.py | 74 ++++++++++++++++++++++-------- app/database/timeutil.py | 6 +++ 6 files changed, 174 insertions(+), 34 deletions(-) create mode 100644 app/database/timeutil.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index 1c5563d3fe94..47bb43dd8dd0 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -24,8 +24,8 @@ def upgrade() -> None: sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"), sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"), sa.Column("storage_locator", sa.Text(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), - sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"), ) @@ -41,9 +41,9 @@ def upgrade() -> None: sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False), sa.Column("preview_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True), sa.Column("user_metadata", sa.JSON(), nullable=True), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), - sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), - sa.Column("last_access_time", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), sqlite_autoincrement=True, ) op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) @@ -68,7 +68,7 @@ def upgrade() -> None: sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), sa.Column("added_by", sa.String(length=128), nullable=True), - sa.Column("added_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), ) op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 8c037fd97735..014e324d75ef 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -1,3 +1,4 @@ +import urllib.parse from typing import Optional from aiohttp import web @@ -32,6 +33,39 @@ async def list_assets(request: web.Request) -> web.Response: return web.json_response(payload.model_dump(mode="json")) + +@ROUTES.get("/api/assets/{id}/content") +async def download_asset_content(request: web.Request) -> web.Response: + asset_info_id_raw = request.match_info.get("id") + try: + asset_info_id = int(asset_info_id_raw) + except Exception: + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + + disposition = request.query.get("disposition", "attachment").lower().strip() + if disposition not in {"inline", "attachment"}: + disposition = "attachment" + + try: + abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download( + asset_info_id=asset_info_id + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve)) + except NotImplementedError as nie: + return _error_response(501, "BACKEND_UNSUPPORTED", str(nie)) + except FileNotFoundError: + return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.") + + quoted = filename.replace('"', "'") + cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}' + + resp = web.FileResponse(abs_path) + resp.content_type = content_type + resp.headers["Content-Disposition"] = cd + return resp + + @ROUTES.put("/api/assets/{id}") async def update_asset(request: web.Request) -> web.Response: asset_info_id_raw = request.match_info.get("id") @@ -61,6 +95,24 @@ async def update_asset(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) +@ROUTES.delete("/api/assets/{id}") +async def delete_asset(request: web.Request) -> web.Response: + asset_info_id_raw = request.match_info.get("id") + try: + asset_info_id = int(asset_info_id_raw) + except Exception: + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + + try: + deleted = await assets_manager.delete_asset_reference(asset_info_id=asset_info_id) + except Exception: + return _error_response(500, "INTERNAL", "Unexpected server error.") + + if not deleted: + return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.") + return web.Response(status=204) + + @ROUTES.get("/api/tags") async def get_tags(request: web.Request) -> web.Response: query_map = dict(request.rel_url.query) diff --git a/app/assets_manager.py b/app/assets_manager.py index 60c3f08cdb96..2c07db4b2ffa 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -1,5 +1,5 @@ +import mimetypes import os -from datetime import datetime, timezone from typing import Optional, Sequence from comfy.cli_args import args @@ -17,6 +17,9 @@ list_tags_with_usage, add_tags_to_asset_info, remove_tags_from_asset_info, + fetch_asset_info_and_asset, + touch_asset_info_by_id, + delete_asset_info_by_id, ) from .api import schemas_out @@ -43,7 +46,7 @@ async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> No async with await create_session() as session: if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns): - await touch_asset_infos_by_fs_path(session, abs_path=abs_path, ts=datetime.now(timezone.utc)) + await touch_asset_infos_by_fs_path(session, abs_path=abs_path) await session.commit() return @@ -117,6 +120,40 @@ async def list_assets( ) +async def resolve_asset_content_for_download( + *, asset_info_id: int +) -> tuple[str, str, str]: + """ + Returns (abs_path, content_type, download_name) for the given AssetInfo id. + Also touches last_access_time (only_if_newer). + Raises: + ValueError if AssetInfo not found + NotImplementedError for unsupported backend + FileNotFoundError if underlying file does not exist (fs backend) + """ + async with await create_session() as session: + pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id) + if not pair: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + info, asset = pair + + if asset.storage_backend != "fs": + # Future: support http/s3/gcs/... + raise NotImplementedError(f"backend {asset.storage_backend!r} not supported yet") + + abs_path = os.path.abspath(asset.storage_locator) + if not os.path.exists(abs_path): + raise FileNotFoundError(abs_path) + + await touch_asset_info_by_id(session, asset_info_id=asset_info_id) + await session.commit() + + ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" + download_name = info.name or os.path.basename(abs_path) + return abs_path, ctype, download_name + + async def update_asset( *, asset_info_id: int, @@ -148,6 +185,12 @@ async def update_asset( ) +async def delete_asset_reference(*, asset_info_id: int) -> bool: + async with await create_session() as session: + r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id) + await session.commit() + return r + async def list_tags( *, diff --git a/app/database/models.py b/app/database/models.py index ca7ad67f81e6..06e46815dcdc 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -14,9 +14,10 @@ Numeric, Boolean, ) -from sqlalchemy.sql import func from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign +from .timeutil import utcnow + class Base(DeclarativeBase): pass @@ -46,10 +47,10 @@ class Asset(Base): storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs") storage_locator: Mapped[str] = mapped_column(Text, nullable=False) created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + DateTime(timezone=False), nullable=False, default=utcnow ) updated_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + DateTime(timezone=False), nullable=False, default=utcnow ) infos: Mapped[list["AssetInfo"]] = relationship( @@ -125,13 +126,13 @@ class AssetInfo(Base): preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL")) user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON) created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + DateTime(timezone=False), nullable=False, default=utcnow ) updated_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + DateTime(timezone=False), nullable=False, default=utcnow ) last_access_time: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + DateTime(timezone=False), nullable=False, default=utcnow ) # Relationships @@ -221,7 +222,9 @@ class AssetInfoTag(Base): ) origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") added_by: Mapped[str | None] = mapped_column(String(128)) - added_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + added_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=utcnow + ) asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links") tag: Mapped["Tag"] = relationship(back_populates="asset_info_links") diff --git a/app/database/services.py b/app/database/services.py index 3280fd534134..98a5ae624db6 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -1,7 +1,7 @@ import os import logging from collections import defaultdict -from datetime import datetime, timezone +from datetime import datetime from decimal import Decimal from typing import Any, Sequence, Optional, Iterable @@ -12,6 +12,7 @@ from sqlalchemy.exc import IntegrityError from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta +from .timeutil import utcnow async def check_fs_asset_exists_quick( @@ -93,7 +94,7 @@ async def ingest_fs_asset( } """ locator = os.path.abspath(abs_path) - datetime_now = datetime.now(timezone.utc) + datetime_now = utcnow() out = { "asset_created": False, @@ -246,7 +247,7 @@ async def touch_asset_infos_by_fs_path( only_if_newer: bool = True, ) -> int: locator = os.path.abspath(abs_path) - ts = ts or datetime.now(timezone.utc) + ts = ts or utcnow() stmt = sa.update(AssetInfo).where( sa.exists( @@ -274,13 +275,31 @@ async def touch_asset_infos_by_fs_path( return int(res.rowcount or 0) +async def touch_asset_info_by_id( + session: AsyncSession, + *, + asset_info_id: int, + ts: Optional[datetime] = None, + only_if_newer: bool = True, +) -> int: + ts = ts or utcnow() + stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) + if only_if_newer: + stmt = stmt.where( + sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) + ) + stmt = stmt.values(last_access_time=ts) + res = await session.execute(stmt) + return int(res.rowcount or 0) + + async def list_asset_infos_page( session: AsyncSession, *, - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, - name_contains: str | None = None, - metadata_filter: dict | None = None, + include_tags: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + name_contains: Optional[str] = None, + metadata_filter: Optional[dict] = None, limit: int = 20, offset: int = 0, sort: str = "created_at", @@ -361,6 +380,19 @@ async def list_asset_infos_page( return infos, tag_map, total +async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: int) -> Optional[tuple[AssetInfo, Asset]]: + row = await session.execute( + select(AssetInfo, Asset) + .join(Asset, Asset.hash == AssetInfo.asset_hash) + .where(AssetInfo.id == asset_info_id) + .limit(1) + ) + pair = row.first() + if not pair: + return None + return pair[0], pair[1] + + async def set_asset_info_tags( session: AsyncSession, *, @@ -374,7 +406,6 @@ async def set_asset_info_tags( Creates missing tag names as 'user'. """ desired = _normalize_tags(tags) - now = datetime.now(timezone.utc) # current links current = set( @@ -389,7 +420,7 @@ async def set_asset_info_tags( if to_add: await _ensure_tags_exist(session, to_add, tag_type="user") session.add_all([ - AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=now) + AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=utcnow()) for t in to_add ]) await session.flush() @@ -447,17 +478,23 @@ async def update_asset_info_full( touched = True if touched and user_metadata is None: - info.updated_at = datetime.now(timezone.utc) + info.updated_at = utcnow() await session.flush() return info +async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> bool: + """Delete the user-visible AssetInfo row. Cascades clear tags and metadata.""" + res = await session.execute(delete(AssetInfo).where(AssetInfo.id == asset_info_id)) + return bool(res.rowcount) + + async def replace_asset_info_metadata_projection( session: AsyncSession, *, asset_info_id: int, - user_metadata: dict | None, + user_metadata: Optional[dict], ) -> None: """Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`.""" info = await session.get(AssetInfo, asset_info_id) @@ -465,7 +502,7 @@ async def replace_asset_info_metadata_projection( raise ValueError(f"AssetInfo {asset_info_id} not found") info.user_metadata = user_metadata or {} - info.updated_at = datetime.now(timezone.utc) + info.updated_at = utcnow() await session.flush() await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) @@ -507,7 +544,7 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[s async def list_tags_with_usage( session, *, - prefix: str | None = None, + prefix: Optional[str] = None, limit: int = 100, offset: int = 0, include_zero: bool = True, @@ -611,7 +648,6 @@ async def add_tags_to_asset_info( already = [t for t in norm if t in existing] if to_add: - now = datetime.now(timezone.utc) # Make insert race-safe with a nested tx; ignore dup conflicts if any. async with session.begin_nested(): session.add_all([ @@ -620,7 +656,7 @@ async def add_tags_to_asset_info( tag_name=t, origin=origin, added_by=added_by, - added_at=now, + added_at=utcnow(), ) for t in to_add ]) try: @@ -677,7 +713,7 @@ async def remove_tags_from_asset_info( return {"removed": to_remove, "not_present": not_present, "total_tags": total} -def _normalize_tags(tags: Sequence[str] | None) -> list[str]: +def _normalize_tags(tags: Optional[Sequence[str]]) -> list[str]: return [t.strip().lower() for t in (tags or []) if (t or "").strip()] @@ -697,8 +733,8 @@ async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_ty def _apply_tag_filters( stmt: sa.sql.Select, - include_tags: Sequence[str] | None, - exclude_tags: Sequence[str] | None, + include_tags: Optional[Sequence[str]], + exclude_tags: Optional[Sequence[str]], ) -> sa.sql.Select: """include_tags: every tag must be present; exclude_tags: none may be present.""" include_tags = _normalize_tags(include_tags) @@ -724,7 +760,7 @@ def _apply_tag_filters( def _apply_metadata_filter( stmt: sa.sql.Select, - metadata_filter: dict | None, + metadata_filter: Optional[dict], ) -> sa.sql.Select: """Apply metadata filters using the projection table asset_info_meta. diff --git a/app/database/timeutil.py b/app/database/timeutil.py new file mode 100644 index 000000000000..e8fab12ee7c1 --- /dev/null +++ b/app/database/timeutil.py @@ -0,0 +1,6 @@ +from datetime import datetime, timezone + + +def utcnow() -> datetime: + """Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC.""" + return datetime.now(timezone.utc).replace(tzinfo=None) From f2ea0bc22c74ca0158c39ffc64a4baa6058798a5 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 24 Aug 2025 14:15:21 +0300 Subject: [PATCH 12/82] added create_asset_from_hash endpoint --- app/api/assets_routes.py | 30 ++++++++++++ app/api/schemas_in.py | 41 +++++++++++++++- app/api/schemas_out.py | 20 ++++++++ app/assets_manager.py | 101 ++++++++++++++++++++++++++++++--------- app/database/services.py | 54 +++++++++++++++++++++ 5 files changed, 221 insertions(+), 25 deletions(-) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 014e324d75ef..63610099824d 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -11,6 +11,15 @@ ROUTES = web.RouteTableDef() +@ROUTES.head("/api/assets/hash/{hash}") +async def head_asset_by_hash(request: web.Request) -> web.Response: + hash_str = request.match_info.get("hash", "").strip().lower() + if not hash_str or ":" not in hash_str: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + exists = await assets_manager.asset_exists(asset_hash=hash_str) + return web.Response(status=200 if exists else 404) + + @ROUTES.get("/api/assets") async def list_assets(request: web.Request) -> web.Response: query_dict = dict(request.rel_url.query) @@ -95,6 +104,27 @@ async def update_asset(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) +@ROUTES.post("/api/assets/from-hash") +async def create_asset_from_hash(request: web.Request) -> web.Response: + try: + payload = await request.json() + body = schemas_in.CreateFromHashBody.model_validate(payload) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + result = await assets_manager.create_asset_from_hash( + hash_str=body.hash, + name=body.name, + tags=body.tags, + user_metadata=body.user_metadata, + ) + if result is None: + return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") + return web.json_response(result.model_dump(mode="json"), status=201) + + @ROUTES.delete("/api/assets/{id}") async def delete_asset(request: web.Request) -> web.Response: asset_info_id_raw = request.match_info.get("id") diff --git a/app/api/schemas_in.py b/app/api/schemas_in.py index 4e0eb62536d6..0f07bf19d615 100644 --- a/app/api/schemas_in.py +++ b/app/api/schemas_in.py @@ -1,4 +1,4 @@ -from __future__ import annotations +import json from typing import Any, Optional, Literal from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator, conint @@ -40,7 +40,6 @@ def _parse_metadata_json(cls, v): if v is None or isinstance(v, dict): return v if isinstance(v, str) and v.strip(): - import json try: parsed = json.loads(v) except Exception as e: @@ -66,6 +65,44 @@ def _at_least_one(self): return self +class CreateFromHashBody(BaseModel): + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + hash: str + name: str + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + + @field_validator("hash") + @classmethod + def _require_blake3(cls, v): + s = (v or "").strip().lower() + if ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3": + raise ValueError("only canonical 'blake3:' is accepted here") + if not digest or any(c for c in digest if c not in "0123456789abcdef"): + raise ValueError("hash digest must be lowercase hex") + return s + + @field_validator("tags", mode="before") + @classmethod + def _tags_norm(cls, v): + if v is None: + return [] + if isinstance(v, list): + out = [str(t).strip().lower() for t in v if str(t).strip()] + seen = set(); dedup = [] + for t in out: + if t not in seen: + seen.add(t); dedup.append(t) + return dedup + if isinstance(v, str): + return [t.strip().lower() for t in v.split(",") if t.strip()] + return [] + + class TagsListQuery(BaseModel): model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) diff --git a/app/api/schemas_out.py b/app/api/schemas_out.py index f86da3523d92..0a71b8bc960d 100644 --- a/app/api/schemas_out.py +++ b/app/api/schemas_out.py @@ -43,6 +43,26 @@ def _ser_updated(self, v: Optional[datetime], _info): return v.isoformat() if v else None +class AssetCreated(BaseModel): + id: int + name: str + asset_hash: str + size: Optional[int] = None + mime_type: Optional[str] = None + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + preview_hash: Optional[str] = None + created_at: Optional[datetime] = None + last_access_time: Optional[datetime] = None + created_new: bool + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("created_at", "last_access_time") + def _ser_dt(self, v: Optional[datetime], _info): + return v.isoformat() if v else None + + class TagUsage(BaseModel): name: str count: int diff --git a/app/assets_manager.py b/app/assets_manager.py index 2c07db4b2ffa..f92232a3dbe0 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -20,10 +20,18 @@ fetch_asset_info_and_asset, touch_asset_info_by_id, delete_asset_info_by_id, + asset_exists_by_hash, + get_asset_by_hash, + create_asset_info_for_existing_asset, ) from .api import schemas_out +async def asset_exists(*, asset_hash: str) -> bool: + async with await create_session() as session: + return await asset_exists_by_hash(session, asset_hash=asset_hash) + + def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None: if not args.disable_model_processing: async_to_sync.AsyncToSyncConverter.run_async_in_thread( @@ -69,14 +77,14 @@ async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> No async def list_assets( *, - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, + include_tags: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, name_contains: Optional[str] = None, metadata_filter: Optional[dict] = None, limit: int = 20, offset: int = 0, - sort: str | None = "created_at", - order: str | None = "desc", + sort: str = "created_at", + order: str = "desc", ) -> schemas_out.AssetsList: sort = _safe_sort_field(sort) order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() @@ -157,9 +165,9 @@ async def resolve_asset_content_for_download( async def update_asset( *, asset_info_id: int, - name: str | None = None, - tags: list[str] | None = None, - user_metadata: dict | None = None, + name: Optional[str] = None, + tags: Optional[list[str]] = None, + user_metadata: Optional[dict] = None, ) -> schemas_out.AssetUpdated: async with await create_session() as session: info = await update_asset_info_full( @@ -192,9 +200,49 @@ async def delete_asset_reference(*, asset_info_id: int) -> bool: return r +async def create_asset_from_hash( + *, + hash_str: str, + name: str, + tags: Optional[list[str]] = None, + user_metadata: Optional[dict] = None, +) -> Optional[schemas_out.AssetCreated]: + canonical = hash_str.strip().lower() + async with await create_session() as session: + asset = await get_asset_by_hash(session, asset_hash=canonical) + if not asset: + return None + + info = await create_asset_info_for_existing_asset( + session, + asset_hash=canonical, + name=_safe_filename(name, fallback=canonical.split(":", 1)[1]), + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + added_by=None, + ) + tag_names = await get_asset_tags(session, asset_info_id=info.id) + await session.commit() + + return schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=info.asset_hash, + size=int(asset.size_bytes), + mime_type=asset.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_hash=info.preview_hash, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=False, + ) + + async def list_tags( *, - prefix: str | None = None, + prefix: Optional[str] = None, limit: int = 100, offset: int = 0, order: str = "count_desc", @@ -217,26 +265,12 @@ async def list_tags( return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total) -def _safe_sort_field(requested: str | None) -> str: - if not requested: - return "created_at" - v = requested.lower() - if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: - return v - return "created_at" - - -def _get_size_mtime_ns(path: str) -> tuple[int, int]: - st = os.stat(path, follow_symlinks=True) - return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) - - async def add_tags_to_asset( *, asset_info_id: int, tags: list[str], origin: str = "manual", - added_by: str | None = None, + added_by: Optional[str] = None, ) -> schemas_out.TagsAdd: async with await create_session() as session: data = await add_tags_to_asset_info( @@ -264,3 +298,24 @@ async def remove_tags_from_asset( ) await session.commit() return schemas_out.TagsRemove(**data) + + +def _safe_sort_field(requested: Optional[str]) -> str: + if not requested: + return "created_at" + v = requested.lower() + if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: + return v + return "created_at" + + +def _get_size_mtime_ns(path: str) -> tuple[int, int]: + st = os.stat(path, follow_symlinks=True) + return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + + +def _safe_filename(name: Optional[str] , fallback: str) -> str: + n = os.path.basename((name or "").strip() or fallback) + if n: + return n + return fallback diff --git a/app/database/services.py b/app/database/services.py index 98a5ae624db6..b916a2055a89 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -15,6 +15,20 @@ from .timeutil import utcnow + +async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: + row = ( + await session.execute( + select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).first() + return row is not None + + +async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]: + return await session.get(Asset, asset_hash) + + async def check_fs_asset_exists_quick( session, *, @@ -393,6 +407,46 @@ async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: in return pair[0], pair[1] +async def create_asset_info_for_existing_asset( + session: AsyncSession, + *, + asset_hash: str, + name: str, + user_metadata: Optional[dict] = None, + tags: Optional[Sequence[str]] = None, + tag_origin: str = "manual", + added_by: Optional[str] = None, +) -> AssetInfo: + """Create a new AssetInfo referencing an existing Asset (no content write).""" + now = utcnow() + info = AssetInfo( + owner_id=None, + name=name, + asset_hash=asset_hash, + preview_hash=None, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(info) + await session.flush() # get info.id + + if user_metadata is not None: + await replace_asset_info_metadata_projection( + session, asset_info_id=info.id, user_metadata=user_metadata + ) + + if tags is not None: + await set_asset_info_tags( + session, + asset_info_id=info.id, + tags=tags, + origin=tag_origin, + added_by=added_by, + ) + return info + + async def set_asset_info_tags( session: AsyncSession, *, From a82577f64aaba1a6b871f9ddc57f1f9ea7531e32 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 24 Aug 2025 15:08:53 +0300 Subject: [PATCH 13/82] auto-creation of tags and fixed population DB when cloned asset is already present --- alembic_db/versions/0001_assets.py | 16 +-- app/assets_manager.py | 10 +- app/database/models.py | 6 +- app/database/services.py | 167 ++++++++++++++++------------- 4 files changed, 114 insertions(+), 85 deletions(-) diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index 47bb43dd8dd0..cdda63fbef6f 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -18,7 +18,7 @@ def upgrade() -> None: # ASSETS: content identity (deduplicated by hash) op.create_table( "assets", - sa.Column("hash", sa.String(length=128), primary_key=True), + sa.Column("hash", sa.String(length=256), primary_key=True), sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), sa.Column("mime_type", sa.String(length=255), nullable=True), sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"), @@ -36,14 +36,15 @@ def upgrade() -> None: op.create_table( "assets_info", sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), - sa.Column("owner_id", sa.String(length=128), nullable=True), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), sa.Column("name", sa.String(length=512), nullable=False), - sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False), - sa.Column("preview_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True), + sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True), sa.Column("user_metadata", sa.JSON(), nullable=True), sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"), sqlite_autoincrement=True, ) op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) @@ -65,7 +66,7 @@ def upgrade() -> None: op.create_table( "asset_info_tags", sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), - sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), sa.Column("added_by", sa.String(length=128), nullable=True), sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), @@ -77,7 +78,7 @@ def upgrade() -> None: # ASSET_LOCATOR_STATE: 1:1 filesystem metadata(for fast integrity checking) for an Asset records op.create_table( "asset_locator_state", - sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True), + sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True), sa.Column("mtime_ns", sa.BigInteger(), nullable=True), sa.Column("etag", sa.String(length=256), nullable=True), sa.Column("last_modified", sa.String(length=128), nullable=True), @@ -112,6 +113,8 @@ def upgrade() -> None: [ # Core concept tags {"name": "models", "tag_type": "system"}, + {"name": "input", "tag_type": "system"}, + {"name": "output", "tag_type": "system"}, # Canonical single-word types {"name": "checkpoint", "tag_type": "system"}, @@ -150,6 +153,7 @@ def downgrade() -> None: op.drop_index("ix_tags_tag_type", table_name="tags") op.drop_table("tags") + op.drop_constraint("uq_assets_info_hash_owner_name", table_name="assets_info") op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") op.drop_index("ix_assets_info_created_at", table_name="assets_info") op.drop_index("ix_assets_info_name", table_name="assets_info") diff --git a/app/assets_manager.py b/app/assets_manager.py index f92232a3dbe0..cece144864a1 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -1,6 +1,7 @@ import mimetypes import os from typing import Optional, Sequence +from pathlib import Path from comfy.cli_args import args from comfy_api.internal import async_to_sync @@ -34,8 +35,13 @@ async def asset_exists(*, asset_hash: str) -> bool: def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None: if not args.disable_model_processing: + p = Path(file_name) + dir_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] async_to_sync.AsyncToSyncConverter.run_async_in_thread( - add_local_asset, tags=tags, file_name=file_name, file_path=file_path + add_local_asset, + tags=list(dict.fromkeys([*tags, *dir_parts])), + file_name=p.name, + file_path=file_path, ) @@ -114,7 +120,7 @@ async def list_assets( size=int(asset.size_bytes) if asset else None, mime_type=asset.mime_type if asset else None, tags=tags, - preview_url=f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later + preview_url=f"/api/v1/assets/{info.id}/content", created_at=info.created_at, updated_at=info.updated_at, last_access_time=info.last_access_time, diff --git a/app/database/models.py b/app/database/models.py index 06e46815dcdc..20b88ca68705 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -7,6 +7,7 @@ DateTime, ForeignKey, Index, + UniqueConstraint, JSON, String, Text, @@ -118,7 +119,7 @@ class AssetInfo(Base): __tablename__ = "assets_info" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - owner_id: Mapped[str | None] = mapped_column(String(128)) + owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") name: Mapped[str] = mapped_column(String(512), nullable=False) asset_hash: Mapped[str] = mapped_column( String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False @@ -169,6 +170,8 @@ class AssetInfo(Base): ) __table_args__ = ( + UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"), + Index("ix_assets_info_owner_name", "owner_id", "name"), Index("ix_assets_info_owner_id", "owner_id"), Index("ix_assets_info_asset_hash", "asset_hash"), Index("ix_assets_info_name", "name"), @@ -186,7 +189,6 @@ def __repr__(self) -> str: return f"" - class AssetInfoMeta(Base): __tablename__ = "asset_info_meta" diff --git a/app/database/services.py b/app/database/services.py index b916a2055a89..960788f9e395 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -1,3 +1,4 @@ +import contextlib import os import logging from collections import defaultdict @@ -75,7 +76,7 @@ async def ingest_fs_asset( mtime_ns: int, mime_type: Optional[str] = None, info_name: Optional[str] = None, - owner_id: Optional[str] = None, + owner_id: str = "", preview_hash: Optional[str] = None, user_metadata: Optional[dict] = None, tags: Sequence[str] = (), @@ -94,7 +95,7 @@ async def ingest_fs_asset( - Create an AssetInfo (no refcount changes). - Link provided tags to that AssetInfo. * If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table. - * If False (default), silently skips unknown tags. + * If False (default), create unknown tags. Returns flags and ids: { @@ -103,8 +104,6 @@ async def ingest_fs_asset( "state_created": bool, "state_updated": bool, "asset_info_id": int | None, - "tags_added": list[str], - "tags_missing": list[str], # filled only when require_existing_tags=False } """ locator = os.path.abspath(abs_path) @@ -116,13 +115,11 @@ async def ingest_fs_asset( "state_created": False, "state_updated": False, "asset_info_id": None, - "tags_added": [], - "tags_missing": [], } # ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ---- - async with session.begin_nested() as sp1: - try: + with contextlib.suppress(IntegrityError): + async with session.begin_nested(): session.add( Asset( hash=asset_hash, @@ -137,27 +134,29 @@ async def ingest_fs_asset( ) await session.flush() out["asset_created"] = True - except IntegrityError: - await sp1.rollback() - # Already exists by hash -> update selected fields if different - existing = await session.get(Asset, asset_hash) - if existing is not None: - desired_size = int(size_bytes) - if existing.size_bytes != desired_size: - existing.size_bytes = desired_size - existing.updated_at = datetime_now - out["asset_updated"] = True - else: - # This should not occur. Log for visibility. - logging.error("Asset %s not found after conflict; skipping update.", asset_hash) - except Exception: - await sp1.rollback() - logging.exception("Unexpected error inserting Asset (hash=%s, locator=%s)", asset_hash, locator) - raise + + if not out["asset_created"]: + existing = await session.get(Asset, asset_hash) + if existing is not None: + changed = False + if existing.size_bytes != size_bytes: + existing.size_bytes = size_bytes + changed = True + if mime_type and existing.mime_type != mime_type: + existing.mime_type = mime_type + changed = True + if existing.storage_locator != locator: + existing.storage_locator = locator + changed = True + if changed: + existing.updated_at = datetime_now + out["asset_updated"] = True + else: + logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash) # ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ---- - async with session.begin_nested() as sp2: - try: + with contextlib.suppress(IntegrityError): + async with session.begin_nested(): session.add( AssetLocatorState( asset_hash=asset_hash, @@ -166,26 +165,22 @@ async def ingest_fs_asset( ) await session.flush() out["state_created"] = True - except IntegrityError: - await sp2.rollback() - state = await session.get(AssetLocatorState, asset_hash) - if state is not None: - desired_mtime = int(mtime_ns) - if state.mtime_ns != desired_mtime: - state.mtime_ns = desired_mtime - out["state_updated"] = True - else: - logging.debug("Locator state missing for %s after conflict; skipping update.", asset_hash) - except Exception: - await sp2.rollback() - logging.exception("Unexpected error inserting AssetLocatorState (hash=%s)", asset_hash) - raise + + if not out["state_created"]: + state = await session.get(AssetLocatorState, asset_hash) + if state is not None: + desired_mtime = int(mtime_ns) + if state.mtime_ns != desired_mtime: + state.mtime_ns = desired_mtime + out["state_updated"] = True + else: + logging.error("Locator state missing for %s after conflict; skipping update.", asset_hash) # ---- Optional: AssetInfo + tag links ---- if info_name: - # 2a) Create AssetInfo (no refcount bump) - async with session.begin_nested() as sp3: - try: + # 2a) Upsert AssetInfo idempotently on (asset_hash, owner_id, name) + with contextlib.suppress(IntegrityError): + async with session.begin_nested(): info = AssetInfo( owner_id=owner_id, name=info_name, @@ -198,16 +193,35 @@ async def ingest_fs_asset( session.add(info) await session.flush() # get info.id out["asset_info_id"] = info.id - except Exception: - await sp3.rollback() - logging.exception( - "Unexpected error inserting AssetInfo (hash=%s, name=%s)", asset_hash, info_name + + existing_info = ( + await session.execute( + select(AssetInfo) + .where( + AssetInfo.asset_hash == asset_hash, + AssetInfo.name == info_name, + (AssetInfo.owner_id == owner_id), ) - raise + .limit(1) + ) + ).unique().scalar_one_or_none() + if not existing_info: + raise RuntimeError("Failed to update or insert AssetInfo.") + + if preview_hash is not None and existing_info.preview_hash != preview_hash: + existing_info.preview_hash = preview_hash + existing_info.updated_at = datetime_now + if existing_info.last_access_time < datetime_now: + existing_info.last_access_time = datetime_now + await session.flush() + out["asset_info_id"] = existing_info.id # 2b) Link tags (if any). We DO NOT create new Tag rows here by default. norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] if norm and out["asset_info_id"] is not None: + if not require_existing_tags: + await _ensure_tags_exist(session, norm, tag_type="user") + # Which tags exist? existing_tag_names = set( name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() @@ -240,8 +254,6 @@ async def ingest_fs_asset( ] ) await session.flush() - out["tags_added"] = to_add - out["tags_missing"] = missing # 2c) Rebuild metadata projection if provided if user_metadata is not None and out["asset_info_id"] is not None: @@ -420,7 +432,7 @@ async def create_asset_info_for_existing_asset( """Create a new AssetInfo referencing an existing Asset (no content write).""" now = utcnow() info = AssetInfo( - owner_id=None, + owner_id="", name=name, asset_hash=asset_hash, preview_hash=None, @@ -688,39 +700,44 @@ async def add_tags_to_asset_info( if create_if_missing: await _ensure_tags_exist(session, norm, tag_type="user") - # Current links - existing = { - tname - for (tname,) in ( + # Snapshot current links + current = { + tag_name + for (tag_name,) in ( await session.execute( sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) ) ).all() } - to_add = [t for t in norm if t not in existing] - already = [t for t in norm if t in existing] + want = set(norm) + to_add = sorted(want - current) if to_add: - # Make insert race-safe with a nested tx; ignore dup conflicts if any. - async with session.begin_nested(): - session.add_all([ - AssetInfoTag( - asset_info_id=asset_info_id, - tag_name=t, - origin=origin, - added_by=added_by, - added_at=utcnow(), - ) for t in to_add - ]) + async with session.begin_nested() as nested: try: + session.add_all( + [ + AssetInfoTag( + asset_info_id=asset_info_id, + tag_name=t, + origin=origin, + added_by=added_by, + added_at=utcnow(), + ) + for t in to_add + ] + ) await session.flush() except IntegrityError: - # Another writer linked the same tag at the same time -> ok, treat as already present. - await session.rollback() + await nested.rollback() - total = await get_asset_tags(session, asset_info_id=asset_info_id) - return {"added": sorted(set(to_add)), "already_present": sorted(set(already)), "total_tags": total} + after = set(await get_asset_tags(session, asset_info_id=asset_info_id)) + return { + "added": sorted(((after - current) & want)), + "already_present": sorted(want & current), + "total_tags": sorted(after), + } async def remove_tags_from_asset_info( @@ -742,8 +759,8 @@ async def remove_tags_from_asset_info( return {"removed": [], "not_present": [], "total_tags": total} existing = { - tname - for (tname,) in ( + tag_name + for (tag_name,) in ( await session.execute( sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) ) From d7464e9e73846c5b175a1dc69bcac5c2c3d98e7e Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 24 Aug 2025 18:27:08 +0300 Subject: [PATCH 14/82] implemented assets scaner --- alembic_db/versions/0001_assets.py | 7 +- app/api/assets_routes.py | 27 ++- app/api/schemas_in.py | 26 +++ app/api/schemas_out.py | 19 +- app/assets_scanner.py | 319 +++++++++++++++++++++++++++++ 5 files changed, 393 insertions(+), 5 deletions(-) create mode 100644 app/assets_scanner.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index cdda63fbef6f..7fc054652278 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -111,12 +111,12 @@ def upgrade() -> None: op.bulk_insert( tags_table, [ - # Core concept tags + # Root folder tags {"name": "models", "tag_type": "system"}, {"name": "input", "tag_type": "system"}, {"name": "output", "tag_type": "system"}, - # Canonical single-word types + # Core tags {"name": "checkpoint", "tag_type": "system"}, {"name": "lora", "tag_type": "system"}, {"name": "vae", "tag_type": "system"}, @@ -130,9 +130,10 @@ def upgrade() -> None: {"name": "vae-approx", "tag_type": "system"}, {"name": "gligen", "tag_type": "system"}, {"name": "style-model", "tag_type": "system"}, + {"name": "photomaker", "tag_type": "system"}, + {"name": "classifier", "tag_type": "system"}, {"name": "encoder", "tag_type": "system"}, {"name": "decoder", "tag_type": "system"}, - # TODO: decide what to do with: photomaker, classifiers ], ) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 63610099824d..be3005a298ae 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -4,7 +4,7 @@ from aiohttp import web from pydantic import ValidationError -from .. import assets_manager +from .. import assets_manager, assets_scanner from . import schemas_in @@ -225,6 +225,31 @@ async def delete_asset_tags(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) +@ROUTES.post("/api/assets/scan/schedule") +async def schedule_asset_scan(request: web.Request) -> web.Response: + try: + payload = await request.json() + except Exception: + payload = {} + + try: + body = schemas_in.ScheduleAssetScanBody.model_validate(payload) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + + states = await assets_scanner.schedule_scans(body.roots) + return web.json_response(states.model_dump(mode="json"), status=202) + + +@ROUTES.get("/api/assets/scan") +async def get_asset_scan_status(request: web.Request) -> web.Response: + root = request.query.get("root", "").strip().lower() + states = assets_scanner.current_statuses() + if root in {"models", "input", "output"}: + states = [s for s in states.scans if s.root == root] # type: ignore + return web.json_response(states.model_dump(mode="json"), status=200) + + def register_assets_routes(app: web.Application) -> None: app.add_routes(ROUTES) diff --git a/app/api/schemas_in.py b/app/api/schemas_in.py index 0f07bf19d615..fa42146d3306 100644 --- a/app/api/schemas_in.py +++ b/app/api/schemas_in.py @@ -146,3 +146,29 @@ def normalize_tags(cls, v: list[str]) -> list[str]: class TagsRemove(TagsAdd): pass + + +class ScheduleAssetScanBody(BaseModel): + roots: list[Literal["models","input","output"]] = Field(default_factory=list) + + @field_validator("roots", mode="before") + @classmethod + def _normalize_roots(cls, v): + if v is None: + return [] + if isinstance(v, str): + items = [x.strip().lower() for x in v.split(",")] + elif isinstance(v, list): + items = [] + for x in v: + if isinstance(x, str): + items.extend([p.strip().lower() for p in x.split(",")]) + else: + return [] + out = [] + seen = set() + for r in items: + if r in {"models","input","output"} and r not in seen: + out.append(r) + seen.add(r) + return out diff --git a/app/api/schemas_out.py b/app/api/schemas_out.py index 0a71b8bc960d..8aca0ee012b7 100644 --- a/app/api/schemas_out.py +++ b/app/api/schemas_out.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Optional +from typing import Any, Literal, Optional from pydantic import BaseModel, ConfigDict, Field, field_serializer @@ -87,3 +87,20 @@ class TagsRemove(BaseModel): removed: list[str] = Field(default_factory=list) not_present: list[str] = Field(default_factory=list) total_tags: list[str] = Field(default_factory=list) + + +class AssetScanStatus(BaseModel): + scan_id: str + root: Literal["models","input","output"] + status: Literal["scheduled","running","completed","failed","cancelled"] + scheduled_at: Optional[str] = None + started_at: Optional[str] = None + finished_at: Optional[str] = None + discovered: int = 0 + processed: int = 0 + errors: int = 0 + last_error: Optional[str] = None + + +class AssetScanStatusResponse(BaseModel): + scans: list[AssetScanStatus] = Field(default_factory=list) diff --git a/app/assets_scanner.py b/app/assets_scanner.py new file mode 100644 index 000000000000..7ffef80b3b57 --- /dev/null +++ b/app/assets_scanner.py @@ -0,0 +1,319 @@ +import asyncio +import logging +import os +import uuid +import time +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Literal, Optional, Sequence + +from . import assets_manager +from .api import schemas_out + +import folder_paths + +LOGGER = logging.getLogger(__name__) + +RootType = Literal["models", "input", "output"] +ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") + +# We run at most one scan per root; overall max parallelism is therefore 3 +# We also bound per-scan ingestion concurrency to avoid swamping threads/DB +DEFAULT_PER_SCAN_CONCURRENCY = 1 + + +@dataclass +class ScanProgress: + scan_id: str + root: RootType + status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled" + scheduled_at: float = field(default_factory=lambda: time.time()) + started_at: Optional[float] = None + finished_at: Optional[float] = None + + discovered: int = 0 + processed: int = 0 + errors: int = 0 + last_error: Optional[str] = None + + # Optional details for diagnostics + details: dict[str, int] = field(default_factory=dict) + + +RUNNING_TASKS: dict[RootType, asyncio.Task] = {} +PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {} + + +def _new_scan_id(root: RootType) -> str: + return f"scan-{root}-{uuid.uuid4().hex[:8]}" + + +def current_statuses() -> schemas_out.AssetScanStatusResponse: + # make shallow copies to avoid external mutation + states = [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT] + return schemas_out.AssetScanStatusResponse( + scans=[ + schemas_out.AssetScanStatus( + scan_id=s.scan_id, + root=s.root, + status=s.status, + scheduled_at=_ts_to_iso(s.scheduled_at), + started_at=_ts_to_iso(s.started_at), + finished_at=_ts_to_iso(s.finished_at), + discovered=s.discovered, + processed=s.processed, + errors=s.errors, + last_error=s.last_error, + ) + for s in states + ] + ) + + +async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse: + """Schedule scans for the provided roots; returns progress snapshots. + + Rules: + - Only roots in {models, input, output} are accepted. + - If a root is already scanning, we do NOT enqueue another one. Status returned as-is. + - Otherwise a new task is created and started immediately. + - Files with zero size are skipped. + """ + normalized: list[RootType] = [] + seen = set() + for r in roots or []: + if not isinstance(r, str): + continue + rr = r.strip().lower() + if rr in ALLOWED_ROOTS and rr not in seen: + normalized.append(rr) # type: ignore + seen.add(rr) + if not normalized: + normalized = list(ALLOWED_ROOTS) # schedule all by default + + results: list[ScanProgress] = [] + for root in normalized: + if root in RUNNING_TASKS and not RUNNING_TASKS[root].done(): + # already running; return the live progress object + results.append(PROGRESS_BY_ROOT[root]) + continue + + # Create fresh progress + prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled") + PROGRESS_BY_ROOT[root] = prog + + # Start task + task = asyncio.create_task(_run_scan_for_root(root, prog), name=f"asset-scan:{root}") + RUNNING_TASKS[root] = task + results.append(prog) + + return schemas_out.AssetScanStatusResponse( + scans=[ + schemas_out.AssetScanStatus( + scan_id=s.scan_id, + root=s.root, + status=s.status, + scheduled_at=_ts_to_iso(s.scheduled_at), + started_at=_ts_to_iso(s.started_at), + finished_at=_ts_to_iso(s.finished_at), + discovered=s.discovered, + processed=s.processed, + errors=s.errors, + last_error=s.last_error, + ) + for s in results + ] + ) + + +async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None: + prog.started_at = time.time() + prog.status = "running" + try: + if root == "models": + await _scan_models(prog) + elif root == "input": + base = folder_paths.get_input_directory() + await _scan_directory_tree(base, root, prog) + elif root == "output": + base = folder_paths.get_output_directory() + await _scan_directory_tree(base, root, prog) + else: + raise RuntimeError(f"Unsupported root: {root}") + prog.status = "completed" + except asyncio.CancelledError: + prog.status = "cancelled" + raise + except Exception as exc: + LOGGER.exception("Asset scan failed for %s", root) + prog.status = "failed" + prog.errors += 1 + prog.last_error = str(exc) + finally: + prog.finished_at = time.time() + # Drop the task entry if it's the current one + t = RUNNING_TASKS.get(root) + if t and t.done(): + RUNNING_TASKS.pop(root, None) + + +async def _scan_models(prog: ScanProgress) -> None: + # Iterate all folder_names whose base paths lie under the Comfy 'models' directory + models_root = os.path.abspath(os.path.join(folder_paths.base_path, "models")) + + # Build list of (folder_name, base_paths[]) that are configured for this category. + # If any path for the category lies under 'models', include the category. + targets: list[tuple[str, list[str]]] = [] + for name, (paths, _exts) in folder_paths.folder_names_and_paths.items(): + if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): + targets.append((name, paths)) + + plans: list[tuple[str, str]] = [] # (abs_path, file_name_for_tags) + per_bucket: dict[str, int] = {} + + for folder_name, bases in targets: + rel_files = folder_paths.get_filename_list(folder_name) or [] + count_valid = 0 + + for rel_path in rel_files: + abs_path = folder_paths.get_full_path(folder_name, rel_path) + if not abs_path: + continue + abs_path = os.path.abspath(abs_path) + + # Extra safety: ensure file is inside one of the allowed base paths + allowed = False + for base in bases: + base_abs = os.path.abspath(base) + try: + common = os.path.commonpath([abs_path, base_abs]) + except ValueError: + common = "" # Different drives on Windows + if common == base_abs: + allowed = True + break + if not allowed: + LOGGER.warning("Skipping file outside models base: %s", abs_path) + continue + + try: + if not os.path.getsize(abs_path): + continue + except OSError as e: + LOGGER.warning("Could not stat %s: %s – skipping", abs_path, e) + continue + + file_name_for_tags = os.path.join(folder_name, rel_path) + plans.append((abs_path, file_name_for_tags)) + count_valid += 1 + + if count_valid: + per_bucket[folder_name] = per_bucket.get(folder_name, 0) + count_valid + + prog.discovered = len(plans) + for k, v in per_bucket.items(): + prog.details[k] = prog.details.get(k, 0) + v + + if not plans: + LOGGER.info("Model scan %s: nothing to ingest", prog.scan_id) + return + + sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY) + tasks: list[asyncio.Task] = [] + + for abs_path, name_for_tags in plans: + async def worker(fp_abs: str = abs_path, fn_rel: str = name_for_tags): + try: + # Offload sync ingestion into a thread + await asyncio.to_thread( + assets_manager.populate_db_with_asset, + ["models"], + fn_rel, + fp_abs, + ) + except Exception as e: + prog.errors += 1 + prog.last_error = str(e) + LOGGER.debug("Error ingesting %s: %s", fp_abs, e) + finally: + prog.processed += 1 + sem.release() + + await sem.acquire() + tasks.append(asyncio.create_task(worker())) + + if tasks: + await asyncio.gather(*tasks) + LOGGER.info( + "Model scan %s finished: discovered=%d processed=%d errors=%d", + prog.scan_id, prog.discovered, prog.processed, prog.errors + ) + + +def _count_files_in_tree(base_abs: str) -> int: + if not os.path.isdir(base_abs): + return 0 + total = 0 + for _dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): + total += len(filenames) + return total + + +async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress) -> None: + # Guard: base_dir must be a directory + base_abs = os.path.abspath(base_dir) + if not os.path.isdir(base_abs): + LOGGER.info("Scan root %s skipped: base directory missing: %s", root, base_abs) + return + + prog.discovered = _count_files_in_tree(base_abs) + + sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY) + tasks: list[asyncio.Task] = [] + for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): + for name in filenames: + rel = os.path.relpath(os.path.join(dirpath, name), base_abs) + abs_path = os.path.join(base_abs, rel) + # Safety: ensure within base + try: + if os.path.commonpath([os.path.abspath(abs_path), base_abs]) != base_abs: + LOGGER.warning("Skipping path outside root %s: %s", root, abs_path) + continue + except ValueError: + continue + + async def worker(fp_abs: str = abs_path, fn_rel: str = rel): + try: + await asyncio.to_thread( + assets_manager.populate_db_with_asset, + [root], + fn_rel, + fp_abs, + ) + except Exception as e: + prog.errors += 1 + prog.last_error = str(e) + finally: + prog.processed += 1 + sem.release() + + await sem.acquire() + tasks.append(asyncio.create_task(worker())) + + if tasks: + await asyncio.gather(*tasks) + + LOGGER.info( + "%s scan %s finished: discovered=%d processed=%d errors=%d", + root.capitalize(), prog.scan_id, prog.discovered, prog.processed, prog.errors + ) + + +def _ts_to_iso(ts: Optional[float]) -> Optional[str]: + if ts is None: + return None + # interpret ts as seconds since epoch UTC and return naive UTC (consistent with other models) + try: + return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat() + except Exception: + return None From 09dabf95bc8df60ff871dc230745235cebcb7d06 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 25 Aug 2025 13:31:56 +0300 Subject: [PATCH 15/82] refactoring: use the same code for "scan task" and realtime DB population --- alembic_db/versions/0001_assets.py | 26 ++++---- app/_assets_helpers.py | 99 ++++++++++++++++++++++++++++++ app/assets_manager.py | 25 +++++--- app/assets_scanner.py | 80 +++++++++++------------- app/database/services.py | 18 +++--- comfy/utils.py | 2 + nodes.py | 26 +------- 7 files changed, 178 insertions(+), 98 deletions(-) create mode 100644 app/_assets_helpers.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index 7fc054652278..b180edacc3a5 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -117,21 +117,25 @@ def upgrade() -> None: {"name": "output", "tag_type": "system"}, # Core tags - {"name": "checkpoint", "tag_type": "system"}, - {"name": "lora", "tag_type": "system"}, + {"name": "configs", "tag_type": "system"}, + {"name": "checkpoints", "tag_type": "system"}, + {"name": "loras", "tag_type": "system"}, {"name": "vae", "tag_type": "system"}, - {"name": "text-encoder", "tag_type": "system"}, - {"name": "clip-vision", "tag_type": "system"}, - {"name": "embedding", "tag_type": "system"}, + {"name": "text_encoders", "tag_type": "system"}, + {"name": "diffusion_models", "tag_type": "system"}, + {"name": "clip_vision", "tag_type": "system"}, + {"name": "style_models", "tag_type": "system"}, + {"name": "embeddings", "tag_type": "system"}, + {"name": "diffusers", "tag_type": "system"}, + {"name": "vae_approx", "tag_type": "system"}, {"name": "controlnet", "tag_type": "system"}, - {"name": "upscale", "tag_type": "system"}, - {"name": "diffusion-model", "tag_type": "system"}, - {"name": "hypernetwork", "tag_type": "system"}, - {"name": "vae-approx", "tag_type": "system"}, {"name": "gligen", "tag_type": "system"}, - {"name": "style-model", "tag_type": "system"}, + {"name": "upscale_models", "tag_type": "system"}, + {"name": "hypernetworks", "tag_type": "system"}, {"name": "photomaker", "tag_type": "system"}, - {"name": "classifier", "tag_type": "system"}, + {"name": "classifiers", "tag_type": "system"}, + + # Extra basic tags (used for vae_approx, ...) {"name": "encoder", "tag_type": "system"}, {"name": "decoder", "tag_type": "system"}, ], diff --git a/app/_assets_helpers.py b/app/_assets_helpers.py new file mode 100644 index 000000000000..49adfaaeec58 --- /dev/null +++ b/app/_assets_helpers.py @@ -0,0 +1,99 @@ +import os +from pathlib import Path +from typing import Optional, Literal, Sequence + +import folder_paths + + +def get_comfy_models_folders() -> list[tuple[str, list[str]]]: + """Build a list of (folder_name, base_paths[]) categories that are configured for model locations. + + We trust `folder_paths.folder_names_and_paths` and include a category if + *any* of its base paths lies under the Comfy `models_dir`. + """ + targets: list[tuple[str, list[str]]] = [] + models_root = os.path.abspath(folder_paths.models_dir) + for name, (paths, _exts) in folder_paths.folder_names_and_paths.items(): + if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): + targets.append((name, paths)) + return targets + + +def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]: + """Given an absolute or relative file path, determine which root category the path belongs to: + - 'input' if the file resides under `folder_paths.get_input_directory()` + - 'output' if the file resides under `folder_paths.get_output_directory()` + - 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()` + + Returns: + (root_category, relative_path_inside_that_root) + For 'models', the relative path is prefixed with the category name: + e.g. ('models', 'vae/test/sub/ae.safetensors') + + Raises: + ValueError: if the path does not belong to input, output, or configured model bases. + """ + fp_abs = os.path.abspath(file_path) + + def _is_within(child: str, parent: str) -> bool: + try: + return os.path.commonpath([child, parent]) == parent + except Exception: + return False + + def _rel(child: str, parent: str) -> str: + return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep) + + # 1) input + input_base = os.path.abspath(folder_paths.get_input_directory()) + if _is_within(fp_abs, input_base): + return "input", _rel(fp_abs, input_base) + + # 2) output + output_base = os.path.abspath(folder_paths.get_output_directory()) + if _is_within(fp_abs, output_base): + return "output", _rel(fp_abs, output_base) + + # 3) models (check deepest matching base to avoid ambiguity) + best: Optional[tuple[int, str, str]] = None # (base_len, bucket, rel_inside_bucket) + for bucket, bases in get_comfy_models_folders(): + for b in bases: + base_abs = os.path.abspath(b) + if not _is_within(fp_abs, base_abs): + continue + cand = (len(base_abs), bucket, _rel(fp_abs, base_abs)) + if best is None or cand[0] > best[0]: + best = cand + + if best is not None: + _, bucket, rel_inside = best + combined = os.path.join(bucket, rel_inside) + return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) + + raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}") + + +def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: + """Return a tuple (name, tags) derived from a filesystem path. + + Semantics: + - Root category is determined by `get_relative_to_root_category_path_of_asset`. + - The returned `name` is the base filename with extension from the relative path. + - The returned `tags` are: + [root_category] + parent folders of the relative path (in order) + For 'models', this means: + file '/.../ModelsDir/vae/test_tag/ae.safetensors' + -> root_category='models', some_path='vae/test_tag/ae.safetensors' + -> name='ae.safetensors', tags=['models', 'vae', 'test_tag'] + + Raises: + ValueError: if the path does not belong to input, output, or configured model bases. + """ + root_category, some_path = get_relative_to_root_category_path_of_asset(file_path) + p = Path(some_path) + parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] + return p.name, normalize_tags([root_category, *parent_parts]) + + +def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]: + return [t.strip().lower() for t in (tags or []) if (t or "").strip()] diff --git a/app/assets_manager.py b/app/assets_manager.py index cece144864a1..0c008b47148f 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -1,7 +1,7 @@ +import logging import mimetypes import os from typing import Optional, Sequence -from pathlib import Path from comfy.cli_args import args from comfy_api.internal import async_to_sync @@ -26,6 +26,7 @@ create_asset_info_for_existing_asset, ) from .api import schemas_out +from ._assets_helpers import get_name_and_tags_from_asset_path async def asset_exists(*, asset_hash: str) -> bool: @@ -33,16 +34,20 @@ async def asset_exists(*, asset_hash: str) -> bool: return await asset_exists_by_hash(session, asset_hash=asset_hash) -def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None: +def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None: if not args.disable_model_processing: - p = Path(file_name) - dir_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] - async_to_sync.AsyncToSyncConverter.run_async_in_thread( - add_local_asset, - tags=list(dict.fromkeys([*tags, *dir_parts])), - file_name=p.name, - file_path=file_path, - ) + if tags is None: + tags = [] + try: + asset_name, path_tags = get_name_and_tags_from_asset_path(file_path) + async_to_sync.AsyncToSyncConverter.run_async_in_thread( + add_local_asset, + tags=list(dict.fromkeys([*path_tags, *tags])), + file_name=asset_name, + file_path=file_path, + ) + except ValueError: + logging.exception("Cant parse '%s' as an asset file path.", file_path) async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None: diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 7ffef80b3b57..691472156cd0 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -7,10 +7,11 @@ from datetime import datetime, timezone from typing import Literal, Optional, Sequence +import folder_paths + from . import assets_manager from .api import schemas_out - -import folder_paths +from ._assets_helpers import get_comfy_models_folders LOGGER = logging.getLogger(__name__) @@ -36,7 +37,7 @@ class ScanProgress: errors: int = 0 last_error: Optional[str] = None - # Optional details for diagnostics + # Optional details for diagnostics (e.g., files per bucket) details: dict[str, int] = field(default_factory=dict) @@ -49,8 +50,6 @@ def _new_scan_id(root: RootType) -> str: def current_statuses() -> schemas_out.AssetScanStatusResponse: - # make shallow copies to avoid external mutation - states = [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT] return schemas_out.AssetScanStatusResponse( scans=[ schemas_out.AssetScanStatus( @@ -65,7 +64,7 @@ def current_statuses() -> schemas_out.AssetScanStatusResponse: errors=s.errors, last_error=s.last_error, ) - for s in states + for s in [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT] ] ) @@ -94,15 +93,12 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes results: list[ScanProgress] = [] for root in normalized: if root in RUNNING_TASKS and not RUNNING_TASKS[root].done(): - # already running; return the live progress object results.append(PROGRESS_BY_ROOT[root]) continue - # Create fresh progress prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled") PROGRESS_BY_ROOT[root] = prog - # Start task task = asyncio.create_task(_run_scan_for_root(root, prog), name=f"asset-scan:{root}") RUNNING_TASKS[root] = task results.append(prog) @@ -151,24 +147,21 @@ async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None: prog.last_error = str(exc) finally: prog.finished_at = time.time() - # Drop the task entry if it's the current one t = RUNNING_TASKS.get(root) if t and t.done(): RUNNING_TASKS.pop(root, None) async def _scan_models(prog: ScanProgress) -> None: - # Iterate all folder_names whose base paths lie under the Comfy 'models' directory - models_root = os.path.abspath(os.path.join(folder_paths.base_path, "models")) - - # Build list of (folder_name, base_paths[]) that are configured for this category. - # If any path for the category lies under 'models', include the category. - targets: list[tuple[str, list[str]]] = [] - for name, (paths, _exts) in folder_paths.folder_names_and_paths.items(): - if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): - targets.append((name, paths)) + """ + Scan all configured model buckets from folder_paths.folder_names_and_paths, + restricted to entries whose base paths lie under folder_paths.models_dir + (per get_comfy_models_folders). We trust those mappings and do not try to + infer anything else here. + """ + targets: list[tuple[str, list[str]]] = get_comfy_models_folders() - plans: list[tuple[str, str]] = [] # (abs_path, file_name_for_tags) + plans: list[str] = [] # absolute file paths to ingest per_bucket: dict[str, int] = {} for folder_name, bases in targets: @@ -198,13 +191,12 @@ async def _scan_models(prog: ScanProgress) -> None: try: if not os.path.getsize(abs_path): - continue + continue # skip empty files except OSError as e: LOGGER.warning("Could not stat %s: %s – skipping", abs_path, e) continue - file_name_for_tags = os.path.join(folder_name, rel_path) - plans.append((abs_path, file_name_for_tags)) + plans.append(abs_path) count_valid += 1 if count_valid: @@ -221,16 +213,12 @@ async def _scan_models(prog: ScanProgress) -> None: sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY) tasks: list[asyncio.Task] = [] - for abs_path, name_for_tags in plans: - async def worker(fp_abs: str = abs_path, fn_rel: str = name_for_tags): + for abs_path in plans: + async def worker(fp_abs: str = abs_path): try: - # Offload sync ingestion into a thread - await asyncio.to_thread( - assets_manager.populate_db_with_asset, - ["models"], - fn_rel, - fp_abs, - ) + # Offload sync ingestion into a thread; populate_db_with_asset + # derives name and tags from the path using _assets_helpers. + await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs) except Exception as e: prog.errors += 1 prog.last_error = str(e) @@ -260,7 +248,10 @@ def _count_files_in_tree(base_abs: str) -> int: async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress) -> None: - # Guard: base_dir must be a directory + """ + Generic scanner for input/output roots. We pass only the absolute path to + populate_db_with_asset and let it derive the relative name and tags. + """ base_abs = os.path.abspath(base_dir) if not os.path.isdir(base_abs): LOGGER.info("Scan root %s skipped: base directory missing: %s", root, base_abs) @@ -272,24 +263,27 @@ async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress tasks: list[asyncio.Task] = [] for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): for name in filenames: - rel = os.path.relpath(os.path.join(dirpath, name), base_abs) - abs_path = os.path.join(base_abs, rel) + abs_path = os.path.abspath(os.path.join(dirpath, name)) + # Safety: ensure within base try: - if os.path.commonpath([os.path.abspath(abs_path), base_abs]) != base_abs: + if os.path.commonpath([abs_path, base_abs]) != base_abs: LOGGER.warning("Skipping path outside root %s: %s", root, abs_path) continue except ValueError: continue - async def worker(fp_abs: str = abs_path, fn_rel: str = rel): + # Skip empty files and handle stat errors + try: + if not os.path.getsize(abs_path): + continue + except OSError as e: + LOGGER.warning("Could not stat %s: %s – skipping", abs_path, e) + continue + + async def worker(fp_abs: str = abs_path): try: - await asyncio.to_thread( - assets_manager.populate_db_with_asset, - [root], - fn_rel, - fp_abs, - ) + await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs) except Exception as e: prog.errors += 1 prog.last_error = str(e) diff --git a/app/database/services.py b/app/database/services.py index 960788f9e395..34029b139d89 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -14,7 +14,7 @@ from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta from .timeutil import utcnow - +from .._assets_helpers import normalize_tags async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: @@ -471,7 +471,7 @@ async def set_asset_info_tags( Replace the tag set on an AssetInfo with `tags`. Idempotent. Creates missing tag names as 'user'. """ - desired = _normalize_tags(tags) + desired = normalize_tags(tags) # current links current = set( @@ -691,7 +691,7 @@ async def add_tags_to_asset_info( if not info: raise ValueError(f"AssetInfo {asset_info_id} not found") - norm = _normalize_tags(tags) + norm = normalize_tags(tags) if not norm: total = await get_asset_tags(session, asset_info_id=asset_info_id) return {"added": [], "already_present": [], "total_tags": total} @@ -753,7 +753,7 @@ async def remove_tags_from_asset_info( if not info: raise ValueError(f"AssetInfo {asset_info_id} not found") - norm = _normalize_tags(tags) + norm = normalize_tags(tags) if not norm: total = await get_asset_tags(session, asset_info_id=asset_info_id) return {"removed": [], "not_present": [], "total_tags": total} @@ -784,12 +784,8 @@ async def remove_tags_from_asset_info( return {"removed": to_remove, "not_present": not_present, "total_tags": total} -def _normalize_tags(tags: Optional[Sequence[str]]) -> list[str]: - return [t.strip().lower() for t in (tags or []) if (t or "").strip()] - - async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: - wanted = _normalize_tags(list(names)) + wanted = normalize_tags(list(names)) if not wanted: return [] existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() @@ -808,8 +804,8 @@ def _apply_tag_filters( exclude_tags: Optional[Sequence[str]], ) -> sa.sql.Select: """include_tags: every tag must be present; exclude_tags: none may be present.""" - include_tags = _normalize_tags(include_tags) - exclude_tags = _normalize_tags(exclude_tags) + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) if include_tags: for tag_name in include_tags: diff --git a/comfy/utils.py b/comfy/utils.py index 220492941342..f13a780e8c30 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -29,6 +29,7 @@ from torch.nn.functional import interpolate from einops import rearrange from comfy.cli_args import args +from app.assets_manager import populate_db_with_asset MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -102,6 +103,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): else: sd = pl_sd + populate_db_with_asset(ckpt) return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): diff --git a/nodes.py b/nodes.py index 04b60ab2fca7..860a236aaa4e 100644 --- a/nodes.py +++ b/nodes.py @@ -31,7 +31,6 @@ from comfy_api.internal import register_versions, ComfyAPIWithVersion from comfy_api.version_list import supported_versions from comfy_api.latest import io, ComfyExtension -from app.assets_manager import populate_db_with_asset import comfy.clip_vision @@ -555,9 +554,7 @@ def INPUT_TYPES(s): def load_checkpoint(self, config_name, ckpt_name): config_path = folder_paths.get_full_path("configs", config_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) - out = comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) - populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path) - return out + return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) class CheckpointLoaderSimple: @classmethod @@ -579,7 +576,6 @@ def INPUT_TYPES(s): def load_checkpoint(self, ckpt_name): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) - populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path) return out[:3] class DiffusersLoader: @@ -622,7 +618,6 @@ def INPUT_TYPES(s): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) - populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path) return out class CLIPSetLastLayer: @@ -681,7 +676,6 @@ def load_lora(self, model, clip, lora_name, strength_model, strength_clip): self.loaded_lora = (lora_path, lora) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) - populate_db_with_asset(["models", "lora"], lora_name, lora_path) return (model_lora, clip_lora) class LoraLoaderModelOnly(LoraLoader): @@ -746,15 +740,11 @@ def load_taesd(name): encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes)) decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes)) - encoder_path = folder_paths.get_full_path_or_raise("vae_approx", encoder) - populate_db_with_asset(["models", "vae-approx", "encoder"], name, encoder_path) - enc = comfy.utils.load_torch_file(encoder_path) + enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder)) for k in enc: sd["taesd_encoder.{}".format(k)] = enc[k] - decoder_path = folder_paths.get_full_path_or_raise("vae_approx", decoder) - populate_db_with_asset(["models", "vae-approx", "decoder"], name, decoder_path) - dec = comfy.utils.load_torch_file(decoder_path) + dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder)) for k in dec: sd["taesd_decoder.{}".format(k)] = dec[k] @@ -787,7 +777,6 @@ def load_vae(self, vae_name): else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) - populate_db_with_asset(["models", "vae"], vae_name, vae_path) vae = comfy.sd.VAE(sd=sd) vae.throw_exception_if_invalid() return (vae,) @@ -807,7 +796,6 @@ def load_controlnet(self, control_net_name): controlnet = comfy.controlnet.load_controlnet(controlnet_path) if controlnet is None: raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.") - populate_db_with_asset(["models", "controlnet"], control_net_name, controlnet_path) return (controlnet,) class DiffControlNetLoader: @@ -824,7 +812,6 @@ def INPUT_TYPES(s): def load_controlnet(self, model, control_net_name): controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) controlnet = comfy.controlnet.load_controlnet(controlnet_path, model) - populate_db_with_asset(["models", "controlnet"], control_net_name, controlnet_path) return (controlnet,) @@ -932,7 +919,6 @@ def load_unet(self, unet_name, weight_dtype): unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) - populate_db_with_asset(["models", "diffusion-model"], unet_name, unet_path) return (model,) class CLIPLoader: @@ -960,7 +946,6 @@ def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) - populate_db_with_asset(["models", "text-encoder"], clip_name, clip_path) return (clip,) class DualCLIPLoader: @@ -991,8 +976,6 @@ def load_clip(self, clip_name1, clip_name2, type, device="default"): model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) - populate_db_with_asset(["models", "text-encoder"], clip_name1, clip_path1) - populate_db_with_asset(["models", "text-encoder"], clip_name2, clip_path2) return (clip,) class CLIPVisionLoader: @@ -1010,7 +993,6 @@ def load_clip(self, clip_name): clip_vision = comfy.clip_vision.load(clip_path) if clip_vision is None: raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.") - populate_db_with_asset(["models", "clip-vision"], clip_name, clip_path) return (clip_vision,) class CLIPVisionEncode: @@ -1045,7 +1027,6 @@ def INPUT_TYPES(s): def load_style_model(self, style_model_name): style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name) style_model = comfy.sd.load_style_model(style_model_path) - populate_db_with_asset(["models", "style-model"], style_model_name, style_model_path) return (style_model,) @@ -1143,7 +1124,6 @@ def INPUT_TYPES(s): def load_gligen(self, gligen_name): gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name) gligen = comfy.sd.load_gligen(gligen_path) - populate_db_with_asset(["models", "gligen"], gligen_name, gligen_path) return (gligen,) class GLIGENTextBoxApply: From a763cbd39db955066d88bad4016d2c5284b61890 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 25 Aug 2025 15:30:55 +0300 Subject: [PATCH 16/82] add upload asset endpoint --- app/_assets_helpers.py | 33 ++++++++++ app/api/assets_routes.py | 138 ++++++++++++++++++++++++++++++++------- app/api/schemas_in.py | 90 +++++++++++++++++++++++++ app/assets_manager.py | 95 ++++++++++++++++++++++++++- 4 files changed, 332 insertions(+), 24 deletions(-) diff --git a/app/_assets_helpers.py b/app/_assets_helpers.py index 49adfaaeec58..4f1ad4446b1c 100644 --- a/app/_assets_helpers.py +++ b/app/_assets_helpers.py @@ -97,3 +97,36 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]: return [t.strip().lower() for t in (tags or []) if (t or "").strip()] + + +def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: + """Validates and maps tags -> (base_dir, subdirs_for_fs)""" + root = tags[0] + if root == "models": + if len(tags) < 2: + raise ValueError("at least two tags required for model asset") + bases = folder_paths.folder_names_and_paths[tags[1]][0] + if not bases: + raise ValueError(f"no base path configured for category '{tags[1]}'") + base_dir = os.path.abspath(bases[0]) + raw_subdirs = tags[2:] + else: + base_dir = os.path.abspath( + folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory() + ) + raw_subdirs = tags[1:] + for i in raw_subdirs: + if i in (".", ".."): + raise ValueError("invalid path component in tags") + + return base_dir, raw_subdirs if raw_subdirs else [] + + +def ensure_within_base(candidate: str, base: str) -> None: + cand_abs = os.path.abspath(candidate) + base_abs = os.path.abspath(base) + try: + if os.path.commonpath([cand_abs, base_abs]) != base_abs: + raise ValueError("destination escapes base directory") + except Exception: + raise ValueError("invalid destination path") diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index be3005a298ae..7fbc694671c3 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -1,9 +1,13 @@ +import os +import uuid import urllib.parse from typing import Optional from aiohttp import web from pydantic import ValidationError +import folder_paths + from .. import assets_manager, assets_scanner from . import schemas_in @@ -42,7 +46,6 @@ async def list_assets(request: web.Request) -> web.Response: return web.json_response(payload.model_dump(mode="json")) - @ROUTES.get("/api/assets/{id}/content") async def download_asset_content(request: web.Request) -> web.Response: asset_info_id_raw = request.match_info.get("id") @@ -75,6 +78,118 @@ async def download_asset_content(request: web.Request) -> web.Response: return resp +@ROUTES.post("/api/assets/from-hash") +async def create_asset_from_hash(request: web.Request) -> web.Response: + try: + payload = await request.json() + body = schemas_in.CreateFromHashBody.model_validate(payload) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + result = await assets_manager.create_asset_from_hash( + hash_str=body.hash, + name=body.name, + tags=body.tags, + user_metadata=body.user_metadata, + ) + if result is None: + return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") + return web.json_response(result.model_dump(mode="json"), status=201) + + +@ROUTES.post("/api/assets") +async def upload_asset(request: web.Request) -> web.Response: + """Multipart/form-data endpoint for Asset uploads.""" + + if not (request.content_type or "").lower().startswith("multipart/"): + return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.") + + reader = await request.multipart() + + file_field = None + file_client_name: Optional[str] = None + tags_raw: list[str] = [] + provided_name: Optional[str] = None + user_metadata_raw: Optional[str] = None + file_written = 0 + + while True: + field = await reader.next() + if field is None: + break + + fname = getattr(field, "name", None) or "" + if fname == "file": + # Save to temp + uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") + unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) + os.makedirs(unique_dir, exist_ok=True) + tmp_path = os.path.join(unique_dir, ".upload.part") + + file_field = field + file_client_name = (field.filename or "").strip() + try: + with open(tmp_path, "wb") as f: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + f.write(chunk) + file_written += len(chunk) + except Exception: + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + finally: + return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.") + elif fname == "tags": + tags_raw.append((await field.text()) or "") + elif fname == "name": + provided_name = (await field.text()) or None + elif fname == "user_metadata": + user_metadata_raw = (await field.text()) or None + + if file_field is None: + return _error_response(400, "MISSING_FILE", "Form must include a 'file' part.") + + if file_written == 0: + try: + os.remove(tmp_path) + finally: + return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.") + + try: + spec = schemas_in.UploadAssetSpec.model_validate({ + "tags": tags_raw, + "name": provided_name, + "user_metadata": user_metadata_raw, + }) + except ValidationError as ve: + try: + os.remove(tmp_path) + finally: + return _validation_error_response("INVALID_BODY", ve) + + if spec.tags[0] == "models" and spec.tags[1] not in folder_paths.folder_names_and_paths: + return _error_response(400, "INVALID_BODY", f"unknown models category '{spec.tags[1]}'") + + try: + created = await assets_manager.upload_asset_from_temp_path( + spec, + temp_path=tmp_path, + client_filename=file_client_name, + ) + return web.json_response(created.model_dump(mode="json"), status=201) + except Exception: + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + finally: + return _error_response(500, "INTERNAL", "Unexpected server error.") + + @ROUTES.put("/api/assets/{id}") async def update_asset(request: web.Request) -> web.Response: asset_info_id_raw = request.match_info.get("id") @@ -104,27 +219,6 @@ async def update_asset(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) -@ROUTES.post("/api/assets/from-hash") -async def create_asset_from_hash(request: web.Request) -> web.Response: - try: - payload = await request.json() - body = schemas_in.CreateFromHashBody.model_validate(payload) - except ValidationError as ve: - return _validation_error_response("INVALID_BODY", ve) - except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") - - result = await assets_manager.create_asset_from_hash( - hash_str=body.hash, - name=body.name, - tags=body.tags, - user_metadata=body.user_metadata, - ) - if result is None: - return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") - return web.json_response(result.model_dump(mode="json"), status=201) - - @ROUTES.delete("/api/assets/{id}") async def delete_asset(request: web.Request) -> web.Response: asset_info_id_raw = request.match_info.get("id") diff --git a/app/api/schemas_in.py b/app/api/schemas_in.py index fa42146d3306..9694a67a62f6 100644 --- a/app/api/schemas_in.py +++ b/app/api/schemas_in.py @@ -172,3 +172,93 @@ def _normalize_roots(cls, v): out.append(r) seen.add(r) return out + + +class UploadAssetSpec(BaseModel): + """Upload Asset operation. + - tags: ordered; first is root ('models'|'input'|'output'); + if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths + - name: desired filename (optional); fallback will be the file hash + - user_metadata: arbitrary JSON object (optional) + """ + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + tags: list[str] = Field(..., min_length=1) + name: Optional[str] = Field(default=None, max_length=512) + user_metadata: dict[str, Any] = Field(default_factory=dict) + + @field_validator("tags", mode="before") + @classmethod + def _parse_tags(cls, v): + """ + Accepts a list of strings (possibly multiple form fields), + where each string can be: + - JSON array (e.g., '["models","loras","foo"]') + - comma-separated ('models, loras, foo') + - single token ('models') + Returns a normalized, deduplicated, ordered list. + """ + items: list[str] = [] + if v is None: + return [] + if isinstance(v, str): + v = [v] + + if isinstance(v, list): + for item in v: + if item is None: + continue + s = str(item).strip() + if not s: + continue + if s.startswith("["): + try: + arr = json.loads(s) + if isinstance(arr, list): + items.extend(str(x) for x in arr) + continue + except Exception: + pass # fallback to CSV parse below + items.extend([p for p in s.split(",") if p.strip()]) + else: + return [] + + # normalize + dedupe + norm = [] + seen = set() + for t in items: + tnorm = str(t).strip().lower() + if tnorm and tnorm not in seen: + seen.add(tnorm) + norm.append(tnorm) + return norm + + @field_validator("user_metadata", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v or {} + if isinstance(v, str): + s = v.strip() + if not s: + return {} + try: + parsed = json.loads(s) + except Exception as e: + raise ValueError(f"user_metadata must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("user_metadata must be a JSON object") + return parsed + return {} + + @model_validator(mode="after") + def _validate_order(self): + if not self.tags: + raise ValueError("tags must be provided and non-empty") + root = self.tags[0] + if root not in {"models", "input", "output"}: + raise ValueError("first tag must be one of: models, input, output") + if root == "models": + if len(self.tags) < 2: + raise ValueError("models uploads require a category tag as the second tag") + return self diff --git a/app/assets_manager.py b/app/assets_manager.py index 0c008b47148f..f6c839b8b0ac 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -25,8 +25,8 @@ get_asset_by_hash, create_asset_info_for_existing_asset, ) -from .api import schemas_out -from ._assets_helpers import get_name_and_tags_from_asset_path +from .api import schemas_in, schemas_out +from ._assets_helpers import get_name_and_tags_from_asset_path, resolve_destination_from_tags, ensure_within_base async def asset_exists(*, asset_hash: str) -> bool: @@ -173,6 +173,97 @@ async def resolve_asset_content_for_download( return abs_path, ctype, download_name +async def upload_asset_from_temp_path( + spec: schemas_in.UploadAssetSpec, + *, + temp_path: str, + client_filename: Optional[str] = None, +) -> schemas_out.AssetCreated: + """ + Finalize an uploaded temp file: + - compute blake3 hash + - resolve destination from tags + - decide filename (spec.name or client filename or hash) + - move file atomically + - ingest into DB (assets, locator state, asset_info + tags) + Returns a populated AssetCreated payload. + """ + + try: + digest = await hashing.blake3_hash(temp_path) + except Exception as e: + raise RuntimeError(f"failed to hash uploaded file: {e}") + asset_hash = "blake3:" + digest + + # Resolve destination + base_dir, subdirs = resolve_destination_from_tags(spec.tags) + dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir + os.makedirs(dest_dir, exist_ok=True) + + # Decide filename + desired_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) + dest_abs = os.path.abspath(os.path.join(dest_dir, desired_name)) + ensure_within_base(dest_abs, base_dir) + + # Content type based on final name + content_type = mimetypes.guess_type(desired_name, strict=False)[0] or "application/octet-stream" + + # Atomic move into place + try: + os.replace(temp_path, dest_abs) + except Exception as e: + raise RuntimeError(f"failed to move uploaded file into place: {e}") + + # Stat final file + try: + size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs) + except OSError as e: + raise RuntimeError(f"failed to stat destination file: {e}") + + # Ingest + build response + async with await create_session() as session: + result = await ingest_fs_asset( + session, + asset_hash=asset_hash, + abs_path=dest_abs, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=content_type, + info_name=os.path.basename(dest_abs), + owner_id="", + preview_hash=None, + user_metadata=spec.user_metadata or {}, + tags=spec.tags, + tag_origin="manual", + added_by=None, + require_existing_tags=False, + ) + info_id = result.get("asset_info_id") + if not info_id: + raise RuntimeError("failed to create asset metadata") + + pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id)) + if not pair: + raise RuntimeError("inconsistent DB state after ingest") + info, asset = pair + tag_names = await get_asset_tags(session, asset_info_id=info.id) + await session.commit() + + return schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=info.asset_hash, + size=int(asset.size_bytes), + mime_type=asset.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_hash=info.preview_hash, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=True, + ) + + async def update_asset( *, asset_info_id: int, From 6fade5da38e727497d674ecbcdac9877602ac620 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 26 Aug 2025 14:19:56 +0300 Subject: [PATCH 17/82] add AssetsResolver support --- alembic_db/versions/0001_assets.py | 45 +++++++--- app/_assets_helpers.py | 5 +- app/api/assets_routes.py | 23 +++-- app/assets_fetcher.py | 132 ++++++++++++++++++++++++++++ app/assets_manager.py | 44 +++++----- app/assets_scanner.py | 4 +- app/database/models.py | 53 +++++++---- app/database/services.py | 136 ++++++++++++++++++++--------- app/resolvers/__init__.py | 35 ++++++++ 9 files changed, 371 insertions(+), 106 deletions(-) create mode 100644 app/assets_fetcher.py create mode 100644 app/resolvers/__init__.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index b180edacc3a5..9fb80ea8c676 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -1,3 +1,4 @@ +# File: /alembic_db/versions/0001_assets.py """initial assets schema + per-asset state cache Revision ID: 0001_assets @@ -22,15 +23,12 @@ def upgrade() -> None: sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), sa.Column("mime_type", sa.String(length=255), nullable=True), sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"), - sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"), - sa.Column("storage_locator", sa.Text(), nullable=False), sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"), ) op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) - op.create_index("ix_assets_backend_locator", "assets", ["storage_backend", "storage_locator"]) # ASSETS_INFO: user-visible references (mutable metadata) op.create_table( @@ -52,11 +50,12 @@ def upgrade() -> None: op.create_index("ix_assets_info_name", "assets_info", ["name"]) op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"]) # TAGS: normalized tag vocabulary op.create_table( "tags", - sa.Column("name", sa.String(length=128), primary_key=True), + sa.Column("name", sa.String(length=512), primary_key=True), sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"), sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"), ) @@ -65,8 +64,8 @@ def upgrade() -> None: # ASSET_INFO_TAGS: many-to-many for tags on AssetInfo op.create_table( "asset_info_tags", - sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), - sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("asset_info_id", sa.Integer(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), sa.Column("added_by", sa.String(length=128), nullable=True), sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), @@ -75,15 +74,15 @@ def upgrade() -> None: op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) - # ASSET_LOCATOR_STATE: 1:1 filesystem metadata(for fast integrity checking) for an Asset records + # ASSET_CACHE_STATE: 1:1 local cache metadata for an Asset op.create_table( - "asset_locator_state", + "asset_cache_state", sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True), + sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file sa.Column("mtime_ns", sa.BigInteger(), nullable=True), - sa.Column("etag", sa.String(length=256), nullable=True), - sa.Column("last_modified", sa.String(length=128), nullable=True), - sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), ) + op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting op.create_table( @@ -102,6 +101,21 @@ def upgrade() -> None: op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) + # ASSET_LOCATIONS: remote locations per asset + op.create_table( + "asset_locations", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False), + sa.Column("provider", sa.String(length=32), nullable=False), # e.g., "gcs" + sa.Column("locator", sa.Text(), nullable=False), # e.g., "gs://bucket/path/to/blob" + sa.Column("expected_size_bytes", sa.BigInteger(), nullable=True), + sa.Column("etag", sa.String(length=256), nullable=True), + sa.Column("last_modified", sa.String(length=128), nullable=True), + sa.UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"), + ) + op.create_index("ix_asset_locations_hash", "asset_locations", ["asset_hash"]) + op.create_index("ix_asset_locations_provider", "asset_locations", ["provider"]) + # Tags vocabulary for models tags_table = sa.table( "tags", @@ -143,13 +157,18 @@ def upgrade() -> None: def downgrade() -> None: + op.drop_index("ix_asset_locations_provider", table_name="asset_locations") + op.drop_index("ix_asset_locations_hash", table_name="asset_locations") + op.drop_table("asset_locations") + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") op.drop_table("asset_info_meta") - op.drop_table("asset_locator_state") + op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_table("asset_cache_state") op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") @@ -159,6 +178,7 @@ def downgrade() -> None: op.drop_table("tags") op.drop_constraint("uq_assets_info_hash_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_owner_name", table_name="assets_info") op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") op.drop_index("ix_assets_info_created_at", table_name="assets_info") op.drop_index("ix_assets_info_name", table_name="assets_info") @@ -166,6 +186,5 @@ def downgrade() -> None: op.drop_index("ix_assets_info_owner_id", table_name="assets_info") op.drop_table("assets_info") - op.drop_index("ix_assets_backend_locator", table_name="assets") op.drop_index("ix_assets_mime_type", table_name="assets") op.drop_table("assets") diff --git a/app/_assets_helpers.py b/app/_assets_helpers.py index 4f1ad4446b1c..9fd3600f170d 100644 --- a/app/_assets_helpers.py +++ b/app/_assets_helpers.py @@ -105,7 +105,10 @@ def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: if root == "models": if len(tags) < 2: raise ValueError("at least two tags required for model asset") - bases = folder_paths.folder_names_and_paths[tags[1]][0] + try: + bases = folder_paths.folder_names_and_paths[tags[1]][0] + except KeyError: + raise ValueError(f"unknown model category '{tags[1]}'") if not bases: raise ValueError(f"no base path configured for category '{tags[1]}'") base_dir = os.path.abspath(bases[0]) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 7fbc694671c3..c0dde790987a 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -9,7 +9,7 @@ import folder_paths from .. import assets_manager, assets_scanner -from . import schemas_in +from . import schemas_in, schemas_out ROUTES = web.RouteTableDef() @@ -20,6 +20,9 @@ async def head_asset_by_hash(request: web.Request) -> web.Response: hash_str = request.match_info.get("hash", "").strip().lower() if not hash_str or ":" not in hash_str: return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + algo, digest = hash_str.split(":", 1) + if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") exists = await assets_manager.asset_exists(asset_hash=hash_str) return web.Response(status=200 if exists else 404) @@ -69,7 +72,7 @@ async def download_asset_content(request: web.Request) -> web.Response: except FileNotFoundError: return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.") - quoted = filename.replace('"', "'") + quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'") cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}' resp = web.FileResponse(abs_path) @@ -115,6 +118,7 @@ async def upload_asset(request: web.Request) -> web.Response: user_metadata_raw: Optional[str] = None file_written = 0 + tmp_path: Optional[str] = None while True: field = await reader.next() if field is None: @@ -173,6 +177,8 @@ async def upload_asset(request: web.Request) -> web.Response: return _validation_error_response("INVALID_BODY", ve) if spec.tags[0] == "models" and spec.tags[1] not in folder_paths.folder_names_and_paths: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) return _error_response(400, "INVALID_BODY", f"unknown models category '{spec.tags[1]}'") try: @@ -182,12 +188,14 @@ async def upload_asset(request: web.Request) -> web.Response: client_filename=file_client_name, ) return web.json_response(created.model_dump(mode="json"), status=201) + except ValueError: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + return _error_response(400, "BAD_REQUEST", "Invalid inputs.") except Exception: - try: - if os.path.exists(tmp_path): - os.remove(tmp_path) - finally: - return _error_response(500, "INTERNAL", "Unexpected server error.") + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + return _error_response(500, "INTERNAL", "Unexpected server error.") @ROUTES.put("/api/assets/{id}") @@ -341,6 +349,7 @@ async def get_asset_scan_status(request: web.Request) -> web.Response: states = assets_scanner.current_statuses() if root in {"models", "input", "output"}: states = [s for s in states.scans if s.root == root] # type: ignore + states = schemas_out.AssetScanStatusResponse(scans=states) return web.json_response(states.model_dump(mode="json"), status=200) diff --git a/app/assets_fetcher.py b/app/assets_fetcher.py new file mode 100644 index 000000000000..ea1c8ed00420 --- /dev/null +++ b/app/assets_fetcher.py @@ -0,0 +1,132 @@ +from __future__ import annotations +import asyncio +import os +import tempfile +from typing import Optional +import mimetypes +import aiohttp + +from .storage.hashing import blake3_hash_sync +from .database.db import create_session +from .database.services import ingest_fs_asset, get_cache_state_by_asset_hash +from .resolvers import resolve_asset +from ._assets_helpers import resolve_destination_from_tags, ensure_within_base + +_FETCH_LOCKS: dict[str, asyncio.Lock] = {} + + +def _sanitize_filename(name: str) -> str: + return os.path.basename((name or "").strip()) or "file" + + +async def ensure_asset_cached( + asset_hash: str, + *, + preferred_name: Optional[str] = None, + tags_hint: Optional[list[str]] = None, +) -> str: + """ + Ensure there is a verified local file for `asset_hash` in the correct Comfy folder. + Policy: + - Resolver must provide valid tags (root and, for models, category). + - If target path already exists: + * if hash matches -> reuse & ingest + * else -> remove and overwrite with the correct content + """ + lock = _FETCH_LOCKS.setdefault(asset_hash, asyncio.Lock()) + async with lock: + # 1) If we already have a state -> trust the path + async with await create_session() as sess: + state = await get_cache_state_by_asset_hash(sess, asset_hash=asset_hash) + if state and os.path.isfile(state.file_path): + return state.file_path + + # 2) Resolve remote location + placement hints (must include valid tags) + res = await resolve_asset(asset_hash) + if not res: + raise FileNotFoundError(f"No resolver/locations for {asset_hash}") + + placement_tags = tags_hint or res.tags + if not placement_tags: + raise ValueError(f"Resolver did not provide placement tags for {asset_hash}") + + name_hint = res.filename or preferred_name or asset_hash.replace(":", "_") + safe_name = _sanitize_filename(name_hint) + + # 3) Map tags -> destination (strict: raises if invalid root or models category) + base_dir, subdirs = resolve_destination_from_tags(placement_tags) # may raise + dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir + os.makedirs(dest_dir, exist_ok=True) + + final_path = os.path.abspath(os.path.join(dest_dir, safe_name)) + ensure_within_base(final_path, base_dir) + + # 4) If target path exists, try to reuse; else delete invalid cache + if os.path.exists(final_path) and os.path.isfile(final_path): + existing_digest = blake3_hash_sync(final_path) + if f"blake3:{existing_digest}" == asset_hash: + size_bytes = os.path.getsize(final_path) + mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000)) + async with await create_session() as sess: + await ingest_fs_asset( + sess, + asset_hash=asset_hash, + abs_path=final_path, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=None, + info_name=None, + tags=(), + ) + await sess.commit() + return final_path + else: + # Invalid cache: remove before re-downloading + os.remove(final_path) + + # 5) Download to temp next to destination + timeout = aiohttp.ClientTimeout(total=60 * 30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(res.download_url, headers=dict(res.headers)) as resp: + resp.raise_for_status() + cl = resp.headers.get("Content-Length") + if res.expected_size and cl and int(cl) != int(res.expected_size): + raise ValueError("server Content-Length does not match expected size") + with tempfile.NamedTemporaryFile("wb", delete=False, dir=dest_dir) as tmp: + tmp_path = tmp.name + async for chunk in resp.content.iter_chunked(8 * 1024 * 1024): + if chunk: + tmp.write(chunk) + + # 6) Verify content hash + digest = blake3_hash_sync(tmp_path) + canonical = f"blake3:{digest}" + if canonical != asset_hash: + try: + os.remove(tmp_path) + finally: + raise ValueError(f"Hash mismatch: expected {asset_hash}, got {canonical}") + + # 7) Atomically move into place (we already removed an invalid file if it existed) + if os.path.exists(final_path): + os.remove(final_path) + os.replace(tmp_path, final_path) + + # 8) Record identity + cache state (+ mime type) + size_bytes = os.path.getsize(final_path) + mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000)) + mime_type = mimetypes.guess_type(safe_name, strict=False)[0] + async with await create_session() as sess: + await ingest_fs_asset( + sess, + asset_hash=asset_hash, + abs_path=final_path, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=mime_type, + info_name=None, + tags=(), + ) + await sess.commit() + + return final_path diff --git a/app/assets_manager.py b/app/assets_manager.py index f6c839b8b0ac..72d299467f66 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -26,7 +26,8 @@ create_asset_info_for_existing_asset, ) from .api import schemas_in, schemas_out -from ._assets_helpers import get_name_and_tags_from_asset_path, resolve_destination_from_tags, ensure_within_base +from ._assets_helpers import get_name_and_tags_from_asset_path, ensure_within_base, resolve_destination_from_tags +from .assets_fetcher import ensure_asset_cached async def asset_exists(*, asset_hash: str) -> bool: @@ -46,17 +47,17 @@ def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> file_name=asset_name, file_path=file_path, ) - except ValueError: - logging.exception("Cant parse '%s' as an asset file path.", file_path) + except ValueError as e: + logging.warning("Skipping non-asset path %s: %s", file_path, e) async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None: """Adds a local asset to the DB. If already present and unchanged, does nothing. Notes: - - Uses absolute path as the canonical locator for the 'fs' backend. + - Uses absolute path as the canonical locator for the cache backend. - Computes BLAKE3 only when the fast existence check indicates it's needed. - - This function ensures the identity row and seeds mtime in asset_locator_state. + - This function ensures the identity row and seeds mtime in asset_cache_state. """ abs_path = os.path.abspath(file_path) size_bytes, mtime_ns = _get_size_mtime_ns(abs_path) @@ -125,7 +126,7 @@ async def list_assets( size=int(asset.size_bytes) if asset else None, mime_type=asset.mime_type if asset else None, tags=tags, - preview_url=f"/api/v1/assets/{info.id}/content", + preview_url=f"/api/assets/{info.id}/content", created_at=info.created_at, updated_at=info.updated_at, last_access_time=info.last_access_time, @@ -143,12 +144,11 @@ async def resolve_asset_content_for_download( *, asset_info_id: int ) -> tuple[str, str, str]: """ - Returns (abs_path, content_type, download_name) for the given AssetInfo id. + Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time. Also touches last_access_time (only_if_newer). + Ensures the local cache is present (uses resolver if needed). Raises: - ValueError if AssetInfo not found - NotImplementedError for unsupported backend - FileNotFoundError if underlying file does not exist (fs backend) + ValueError if AssetInfo cannot be found """ async with await create_session() as session: pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id) @@ -156,21 +156,19 @@ async def resolve_asset_content_for_download( raise ValueError(f"AssetInfo {asset_info_id} not found") info, asset = pair + tag_names = await get_asset_tags(session, asset_info_id=info.id) - if asset.storage_backend != "fs": - # Future: support http/s3/gcs/... - raise NotImplementedError(f"backend {asset.storage_backend!r} not supported yet") - - abs_path = os.path.abspath(asset.storage_locator) - if not os.path.exists(abs_path): - raise FileNotFoundError(abs_path) + # Ensure cached (download if missing) + preferred_name = info.name or info.asset_hash.split(":", 1)[-1] + abs_path = await ensure_asset_cached(info.asset_hash, preferred_name=preferred_name, tags_hint=tag_names) + async with await create_session() as session: await touch_asset_info_by_id(session, asset_info_id=asset_info_id) await session.commit() - ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" - download_name = info.name or os.path.basename(abs_path) - return abs_path, ctype, download_name + ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" + download_name = info.name or os.path.basename(abs_path) + return abs_path, ctype, download_name async def upload_asset_from_temp_path( @@ -238,7 +236,7 @@ async def upload_asset_from_temp_path( added_by=None, require_existing_tags=False, ) - info_id = result.get("asset_info_id") + info_id = result["asset_info_id"] if not info_id: raise RuntimeError("failed to create asset metadata") @@ -260,7 +258,7 @@ async def upload_asset_from_temp_path( preview_hash=info.preview_hash, created_at=info.created_at, last_access_time=info.last_access_time, - created_new=True, + created_new=result["asset_created"], ) @@ -416,7 +414,7 @@ def _get_size_mtime_ns(path: str) -> tuple[int, int]: return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) -def _safe_filename(name: Optional[str] , fallback: str) -> str: +def _safe_filename(name: Optional[str], fallback: str) -> str: n = os.path.basename((name or "").strip() or fallback) if n: return n diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 691472156cd0..5bafd6bb7c29 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -147,9 +147,7 @@ async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None: prog.last_error = str(exc) finally: prog.finished_at = time.time() - t = RUNNING_TASKS.get(root) - if t and t.done(): - RUNNING_TASKS.pop(root, None) + RUNNING_TASKS.pop(root, None) async def _scan_models(prog: ScanProgress) -> None: diff --git a/app/database/models.py b/app/database/models.py index 20b88ca68705..d964a5226deb 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -45,8 +45,6 @@ class Asset(Base): size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) mime_type: Mapped[str | None] = mapped_column(String(255)) refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) - storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs") - storage_locator: Mapped[str] = mapped_column(Text, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=False), nullable=False, default=utcnow ) @@ -71,48 +69,71 @@ class Asset(Base): viewonly=True, ) - locator_state: Mapped["AssetLocatorState | None"] = relationship( + cache_state: Mapped["AssetCacheState | None"] = relationship( back_populates="asset", uselist=False, cascade="all, delete-orphan", passive_deletes=True, ) + locations: Mapped[list["AssetLocation"]] = relationship( + back_populates="asset", + cascade="all, delete-orphan", + passive_deletes=True, + ) + __table_args__ = ( Index("ix_assets_mime_type", "mime_type"), - Index("ix_assets_backend_locator", "storage_backend", "storage_locator"), ) def to_dict(self, include_none: bool = False) -> dict[str, Any]: return to_dict(self, include_none=include_none) def __repr__(self) -> str: - return f"" + return f"" -class AssetLocatorState(Base): - __tablename__ = "asset_locator_state" +class AssetCacheState(Base): + __tablename__ = "asset_cache_state" asset_hash: Mapped[str] = mapped_column( String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True ) - # For fs backends: nanosecond mtime; nullable if not applicable + file_path: Mapped[str] = mapped_column(Text, nullable=False) mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) - # For HTTP/S3/GCS/Azure, etc.: optional validators - etag: Mapped[str | None] = mapped_column(String(256), nullable=True) - last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True) - asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False) + asset: Mapped["Asset"] = relationship(back_populates="cache_state", uselist=False) __table_args__ = ( - CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"), + Index("ix_asset_cache_state_file_path", "file_path"), + CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), ) def to_dict(self, include_none: bool = False) -> dict[str, Any]: return to_dict(self, include_none=include_none) def __repr__(self) -> str: - return f"" + return f"" + + +class AssetLocation(Base): + __tablename__ = "asset_locations" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False) + provider: Mapped[str] = mapped_column(String(32), nullable=False) # "gcs" + locator: Mapped[str] = mapped_column(Text, nullable=False) # "gs://bucket/object" + expected_size_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + etag: Mapped[str | None] = mapped_column(String(256), nullable=True) + last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True) + + asset: Mapped["Asset"] = relationship(back_populates="locations") + + __table_args__ = ( + UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"), + Index("ix_asset_locations_hash", "asset_hash"), + Index("ix_asset_locations_provider", "provider"), + ) class AssetInfo(Base): @@ -220,7 +241,7 @@ class AssetInfoTag(Base): Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True ) tag_name: Mapped[str] = mapped_column( - String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True + String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True ) origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") added_by: Mapped[str | None] = mapped_column(String(128)) @@ -240,7 +261,7 @@ class AssetInfoTag(Base): class Tag(Base): __tablename__ = "tags" - name: Mapped[str] = mapped_column(String(128), primary_key=True) + name: Mapped[str] = mapped_column(String(512), primary_key=True) tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") asset_info_links: Mapped[list["AssetInfoTag"]] = relationship( diff --git a/app/database/services.py b/app/database/services.py index 34029b139d89..5f1ffffbfd7d 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -12,7 +12,7 @@ from sqlalchemy.orm import contains_eager from sqlalchemy.exc import IntegrityError -from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta +from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation from .timeutil import utcnow from .._assets_helpers import normalize_tags @@ -38,30 +38,24 @@ async def check_fs_asset_exists_quick( mtime_ns: Optional[int] = None, ) -> bool: """ - Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path, + Returns 'True' if there is already AssetCacheState record that matches this absolute path, AND (if provided) mtime_ns matches stored locator-state, AND (if provided) size_bytes matches verified size when known. """ locator = os.path.abspath(file_path) - stmt = select(sa.literal(True)).select_from(Asset) + stmt = select(sa.literal(True)).select_from(AssetCacheState).join( + Asset, Asset.hash == AssetCacheState.asset_hash + ).where(AssetCacheState.file_path == locator).limit(1) - conditions = [ - Asset.storage_backend == "fs", - Asset.storage_locator == locator, - ] - - # If size_bytes provided require equality when the asset has a verified (non-zero) size. - # If verified size is 0 (unknown), we don't force equality. - if size_bytes is not None: - conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes))) - - # If mtime_ns provided require the locator-state to exist and match. + conds = [] if mtime_ns is not None: - stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash) - conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns)) + conds.append(AssetCacheState.mtime_ns == int(mtime_ns)) + if size_bytes is not None: + conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes))) - stmt = stmt.where(*conditions).limit(1) + if conds: + stmt = stmt.where(*conds) row = (await session.execute(stmt)).first() return row is not None @@ -85,11 +79,11 @@ async def ingest_fs_asset( require_existing_tags: bool = False, ) -> dict: """ - Creates or updates Asset record for a local (fs) asset. + Upsert Asset identity row + cache state pointing at local file. Always: - Insert Asset if missing; else update size_bytes (and updated_at) if different. - - Insert AssetLocatorState if missing; else update mtime_ns if different. + - Insert AssetCacheState if missing; else update mtime_ns if different. Optionally (when info_name is provided): - Create an AssetInfo (no refcount changes). @@ -126,8 +120,6 @@ async def ingest_fs_asset( size_bytes=int(size_bytes), mime_type=mime_type, refcount=0, - storage_backend="fs", - storage_locator=locator, created_at=datetime_now, updated_at=datetime_now, ) @@ -145,21 +137,19 @@ async def ingest_fs_asset( if mime_type and existing.mime_type != mime_type: existing.mime_type = mime_type changed = True - if existing.storage_locator != locator: - existing.storage_locator = locator - changed = True if changed: existing.updated_at = datetime_now out["asset_updated"] = True else: logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash) - # ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ---- + # ---- Step 2: INSERT/UPDATE AssetCacheState (mtime_ns, file_path) ---- with contextlib.suppress(IntegrityError): async with session.begin_nested(): session.add( - AssetLocatorState( + AssetCacheState( asset_hash=asset_hash, + file_path=locator, mtime_ns=int(mtime_ns), ) ) @@ -167,11 +157,17 @@ async def ingest_fs_asset( out["state_created"] = True if not out["state_created"]: - state = await session.get(AssetLocatorState, asset_hash) + state = await session.get(AssetCacheState, asset_hash) if state is not None: - desired_mtime = int(mtime_ns) - if state.mtime_ns != desired_mtime: - state.mtime_ns = desired_mtime + changed = False + if state.file_path != locator: + state.file_path = locator + changed = True + if state.mtime_ns != int(mtime_ns): + state.mtime_ns = int(mtime_ns) + changed = True + if changed: + await session.flush() out["state_updated"] = True else: logging.error("Locator state missing for %s after conflict; skipping update.", asset_hash) @@ -278,11 +274,10 @@ async def touch_asset_infos_by_fs_path( stmt = sa.update(AssetInfo).where( sa.exists( sa.select(sa.literal(1)) - .select_from(Asset) + .select_from(AssetCacheState) .where( - Asset.hash == AssetInfo.asset_hash, - Asset.storage_backend == "fs", - Asset.storage_locator == locator, + AssetCacheState.asset_hash == AssetInfo.asset_hash, + AssetCacheState.file_path == locator, ) ) ) @@ -337,13 +332,6 @@ async def list_asset_infos_page( We purposely collect tags in a separate (single) query to avoid row explosion. """ - # Clamp - if limit <= 0: - limit = 1 - if limit > 100: - limit = 100 - if offset < 0: - offset = 0 # Build base query base = ( @@ -419,6 +407,66 @@ async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: in return pair[0], pair[1] +async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]: + return await session.get(AssetCacheState, asset_hash) + + +async def list_asset_locations( + session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None +) -> list[AssetLocation]: + stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash) + if provider: + stmt = stmt.where(AssetLocation.provider == provider) + return (await session.execute(stmt)).scalars().all() + + +async def upsert_asset_location( + session: AsyncSession, + *, + asset_hash: str, + provider: str, + locator: str, + expected_size_bytes: Optional[int] = None, + etag: Optional[str] = None, + last_modified: Optional[str] = None, +) -> AssetLocation: + loc = ( + await session.execute( + select(AssetLocation).where( + AssetLocation.asset_hash == asset_hash, + AssetLocation.provider == provider, + AssetLocation.locator == locator, + ).limit(1) + ) + ).scalars().first() + if loc: + changed = False + if expected_size_bytes is not None and loc.expected_size_bytes != expected_size_bytes: + loc.expected_size_bytes = expected_size_bytes + changed = True + if etag is not None and loc.etag != etag: + loc.etag = etag + changed = True + if last_modified is not None and loc.last_modified != last_modified: + loc.last_modified = last_modified + changed = True + if changed: + await session.flush() + return loc + + loc = AssetLocation( + asset_hash=asset_hash, + provider=provider, + locator=locator, + expected_size_bytes=expected_size_bytes, + etag=etag, + last_modified=last_modified, + ) + session.add(loc) + await session.flush() + return loc + + async def create_asset_info_for_existing_asset( session: AsyncSession, *, @@ -925,7 +973,8 @@ def _project_kv(key: str, value: Any) -> list[dict]: rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) elif isinstance(value, (int, float, Decimal)): # store numeric; SQLAlchemy will coerce to Numeric - rows.append({"key": key, "ordinal": 0, "val_num": value}) + num = value if isinstance(value, Decimal) else Decimal(str(value)) + rows.append({"key": key, "ordinal": 0, "val_num": num}) elif isinstance(value, str): rows.append({"key": key, "ordinal": 0, "val_str": value}) else: @@ -943,7 +992,8 @@ def _project_kv(key: str, value: Any) -> list[dict]: elif isinstance(x, bool): rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) elif isinstance(x, (int, float, Decimal)): - rows.append({"key": key, "ordinal": i, "val_num": x}) + num = x if isinstance(x, Decimal) else Decimal(str(x)) + rows.append({"key": key, "ordinal": i, "val_num": num}) elif isinstance(x, str): rows.append({"key": key, "ordinal": i, "val_str": x}) else: diff --git a/app/resolvers/__init__.py b/app/resolvers/__init__.py new file mode 100644 index 000000000000..c489ebad7c57 --- /dev/null +++ b/app/resolvers/__init__.py @@ -0,0 +1,35 @@ +import contextlib +from dataclasses import dataclass +from typing import Protocol, Optional, Mapping + + +@dataclass +class ResolveResult: + provider: str # e.g., "gcs" + download_url: str # fully-qualified URL to fetch bytes + headers: Mapping[str, str] # optional auth headers etc + expected_size: Optional[int] = None + tags: Optional[list[str]] = None # e.g. ["models","vae","subdir"] + filename: Optional[str] = None # preferred basename + +class AssetResolver(Protocol): + provider: str + async def resolve(self, asset_hash: str) -> Optional[ResolveResult]: ... + + +_REGISTRY: list[AssetResolver] = [] + + +def register_resolver(resolver: AssetResolver) -> None: + """Append Resolver with simple de-dup per provider.""" + global _REGISTRY + _REGISTRY = [r for r in _REGISTRY if r.provider != resolver.provider] + [resolver] + + +async def resolve_asset(asset_hash: str) -> Optional[ResolveResult]: + for r in _REGISTRY: + with contextlib.suppress(Exception): # For Resolver failure we just try the next one + res = await r.resolve(asset_hash) + if res: + return res + return None From 7c1b0be49661768207f6a60c047ddd4066784925 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 27 Aug 2025 09:58:12 +0300 Subject: [PATCH 18/82] add Get Asset endpoint --- app/api/assets_routes.py | 27 ++++++++++++++++++++++----- app/api/schemas_out.py | 7 +++++-- app/assets_manager.py | 22 ++++++++++++++++++++++ app/database/services.py | 37 ++++++++++++++++++++++++++++++++++++- 4 files changed, 85 insertions(+), 8 deletions(-) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index c0dde790987a..2ca2932e591c 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -49,7 +49,7 @@ async def list_assets(request: web.Request) -> web.Response: return web.json_response(payload.model_dump(mode="json")) -@ROUTES.get("/api/assets/{id}/content") +@ROUTES.get("/api/assets/{id:\\d+}/content") async def download_asset_content(request: web.Request) -> web.Response: asset_info_id_raw = request.match_info.get("id") try: @@ -198,7 +198,24 @@ async def upload_asset(request: web.Request) -> web.Response: return _error_response(500, "INTERNAL", "Unexpected server error.") -@ROUTES.put("/api/assets/{id}") +@ROUTES.get("/api/assets/{id:\\d+}") +async def get_asset(request: web.Request) -> web.Response: + asset_info_id_raw = request.match_info.get("id") + try: + asset_info_id = int(asset_info_id_raw) + except Exception: + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + + try: + result = await assets_manager.get_asset(asset_info_id=asset_info_id) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.put("/api/assets/{id:\\d+}") async def update_asset(request: web.Request) -> web.Response: asset_info_id_raw = request.match_info.get("id") try: @@ -227,7 +244,7 @@ async def update_asset(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) -@ROUTES.delete("/api/assets/{id}") +@ROUTES.delete("/api/assets/{id:\\d+}") async def delete_asset(request: web.Request) -> web.Response: asset_info_id_raw = request.match_info.get("id") try: @@ -267,7 +284,7 @@ async def get_tags(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json")) -@ROUTES.post("/api/assets/{id}/tags") +@ROUTES.post("/api/assets/{id:\\d+}/tags") async def add_asset_tags(request: web.Request) -> web.Response: asset_info_id_raw = request.match_info.get("id") try: @@ -298,7 +315,7 @@ async def add_asset_tags(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) -@ROUTES.delete("/api/assets/{id}/tags") +@ROUTES.delete("/api/assets/{id:\\d+}/tags") async def delete_asset_tags(request: web.Request) -> web.Response: asset_info_id_raw = request.match_info.get("id") try: diff --git a/app/api/schemas_out.py b/app/api/schemas_out.py index 8aca0ee012b7..1b41d8021d32 100644 --- a/app/api/schemas_out.py +++ b/app/api/schemas_out.py @@ -43,7 +43,7 @@ def _ser_updated(self, v: Optional[datetime], _info): return v.isoformat() if v else None -class AssetCreated(BaseModel): +class AssetDetail(BaseModel): id: int name: str asset_hash: str @@ -54,7 +54,6 @@ class AssetCreated(BaseModel): preview_hash: Optional[str] = None created_at: Optional[datetime] = None last_access_time: Optional[datetime] = None - created_new: bool model_config = ConfigDict(from_attributes=True) @@ -63,6 +62,10 @@ def _ser_dt(self, v: Optional[datetime], _info): return v.isoformat() if v else None +class AssetCreated(AssetDetail): + created_new: bool + + class TagUsage(BaseModel): name: str count: int diff --git a/app/assets_manager.py b/app/assets_manager.py index 72d299467f66..e61895c8ad7e 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -24,6 +24,7 @@ asset_exists_by_hash, get_asset_by_hash, create_asset_info_for_existing_asset, + fetch_asset_info_asset_and_tags, ) from .api import schemas_in, schemas_out from ._assets_helpers import get_name_and_tags_from_asset_path, ensure_within_base, resolve_destination_from_tags @@ -140,6 +141,27 @@ async def list_assets( ) +async def get_asset(*, asset_info_id: int) -> schemas_out.AssetDetail: + async with await create_session() as session: + res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id) + if not res: + raise ValueError(f"AssetInfo {asset_info_id} not found") + info, asset, tag_names = res + + return schemas_out.AssetDetail( + id=info.id, + name=info.name, + asset_hash=info.asset_hash, + size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, + mime_type=asset.mime_type if asset else None, + tags=tag_names, + preview_hash=info.preview_hash, + user_metadata=info.user_metadata or {}, + created_at=info.created_at, + last_access_time=info.last_access_time, + ) + + async def resolve_asset_content_for_download( *, asset_info_id: int ) -> tuple[str, str, str]: diff --git a/app/database/services.py b/app/database/services.py index 5f1ffffbfd7d..95a2d07ab979 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, delete, exists, func -from sqlalchemy.orm import contains_eager +from sqlalchemy.orm import contains_eager, noload from sqlalchemy.exc import IntegrityError from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation @@ -407,6 +407,41 @@ async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: in return pair[0], pair[1] +async def fetch_asset_info_asset_and_tags( + session: AsyncSession, + *, + asset_info_id: int, +) -> Optional[tuple[AssetInfo, Asset, list[str]]]: + """Fetch AssetInfo, its Asset, and all tag names. + + Returns: + (AssetInfo, Asset, [tag_names]) or None if the asset_info_id does not exist. + """ + stmt = ( + select(AssetInfo, Asset, Tag.name) + .join(Asset, Asset.hash == AssetInfo.asset_hash) + .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) + .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) + .where(AssetInfo.id == asset_info_id) + .options(noload(AssetInfo.tags)) + .order_by(Tag.name.asc()) + ) + + rows = (await session.execute(stmt)).all() + if not rows: + return None + + # First row contains the mapped entities; tags may repeat across rows + first_info, first_asset, _ = rows[0] + tags: list[str] = [] + seen: set[str] = set() + for _info, _asset, tag_name in rows: + if tag_name and tag_name not in seen: + seen.add(tag_name) + tags.append(tag_name) + return first_info, first_asset, tags + + async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]: return await session.get(AssetCacheState, asset_hash) From 026b7f209c5bb6068d40c6ac614715d0bcff68a3 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 27 Aug 2025 19:47:55 +0300 Subject: [PATCH 19/82] add "--multi-user" support --- app/_assets_helpers.py | 12 ++++++ app/api/assets_routes.py | 31 ++++++++++---- app/assets_manager.py | 54 ++++++++++++++++++++---- app/database/services.py | 91 +++++++++++++++++++++++----------------- server.py | 4 +- 5 files changed, 136 insertions(+), 56 deletions(-) diff --git a/app/_assets_helpers.py b/app/_assets_helpers.py index 9fd3600f170d..4a8d39625376 100644 --- a/app/_assets_helpers.py +++ b/app/_assets_helpers.py @@ -2,8 +2,12 @@ from pathlib import Path from typing import Optional, Literal, Sequence +import sqlalchemy as sa + import folder_paths +from .database.models import AssetInfo + def get_comfy_models_folders() -> list[tuple[str, list[str]]]: """Build a list of (folder_name, base_paths[]) categories that are configured for model locations. @@ -133,3 +137,11 @@ def ensure_within_base(candidate: str, base: str) -> None: raise ValueError("destination escapes base directory") except Exception: raise ValueError("invalid destination path") + + +def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads.""" + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetInfo.owner_id == "" + return AssetInfo.owner_id.in_(["", owner_id]) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 2ca2932e591c..b5c25dcecb2a 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -8,11 +8,12 @@ import folder_paths -from .. import assets_manager, assets_scanner +from .. import assets_manager, assets_scanner, user_manager from . import schemas_in, schemas_out ROUTES = web.RouteTableDef() +UserManager: Optional[user_manager.UserManager] = None @ROUTES.head("/api/assets/hash/{hash}") @@ -45,6 +46,7 @@ async def list_assets(request: web.Request) -> web.Response: offset=q.offset, sort=q.sort, order=q.order, + owner_id=UserManager.get_request_user_id(request), ) return web.json_response(payload.model_dump(mode="json")) @@ -63,7 +65,8 @@ async def download_asset_content(request: web.Request) -> web.Response: try: abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download( - asset_info_id=asset_info_id + asset_info_id=asset_info_id, + owner_id=UserManager.get_request_user_id(request), ) except ValueError as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve)) @@ -96,6 +99,7 @@ async def create_asset_from_hash(request: web.Request) -> web.Response: name=body.name, tags=body.tags, user_metadata=body.user_metadata, + owner_id=UserManager.get_request_user_id(request), ) if result is None: return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") @@ -186,6 +190,7 @@ async def upload_asset(request: web.Request) -> web.Response: spec, temp_path=tmp_path, client_filename=file_client_name, + owner_id=UserManager.get_request_user_id(request), ) return web.json_response(created.model_dump(mode="json"), status=201) except ValueError: @@ -207,7 +212,10 @@ async def get_asset(request: web.Request) -> web.Response: return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") try: - result = await assets_manager.get_asset(asset_info_id=asset_info_id) + result = await assets_manager.get_asset( + asset_info_id=asset_info_id, + owner_id=UserManager.get_request_user_id(request), + ) except ValueError as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) except Exception: @@ -236,8 +244,9 @@ async def update_asset(request: web.Request) -> web.Response: name=body.name, tags=body.tags, user_metadata=body.user_metadata, + owner_id=UserManager.get_request_user_id(request), ) - except ValueError as ve: + except (ValueError, PermissionError) as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) except Exception: return _error_response(500, "INTERNAL", "Unexpected server error.") @@ -253,7 +262,10 @@ async def delete_asset(request: web.Request) -> web.Response: return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") try: - deleted = await assets_manager.delete_asset_reference(asset_info_id=asset_info_id) + deleted = await assets_manager.delete_asset_reference( + asset_info_id=asset_info_id, + owner_id=UserManager.get_request_user_id(request), + ) except Exception: return _error_response(500, "INTERNAL", "Unexpected server error.") @@ -280,6 +292,7 @@ async def get_tags(request: web.Request) -> web.Response: offset=query.offset, order=query.order, include_zero=query.include_zero, + owner_id=UserManager.get_request_user_id(request), ) return web.json_response(result.model_dump(mode="json")) @@ -306,8 +319,9 @@ async def add_asset_tags(request: web.Request) -> web.Response: tags=data.tags, origin="manual", added_by=None, + owner_id=UserManager.get_request_user_id(request), ) - except ValueError as ve: + except (ValueError, PermissionError) as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) except Exception: return _error_response(500, "INTERNAL", "Unexpected server error.") @@ -335,6 +349,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response: result = await assets_manager.remove_tags_from_asset( asset_info_id=asset_info_id, tags=data.tags, + owner_id=UserManager.get_request_user_id(request), ) except ValueError as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) @@ -370,7 +385,9 @@ async def get_asset_scan_status(request: web.Request) -> web.Response: return web.json_response(states.model_dump(mode="json"), status=200) -def register_assets_routes(app: web.Application) -> None: +def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: + global UserManager + UserManager = user_manager_instance app.add_routes(ROUTES) diff --git a/app/assets_manager.py b/app/assets_manager.py index e61895c8ad7e..8cdf1fffc745 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -25,9 +25,14 @@ get_asset_by_hash, create_asset_info_for_existing_asset, fetch_asset_info_asset_and_tags, + get_asset_info_by_id, ) from .api import schemas_in, schemas_out -from ._assets_helpers import get_name_and_tags_from_asset_path, ensure_within_base, resolve_destination_from_tags +from ._assets_helpers import ( + get_name_and_tags_from_asset_path, + ensure_within_base, + resolve_destination_from_tags, +) from .assets_fetcher import ensure_asset_cached @@ -98,6 +103,7 @@ async def list_assets( offset: int = 0, sort: str = "created_at", order: str = "desc", + owner_id: str = "", ) -> schemas_out.AssetsList: sort = _safe_sort_field(sort) order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() @@ -105,6 +111,7 @@ async def list_assets( async with await create_session() as session: infos, tag_map, total = await list_asset_infos_page( session, + owner_id=owner_id, include_tags=include_tags, exclude_tags=exclude_tags, name_contains=name_contains, @@ -141,9 +148,9 @@ async def list_assets( ) -async def get_asset(*, asset_info_id: int) -> schemas_out.AssetDetail: +async def get_asset(*, asset_info_id: int, owner_id: str = "") -> schemas_out.AssetDetail: async with await create_session() as session: - res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id) + res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) if not res: raise ValueError(f"AssetInfo {asset_info_id} not found") info, asset, tag_names = res @@ -163,7 +170,9 @@ async def get_asset(*, asset_info_id: int) -> schemas_out.AssetDetail: async def resolve_asset_content_for_download( - *, asset_info_id: int + *, + asset_info_id: int, + owner_id: str = "", ) -> tuple[str, str, str]: """ Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time. @@ -173,7 +182,7 @@ async def resolve_asset_content_for_download( ValueError if AssetInfo cannot be found """ async with await create_session() as session: - pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id) + pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) if not pair: raise ValueError(f"AssetInfo {asset_info_id} not found") @@ -198,6 +207,7 @@ async def upload_asset_from_temp_path( *, temp_path: str, client_filename: Optional[str] = None, + owner_id: str = "", ) -> schemas_out.AssetCreated: """ Finalize an uploaded temp file: @@ -250,7 +260,7 @@ async def upload_asset_from_temp_path( mtime_ns=mtime_ns, mime_type=content_type, info_name=os.path.basename(dest_abs), - owner_id="", + owner_id=owner_id, preview_hash=None, user_metadata=spec.user_metadata or {}, tags=spec.tags, @@ -262,7 +272,7 @@ async def upload_asset_from_temp_path( if not info_id: raise RuntimeError("failed to create asset metadata") - pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id)) + pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id), owner_id=owner_id) if not pair: raise RuntimeError("inconsistent DB state after ingest") info, asset = pair @@ -290,8 +300,15 @@ async def update_asset( name: Optional[str] = None, tags: Optional[list[str]] = None, user_metadata: Optional[dict] = None, + owner_id: str = "", ) -> schemas_out.AssetUpdated: async with await create_session() as session: + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + info = await update_asset_info_full( session, asset_info_id=asset_info_id, @@ -300,6 +317,7 @@ async def update_asset( user_metadata=user_metadata, tag_origin="manual", added_by=None, + asset_info_row=info_row, ) tag_names = await get_asset_tags(session, asset_info_id=asset_info_id) @@ -315,9 +333,9 @@ async def update_asset( ) -async def delete_asset_reference(*, asset_info_id: int) -> bool: +async def delete_asset_reference(*, asset_info_id: int, owner_id: str) -> bool: async with await create_session() as session: - r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id) + r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) await session.commit() return r @@ -328,6 +346,7 @@ async def create_asset_from_hash( name: str, tags: Optional[list[str]] = None, user_metadata: Optional[dict] = None, + owner_id: str = "", ) -> Optional[schemas_out.AssetCreated]: canonical = hash_str.strip().lower() async with await create_session() as session: @@ -343,6 +362,7 @@ async def create_asset_from_hash( tags=tags or [], tag_origin="manual", added_by=None, + owner_id=owner_id, ) tag_names = await get_asset_tags(session, asset_info_id=info.id) await session.commit() @@ -369,6 +389,7 @@ async def list_tags( offset: int = 0, order: str = "count_desc", include_zero: bool = True, + owner_id: str = "", ) -> schemas_out.TagsList: limit = max(1, min(1000, limit)) offset = max(0, offset) @@ -381,6 +402,7 @@ async def list_tags( offset=offset, include_zero=include_zero, order=order, + owner_id=owner_id, ) tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] @@ -393,8 +415,14 @@ async def add_tags_to_asset( tags: list[str], origin: str = "manual", added_by: Optional[str] = None, + owner_id: str = "", ) -> schemas_out.TagsAdd: async with await create_session() as session: + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") data = await add_tags_to_asset_info( session, asset_info_id=asset_info_id, @@ -402,6 +430,7 @@ async def add_tags_to_asset( origin=origin, added_by=added_by, create_if_missing=True, + asset_info_row=info_row, ) await session.commit() return schemas_out.TagsAdd(**data) @@ -411,8 +440,15 @@ async def remove_tags_from_asset( *, asset_info_id: int, tags: list[str], + owner_id: str = "", ) -> schemas_out.TagsRemove: async with await create_session() as session: + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + data = await remove_tags_from_asset_info( session, asset_info_id=asset_info_id, diff --git a/app/database/services.py b/app/database/services.py index 95a2d07ab979..4bf09ed9718b 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -14,7 +14,7 @@ from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation from .timeutil import utcnow -from .._assets_helpers import normalize_tags +from .._assets_helpers import normalize_tags, visible_owner_clause async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: @@ -30,6 +30,10 @@ async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Option return await session.get(Asset, asset_hash) +async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> Optional[AssetInfo]: + return await session.get(AssetInfo, asset_info_id) + + async def check_fs_asset_exists_quick( session, *, @@ -317,6 +321,7 @@ async def touch_asset_info_by_id( async def list_asset_infos_page( session: AsyncSession, *, + owner_id: str = "", include_tags: Optional[Sequence[str]] = None, exclude_tags: Optional[Sequence[str]] = None, name_contains: Optional[str] = None, @@ -326,26 +331,18 @@ async def list_asset_infos_page( sort: str = "created_at", order: str = "desc", ) -> tuple[list[AssetInfo], dict[int, list[str]], int]: - """ - Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1), - plus a map of asset_info_id -> [tags], and the total count. - - We purposely collect tags in a separate (single) query to avoid row explosion. - """ - - # Build base query + """Return page of AssetInfo rows in the viewers visibility.""" base = ( select(AssetInfo) .join(Asset, Asset.hash == AssetInfo.asset_hash) .options(contains_eager(AssetInfo.asset)) + .where(visible_owner_clause(owner_id)) ) - # Filters if name_contains: base = base.where(AssetInfo.name.ilike(f"%{name_contains}%")) base = _apply_tag_filters(base, include_tags, exclude_tags) - base = _apply_metadata_filter(base, metadata_filter) # Sort @@ -368,13 +365,14 @@ async def list_asset_infos_page( select(func.count()) .select_from(AssetInfo) .join(Asset, Asset.hash == AssetInfo.asset_hash) + .where(visible_owner_clause(owner_id)) ) if name_contains: count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%")) count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) - total = (await session.execute(count_stmt)).scalar_one() + total = int((await session.execute(count_stmt)).scalar_one() or 0) # Fetch rows infos = (await session.execute(base)).scalars().unique().all() @@ -394,13 +392,22 @@ async def list_asset_infos_page( return infos, tag_map, total -async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: int) -> Optional[tuple[AssetInfo, Asset]]: - row = await session.execute( +async def fetch_asset_info_and_asset( + session: AsyncSession, + *, + asset_info_id: int, + owner_id: str = "", +) -> Optional[tuple[AssetInfo, Asset]]: + stmt = ( select(AssetInfo, Asset) .join(Asset, Asset.hash == AssetInfo.asset_hash) - .where(AssetInfo.id == asset_info_id) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) .limit(1) ) + row = await session.execute(stmt) pair = row.first() if not pair: return None @@ -411,18 +418,17 @@ async def fetch_asset_info_asset_and_tags( session: AsyncSession, *, asset_info_id: int, + owner_id: str = "", ) -> Optional[tuple[AssetInfo, Asset, list[str]]]: - """Fetch AssetInfo, its Asset, and all tag names. - - Returns: - (AssetInfo, Asset, [tag_names]) or None if the asset_info_id does not exist. - """ stmt = ( select(AssetInfo, Asset, Tag.name) .join(Asset, Asset.hash == AssetInfo.asset_hash) .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) - .where(AssetInfo.id == asset_info_id) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) .options(noload(AssetInfo.tags)) .order_by(Tag.name.asc()) ) @@ -511,11 +517,12 @@ async def create_asset_info_for_existing_asset( tags: Optional[Sequence[str]] = None, tag_origin: str = "manual", added_by: Optional[str] = None, + owner_id: str = "", ) -> AssetInfo: """Create a new AssetInfo referencing an existing Asset (no content write).""" now = utcnow() info = AssetInfo( - owner_id="", + owner_id=owner_id, name=name, asset_hash=asset_hash, preview_hash=None, @@ -593,6 +600,7 @@ async def update_asset_info_full( user_metadata: Optional[dict] = None, tag_origin: str = "manual", added_by: Optional[str] = None, + asset_info_row: Any = None, ) -> AssetInfo: """ Update AssetInfo fields: @@ -601,9 +609,12 @@ async def update_asset_info_full( - replace tags with provided set (if provided) Returns the updated AssetInfo. """ - info = await session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") + if not asset_info_row: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + else: + info = asset_info_row touched = False if name is not None and name != info.name: @@ -633,9 +644,12 @@ async def update_asset_info_full( return info -async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> bool: +async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int, owner_id: str) -> bool: """Delete the user-visible AssetInfo row. Cascades clear tags and metadata.""" - res = await session.execute(delete(AssetInfo).where(AssetInfo.id == asset_info_id)) + res = await session.execute(delete(AssetInfo).where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + )) return bool(res.rowcount) @@ -691,25 +705,24 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[s async def list_tags_with_usage( - session, + session: AsyncSession, *, prefix: Optional[str] = None, limit: int = 100, offset: int = 0, include_zero: bool = True, - order: str = "count_desc", # "count_desc" | "name_asc" + order: str = "count_desc", # "count_desc" | "name_asc" + owner_id: str = "", ) -> tuple[list[tuple[str, str, int]], int]: - """ - Returns: - rows: list of (name, tag_type, count) - total: number of tags matching filter (independent of pagination) - """ - # Subquery with counts by tag_name + # Subquery with counts by tag_name and owner_id counts_sq = ( select( AssetInfoTag.tag_name.label("tag_name"), func.count(AssetInfoTag.asset_info_id).label("cnt"), ) + .select_from(AssetInfoTag) + .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) + .where(visible_owner_clause(owner_id)) .group_by(AssetInfoTag.tag_name) .subquery() ) @@ -765,14 +778,16 @@ async def add_tags_to_asset_info( origin: str = "manual", added_by: Optional[str] = None, create_if_missing: bool = True, + asset_info_row: Any = None, ) -> dict: """Adds tags to an AssetInfo. If create_if_missing=True, missing tag rows are created as 'user'. Returns: {"added": [...], "already_present": [...], "total_tags": [...]} """ - info = await session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") + if not asset_info_row: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") norm = normalize_tags(tags) if not norm: diff --git a/server.py b/server.py index ba368654fcc2..310b7601c778 100644 --- a/server.py +++ b/server.py @@ -36,7 +36,7 @@ from app.custom_node_manager import CustomNodeManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes -from app.api.assets_routes import register_assets_routes +from app.api.assets_routes import register_assets_system from protocol import BinaryEventTypes async def send_socket_catch_exception(function, message): @@ -182,7 +182,7 @@ def __init__(self, loop): else args.front_end_root ) logging.info(f"[Prompt Server] web root: {self.web_root}") - register_assets_routes(self.app) + register_assets_system(self.app, self.user_manager) routes = web.RouteTableDef() self.routes = routes self.last_node_id = None From 0379eff0b56a5a22c145815d861e08f109497411 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 27 Aug 2025 21:18:26 +0300 Subject: [PATCH 20/82] allow Upload Asset endpoint to accept hash (as documentation requires) --- app/api/assets_routes.py | 117 ++++++++++++++++++++++++++++++++------- app/api/schemas_in.py | 19 +++++++ app/assets_manager.py | 49 ++++++++++++++-- 3 files changed, 162 insertions(+), 23 deletions(-) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index b5c25dcecb2a..bdbb3616713d 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -1,3 +1,4 @@ +import contextlib import os import uuid import urllib.parse @@ -115,29 +116,63 @@ async def upload_asset(request: web.Request) -> web.Response: reader = await request.multipart() - file_field = None + file_present = False file_client_name: Optional[str] = None tags_raw: list[str] = [] provided_name: Optional[str] = None user_metadata_raw: Optional[str] = None - file_written = 0 + provided_hash: Optional[str] = None + provided_hash_exists: Optional[bool] = None + file_written = 0 tmp_path: Optional[str] = None while True: field = await reader.next() if field is None: break - fname = getattr(field, "name", None) or "" - if fname == "file": - # Save to temp + fname = getattr(field, "name", "") or "" + + if fname == "hash": + try: + s = ((await field.text()) or "").strip().lower() + except Exception: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + + if s: + if ":" not in s: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + provided_hash = f"{algo}:{digest}" + try: + provided_hash_exists = await assets_manager.asset_exists(asset_hash=provided_hash) + except Exception: + provided_hash_exists = None # do not fail the whole request here + + elif fname == "file": + file_present = True + file_client_name = (field.filename or "").strip() + + if provided_hash and provided_hash_exists is True: + # If client supplied a hash that we know exists, drain but do not write to disk + try: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + file_written += len(chunk) + except Exception: + return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.") + continue # Do not create temp file; we will create AssetInfo from the existing content + + # Otherwise, store to temp for hashing/ingest uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) os.makedirs(unique_dir, exist_ok=True) tmp_path = os.path.join(unique_dir, ".upload.part") - file_field = field - file_client_name = (field.filename or "").strip() try: with open(tmp_path, "wb") as f: while True: @@ -148,7 +183,7 @@ async def upload_asset(request: web.Request) -> web.Response: file_written += len(chunk) except Exception: try: - if os.path.exists(tmp_path): + if os.path.exists(tmp_path or ""): os.remove(tmp_path) finally: return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.") @@ -159,12 +194,15 @@ async def upload_asset(request: web.Request) -> web.Response: elif fname == "user_metadata": user_metadata_raw = (await field.text()) or None - if file_field is None: - return _error_response(400, "MISSING_FILE", "Form must include a 'file' part.") + # If client did not send file, and we are not doing a from-hash fast path -> error + if not file_present and not (provided_hash and provided_hash_exists): + return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.") - if file_written == 0: + if file_present and file_written == 0 and not (provided_hash and provided_hash_exists): + # Empty upload is only acceptable if we are fast-pathing from existing hash try: - os.remove(tmp_path) + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) finally: return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.") @@ -173,29 +211,70 @@ async def upload_asset(request: web.Request) -> web.Response: "tags": tags_raw, "name": provided_name, "user_metadata": user_metadata_raw, + "hash": provided_hash, }) except ValidationError as ve: try: - os.remove(tmp_path) + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) finally: return _validation_error_response("INVALID_BODY", ve) - if spec.tags[0] == "models" and spec.tags[1] not in folder_paths.folder_names_and_paths: + # Validate models category against configured folders (consistent with previous behavior) + if spec.tags and spec.tags[0] == "models": + if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + return _error_response( + 400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'" + ) + + owner_id = UserManager.get_request_user_id(request) + + # Fast path: if a valid provided hash exists, create AssetInfo without writing anything + if spec.hash and provided_hash_exists is True: + try: + result = await assets_manager.create_asset_from_hash( + hash_str=spec.hash, + name=spec.name or (spec.hash.split(":", 1)[1]), + tags=spec.tags, + user_metadata=spec.user_metadata or {}, + owner_id=owner_id, + ) + except Exception: + return _error_response(500, "INTERNAL", "Unexpected server error.") + + if result is None: + return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist") + + # Drain temp if we accidentally saved (e.g., hash field came after file) if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - return _error_response(400, "INVALID_BODY", f"unknown models category '{spec.tags[1]}'") + with contextlib.suppress(Exception): + os.remove(tmp_path) + + status = 200 if (not result.created_new) else 201 + return web.json_response(result.model_dump(mode="json"), status=status) + + # Otherwise, we must have a temp file path to ingest + if not tmp_path or not os.path.exists(tmp_path): + # The only case we reach here without a temp file is: client sent a hash that does not exist and no file + return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.") try: created = await assets_manager.upload_asset_from_temp_path( spec, temp_path=tmp_path, client_filename=file_client_name, - owner_id=UserManager.get_request_user_id(request), + owner_id=owner_id, ) - return web.json_response(created.model_dump(mode="json"), status=201) - except ValueError: + status = 201 if created.created_new else 200 + return web.json_response(created.model_dump(mode="json"), status=status) + except ValueError as e: if tmp_path and os.path.exists(tmp_path): os.remove(tmp_path) + msg = str(e) + if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH": + return _error_response(400, "HASH_MISMATCH", "Uploaded file hash does not match provided hash.") return _error_response(400, "BAD_REQUEST", "Invalid inputs.") except Exception: if tmp_path and os.path.exists(tmp_path): diff --git a/app/api/schemas_in.py b/app/api/schemas_in.py index 9694a67a62f6..412b72e3af0b 100644 --- a/app/api/schemas_in.py +++ b/app/api/schemas_in.py @@ -180,12 +180,31 @@ class UploadAssetSpec(BaseModel): if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths - name: desired filename (optional); fallback will be the file hash - user_metadata: arbitrary JSON object (optional) + - hash: optional canonical 'blake3:' provided by the client for validation / fast-path """ model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) tags: list[str] = Field(..., min_length=1) name: Optional[str] = Field(default=None, max_length=512) user_metadata: dict[str, Any] = Field(default_factory=dict) + hash: Optional[str] = Field(default=None) + + @field_validator("hash", mode="before") + @classmethod + def _parse_hash(cls, v): + if v is None: + return None + s = str(v).strip().lower() + if not s: + return None + if ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3": + raise ValueError("only canonical 'blake3:' is accepted here") + if not digest or any(c for c in digest if c not in "0123456789abcdef"): + raise ValueError("hash digest must be lowercase hex") + return f"{algo}:{digest}" @field_validator("tags", mode="before") @classmethod diff --git a/app/assets_manager.py b/app/assets_manager.py index 8cdf1fffc745..bb276249777f 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -1,3 +1,4 @@ +import contextlib import logging import mimetypes import os @@ -208,13 +209,14 @@ async def upload_asset_from_temp_path( temp_path: str, client_filename: Optional[str] = None, owner_id: str = "", + expected_asset_hash: Optional[str] = None, ) -> schemas_out.AssetCreated: """ Finalize an uploaded temp file: - compute blake3 hash - - resolve destination from tags - - decide filename (spec.name or client filename or hash) - - move file atomically + - if expected_asset_hash provided, verify equality (400 on mismatch at caller) + - if an Asset with the same hash exists: discard temp, create AssetInfo only (no write) + - else resolve destination from tags and atomically move into place - ingest into DB (assets, locator state, asset_info + tags) Returns a populated AssetCreated payload. """ @@ -225,7 +227,46 @@ async def upload_asset_from_temp_path( raise RuntimeError(f"failed to hash uploaded file: {e}") asset_hash = "blake3:" + digest - # Resolve destination + if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower(): + raise ValueError("HASH_MISMATCH") + + # Fast path: content already known --> no writes, just create a reference + async with await create_session() as session: + existing = await get_asset_by_hash(session, asset_hash=asset_hash) + if existing is not None: + with contextlib.suppress(Exception): + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + desired_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) + info = await create_asset_info_for_existing_asset( + session, + asset_hash=asset_hash, + name=desired_name, + user_metadata=spec.user_metadata or {}, + tags=spec.tags or [], + tag_origin="manual", + added_by=None, + owner_id=owner_id, + ) + tag_names = await get_asset_tags(session, asset_info_id=info.id) + await session.commit() + + return schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=info.asset_hash, + size=int(existing.size_bytes) if existing.size_bytes is not None else None, + mime_type=existing.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_hash=info.preview_hash, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=False, + ) + + # Resolve destination (only for truly new content) base_dir, subdirs = resolve_destination_from_tags(spec.tags) dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir os.makedirs(dest_dir, exist_ok=True) From eb7008a4d37f0f42cf78174248560ef3818914c0 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 27 Aug 2025 21:26:35 +0300 Subject: [PATCH 21/82] removed not used "added_by" column --- alembic_db/versions/0001_assets.py | 1 - app/api/assets_routes.py | 1 - app/assets_manager.py | 6 ------ app/database/models.py | 1 - app/database/services.py | 11 +---------- 5 files changed, 1 insertion(+), 19 deletions(-) diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index 9fb80ea8c676..18b1c71a1108 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -67,7 +67,6 @@ def upgrade() -> None: sa.Column("asset_info_id", sa.Integer(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), - sa.Column("added_by", sa.String(length=128), nullable=True), sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), ) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index bdbb3616713d..61188c0907ef 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -397,7 +397,6 @@ async def add_asset_tags(request: web.Request) -> web.Response: asset_info_id=asset_info_id, tags=data.tags, origin="manual", - added_by=None, owner_id=UserManager.get_request_user_id(request), ) except (ValueError, PermissionError) as ve: diff --git a/app/assets_manager.py b/app/assets_manager.py index bb276249777f..f9046c5ed936 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -246,7 +246,6 @@ async def upload_asset_from_temp_path( user_metadata=spec.user_metadata or {}, tags=spec.tags or [], tag_origin="manual", - added_by=None, owner_id=owner_id, ) tag_names = await get_asset_tags(session, asset_info_id=info.id) @@ -306,7 +305,6 @@ async def upload_asset_from_temp_path( user_metadata=spec.user_metadata or {}, tags=spec.tags, tag_origin="manual", - added_by=None, require_existing_tags=False, ) info_id = result["asset_info_id"] @@ -357,7 +355,6 @@ async def update_asset( tags=tags, user_metadata=user_metadata, tag_origin="manual", - added_by=None, asset_info_row=info_row, ) @@ -402,7 +399,6 @@ async def create_asset_from_hash( user_metadata=user_metadata or {}, tags=tags or [], tag_origin="manual", - added_by=None, owner_id=owner_id, ) tag_names = await get_asset_tags(session, asset_info_id=info.id) @@ -455,7 +451,6 @@ async def add_tags_to_asset( asset_info_id: int, tags: list[str], origin: str = "manual", - added_by: Optional[str] = None, owner_id: str = "", ) -> schemas_out.TagsAdd: async with await create_session() as session: @@ -469,7 +464,6 @@ async def add_tags_to_asset( asset_info_id=asset_info_id, tags=tags, origin=origin, - added_by=added_by, create_if_missing=True, asset_info_row=info_row, ) diff --git a/app/database/models.py b/app/database/models.py index d964a5226deb..87aa942ed115 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -244,7 +244,6 @@ class AssetInfoTag(Base): String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True ) origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") - added_by: Mapped[str | None] = mapped_column(String(128)) added_at: Mapped[datetime] = mapped_column( DateTime(timezone=False), nullable=False, default=utcnow ) diff --git a/app/database/services.py b/app/database/services.py index 4bf09ed9718b..66d5190322c7 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -79,7 +79,6 @@ async def ingest_fs_asset( user_metadata: Optional[dict] = None, tags: Sequence[str] = (), tag_origin: str = "manual", - added_by: Optional[str] = None, require_existing_tags: bool = False, ) -> dict: """ @@ -247,7 +246,6 @@ async def ingest_fs_asset( asset_info_id=out["asset_info_id"], tag_name=t, origin=tag_origin, - added_by=added_by, added_at=datetime_now, ) for t in to_add @@ -516,7 +514,6 @@ async def create_asset_info_for_existing_asset( user_metadata: Optional[dict] = None, tags: Optional[Sequence[str]] = None, tag_origin: str = "manual", - added_by: Optional[str] = None, owner_id: str = "", ) -> AssetInfo: """Create a new AssetInfo referencing an existing Asset (no content write).""" @@ -544,7 +541,6 @@ async def create_asset_info_for_existing_asset( asset_info_id=info.id, tags=tags, origin=tag_origin, - added_by=added_by, ) return info @@ -555,7 +551,6 @@ async def set_asset_info_tags( asset_info_id: int, tags: Sequence[str], origin: str = "manual", - added_by: Optional[str] = None, ) -> dict: """ Replace the tag set on an AssetInfo with `tags`. Idempotent. @@ -576,7 +571,7 @@ async def set_asset_info_tags( if to_add: await _ensure_tags_exist(session, to_add, tag_type="user") session.add_all([ - AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=utcnow()) + AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) for t in to_add ]) await session.flush() @@ -599,7 +594,6 @@ async def update_asset_info_full( tags: Optional[Sequence[str]] = None, user_metadata: Optional[dict] = None, tag_origin: str = "manual", - added_by: Optional[str] = None, asset_info_row: Any = None, ) -> AssetInfo: """ @@ -633,7 +627,6 @@ async def update_asset_info_full( asset_info_id=asset_info_id, tags=tags, origin=tag_origin, - added_by=added_by, ) touched = True @@ -776,7 +769,6 @@ async def add_tags_to_asset_info( asset_info_id: int, tags: Sequence[str], origin: str = "manual", - added_by: Optional[str] = None, create_if_missing: bool = True, asset_info_row: Any = None, ) -> dict: @@ -820,7 +812,6 @@ async def add_tags_to_asset_info( asset_info_id=asset_info_id, tag_name=t, origin=origin, - added_by=added_by, added_at=utcnow(), ) for t in to_add From 871e41aec6225c6c2c1b019db482fbd94b214c32 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 27 Aug 2025 21:36:31 +0300 Subject: [PATCH 22/82] removed not needed "refcount" column --- alembic_db/versions/0001_assets.py | 2 -- app/database/models.py | 1 - app/database/services.py | 3 +-- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index 18b1c71a1108..ec41ee6c1c83 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -22,11 +22,9 @@ def upgrade() -> None: sa.Column("hash", sa.String(length=256), primary_key=True), sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), sa.Column("mime_type", sa.String(length=255), nullable=True), - sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"), sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), - sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"), ) op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) diff --git a/app/database/models.py b/app/database/models.py index 87aa942ed115..f2972b00add0 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -44,7 +44,6 @@ class Asset(Base): hash: Mapped[str] = mapped_column(String(256), primary_key=True) size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) mime_type: Mapped[str | None] = mapped_column(String(255)) - refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=False), nullable=False, default=utcnow ) diff --git a/app/database/services.py b/app/database/services.py index 66d5190322c7..5c3bbe42aa1b 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -89,7 +89,7 @@ async def ingest_fs_asset( - Insert AssetCacheState if missing; else update mtime_ns if different. Optionally (when info_name is provided): - - Create an AssetInfo (no refcount changes). + - Create an AssetInfo. - Link provided tags to that AssetInfo. * If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table. * If False (default), create unknown tags. @@ -122,7 +122,6 @@ async def ingest_fs_asset( hash=asset_hash, size_bytes=int(size_bytes), mime_type=mime_type, - refcount=0, created_at=datetime_now, updated_at=datetime_now, ) From bdf4ba24ceff6fe2b4ac4e18bc51192175ba4220 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 27 Aug 2025 21:58:17 +0300 Subject: [PATCH 23/82] removed not needed "assets.updated_at" column --- alembic_db/versions/0001_assets.py | 1 - app/database/models.py | 3 --- app/database/services.py | 4 +--- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index ec41ee6c1c83..8499306ba103 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -23,7 +23,6 @@ def upgrade() -> None: sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), sa.Column("mime_type", sa.String(length=255), nullable=True), sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), ) op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) diff --git a/app/database/models.py b/app/database/models.py index f2972b00add0..ea3f4970d394 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -47,9 +47,6 @@ class Asset(Base): created_at: Mapped[datetime] = mapped_column( DateTime(timezone=False), nullable=False, default=utcnow ) - updated_at: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow - ) infos: Mapped[list["AssetInfo"]] = relationship( "AssetInfo", diff --git a/app/database/services.py b/app/database/services.py index 5c3bbe42aa1b..d1b85e160ef0 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -85,7 +85,7 @@ async def ingest_fs_asset( Upsert Asset identity row + cache state pointing at local file. Always: - - Insert Asset if missing; else update size_bytes (and updated_at) if different. + - Insert Asset if missing; - Insert AssetCacheState if missing; else update mtime_ns if different. Optionally (when info_name is provided): @@ -123,7 +123,6 @@ async def ingest_fs_asset( size_bytes=int(size_bytes), mime_type=mime_type, created_at=datetime_now, - updated_at=datetime_now, ) ) await session.flush() @@ -140,7 +139,6 @@ async def ingest_fs_asset( existing.mime_type = mime_type changed = True if changed: - existing.updated_at = datetime_now out["asset_updated"] = True else: logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash) From 6b86be320a8dd83ab6c666b8680eaff595e5e4bd Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Thu, 28 Aug 2025 08:22:54 +0300 Subject: [PATCH 24/82] use UUID instead of autoincrement Integer for Assets ID field --- alembic_db/versions/0001_assets.py | 7 ++-- app/api/assets_routes.py | 51 ++++++++++++++++-------------- app/api/schemas_out.py | 6 ++-- app/assets_manager.py | 16 +++++----- app/database/models.py | 12 +++---- app/database/services.py | 40 +++++++++++------------ 6 files changed, 67 insertions(+), 65 deletions(-) diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index 8499306ba103..c80874aa20ea 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -30,7 +30,7 @@ def upgrade() -> None: # ASSETS_INFO: user-visible references (mutable metadata) op.create_table( "assets_info", - sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("id", sa.String(length=36), primary_key=True), sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), sa.Column("name", sa.String(length=512), nullable=False), sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False), @@ -40,7 +40,6 @@ def upgrade() -> None: sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), sa.UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"), - sqlite_autoincrement=True, ) op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) op.create_index("ix_assets_info_asset_hash", "assets_info", ["asset_hash"]) @@ -61,7 +60,7 @@ def upgrade() -> None: # ASSET_INFO_TAGS: many-to-many for tags on AssetInfo op.create_table( "asset_info_tags", - sa.Column("asset_info_id", sa.Integer(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), @@ -83,7 +82,7 @@ def upgrade() -> None: # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting op.create_table( "asset_info_meta", - sa.Column("asset_info_id", sa.Integer(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), sa.Column("key", sa.String(length=256), nullable=False), sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), sa.Column("val_str", sa.String(length=2048), nullable=True), diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 61188c0907ef..71e99f2310a2 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -16,6 +16,9 @@ ROUTES = web.RouteTableDef() UserManager: Optional[user_manager.UserManager] = None +# UUID regex (canonical hyphenated form, case-insensitive) +UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + @ROUTES.head("/api/assets/hash/{hash}") async def head_asset_by_hash(request: web.Request) -> web.Response: @@ -52,13 +55,13 @@ async def list_assets(request: web.Request) -> web.Response: return web.json_response(payload.model_dump(mode="json")) -@ROUTES.get("/api/assets/{id:\\d+}/content") +@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") async def download_asset_content(request: web.Request) -> web.Response: - asset_info_id_raw = request.match_info.get("id") + asset_info_id_raw = request.match_info.get("id", "") try: - asset_info_id = int(asset_info_id_raw) + asset_info_id = str(uuid.UUID(asset_info_id_raw)) except Exception: - return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") disposition = request.query.get("disposition", "attachment").lower().strip() if disposition not in {"inline", "attachment"}: @@ -282,13 +285,13 @@ async def upload_asset(request: web.Request) -> web.Response: return _error_response(500, "INTERNAL", "Unexpected server error.") -@ROUTES.get("/api/assets/{id:\\d+}") +@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") async def get_asset(request: web.Request) -> web.Response: - asset_info_id_raw = request.match_info.get("id") + asset_info_id_raw = request.match_info.get("id", "") try: - asset_info_id = int(asset_info_id_raw) + asset_info_id = str(uuid.UUID(asset_info_id_raw)) except Exception: - return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") try: result = await assets_manager.get_asset( @@ -302,13 +305,13 @@ async def get_asset(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) -@ROUTES.put("/api/assets/{id:\\d+}") +@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") async def update_asset(request: web.Request) -> web.Response: - asset_info_id_raw = request.match_info.get("id") + asset_info_id_raw = request.match_info.get("id", "") try: - asset_info_id = int(asset_info_id_raw) + asset_info_id = str(uuid.UUID(asset_info_id_raw)) except Exception: - return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") try: body = schemas_in.UpdateAssetBody.model_validate(await request.json()) @@ -332,13 +335,13 @@ async def update_asset(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) -@ROUTES.delete("/api/assets/{id:\\d+}") +@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") async def delete_asset(request: web.Request) -> web.Response: - asset_info_id_raw = request.match_info.get("id") + asset_info_id_raw = request.match_info.get("id", "") try: - asset_info_id = int(asset_info_id_raw) + asset_info_id = str(uuid.UUID(asset_info_id_raw)) except Exception: - return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") try: deleted = await assets_manager.delete_asset_reference( @@ -376,13 +379,13 @@ async def get_tags(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json")) -@ROUTES.post("/api/assets/{id:\\d+}/tags") +@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") async def add_asset_tags(request: web.Request) -> web.Response: - asset_info_id_raw = request.match_info.get("id") + asset_info_id_raw = request.match_info.get("id", "") try: - asset_info_id = int(asset_info_id_raw) + asset_info_id = str(uuid.UUID(asset_info_id_raw)) except Exception: - return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") try: payload = await request.json() @@ -407,13 +410,13 @@ async def add_asset_tags(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) -@ROUTES.delete("/api/assets/{id:\\d+}/tags") +@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") async def delete_asset_tags(request: web.Request) -> web.Response: - asset_info_id_raw = request.match_info.get("id") + asset_info_id_raw = request.match_info.get("id", "") try: - asset_info_id = int(asset_info_id_raw) + asset_info_id = str(uuid.UUID(asset_info_id_raw)) except Exception: - return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") try: payload = await request.json() diff --git a/app/api/schemas_out.py b/app/api/schemas_out.py index 1b41d8021d32..581d717960c3 100644 --- a/app/api/schemas_out.py +++ b/app/api/schemas_out.py @@ -4,7 +4,7 @@ class AssetSummary(BaseModel): - id: int + id: str name: str asset_hash: str size: Optional[int] = None @@ -29,7 +29,7 @@ class AssetsList(BaseModel): class AssetUpdated(BaseModel): - id: int + id: str name: str asset_hash: str tags: list[str] = Field(default_factory=list) @@ -44,7 +44,7 @@ def _ser_updated(self, v: Optional[datetime], _info): class AssetDetail(BaseModel): - id: int + id: str name: str asset_hash: str size: Optional[int] = None diff --git a/app/assets_manager.py b/app/assets_manager.py index f9046c5ed936..3d7c040c4c18 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -73,7 +73,7 @@ async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> No async with await create_session() as session: if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns): - await touch_asset_infos_by_fs_path(session, abs_path=abs_path) + await touch_asset_infos_by_fs_path(session, file_path=abs_path) await session.commit() return @@ -149,7 +149,7 @@ async def list_assets( ) -async def get_asset(*, asset_info_id: int, owner_id: str = "") -> schemas_out.AssetDetail: +async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail: async with await create_session() as session: res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) if not res: @@ -172,7 +172,7 @@ async def get_asset(*, asset_info_id: int, owner_id: str = "") -> schemas_out.As async def resolve_asset_content_for_download( *, - asset_info_id: int, + asset_info_id: str, owner_id: str = "", ) -> tuple[str, str, str]: """ @@ -311,7 +311,7 @@ async def upload_asset_from_temp_path( if not info_id: raise RuntimeError("failed to create asset metadata") - pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id), owner_id=owner_id) + pair = await fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id) if not pair: raise RuntimeError("inconsistent DB state after ingest") info, asset = pair @@ -335,7 +335,7 @@ async def upload_asset_from_temp_path( async def update_asset( *, - asset_info_id: int, + asset_info_id: str, name: Optional[str] = None, tags: Optional[list[str]] = None, user_metadata: Optional[dict] = None, @@ -371,7 +371,7 @@ async def update_asset( ) -async def delete_asset_reference(*, asset_info_id: int, owner_id: str) -> bool: +async def delete_asset_reference(*, asset_info_id: str, owner_id: str) -> bool: async with await create_session() as session: r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) await session.commit() @@ -448,7 +448,7 @@ async def list_tags( async def add_tags_to_asset( *, - asset_info_id: int, + asset_info_id: str, tags: list[str], origin: str = "manual", owner_id: str = "", @@ -473,7 +473,7 @@ async def add_tags_to_asset( async def remove_tags_from_asset( *, - asset_info_id: int, + asset_info_id: str, tags: list[str], owner_id: str = "", ) -> schemas_out.TagsRemove: diff --git a/app/database/models.py b/app/database/models.py index ea3f4970d394..47f8bbaf35e5 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,5 +1,6 @@ from datetime import datetime from typing import Any, Optional +import uuid from sqlalchemy import ( Integer, @@ -135,7 +136,7 @@ class AssetLocation(Base): class AssetInfo(Base): __tablename__ = "assets_info" - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") name: Mapped[str] = mapped_column(String(512), nullable=False) asset_hash: Mapped[str] = mapped_column( @@ -194,7 +195,6 @@ class AssetInfo(Base): Index("ix_assets_info_name", "name"), Index("ix_assets_info_created_at", "created_at"), Index("ix_assets_info_last_access_time", "last_access_time"), - {"sqlite_autoincrement": True}, ) def to_dict(self, include_none: bool = False) -> dict[str, Any]: @@ -209,8 +209,8 @@ def __repr__(self) -> str: class AssetInfoMeta(Base): __tablename__ = "asset_info_meta" - asset_info_id: Mapped[int] = mapped_column( - Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + asset_info_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True ) key: Mapped[str] = mapped_column(String(256), primary_key=True) ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0) @@ -233,8 +233,8 @@ class AssetInfoMeta(Base): class AssetInfoTag(Base): __tablename__ = "asset_info_tags" - asset_info_id: Mapped[int] = mapped_column( - Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + asset_info_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True ) tag_name: Mapped[str] = mapped_column( String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True diff --git a/app/database/services.py b/app/database/services.py index d1b85e160ef0..da8c02f67c76 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -30,7 +30,7 @@ async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Option return await session.get(Asset, asset_hash) -async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> Optional[AssetInfo]: +async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]: return await session.get(AssetInfo, asset_info_id) @@ -100,13 +100,13 @@ async def ingest_fs_asset( "asset_updated": bool, "state_created": bool, "state_updated": bool, - "asset_info_id": int | None, + "asset_info_id": str | None, } """ locator = os.path.abspath(abs_path) datetime_now = utcnow() - out = { + out: dict[str, Any] = { "asset_created": False, "asset_updated": False, "state_created": False, @@ -187,7 +187,7 @@ async def ingest_fs_asset( last_access_time=datetime_now, ) session.add(info) - await session.flush() # get info.id + await session.flush() # get info.id (UUID) out["asset_info_id"] = info.id existing_info = ( @@ -263,11 +263,11 @@ async def ingest_fs_asset( async def touch_asset_infos_by_fs_path( session: AsyncSession, *, - abs_path: str, + file_path: str, ts: Optional[datetime] = None, only_if_newer: bool = True, ) -> int: - locator = os.path.abspath(abs_path) + locator = os.path.abspath(file_path) ts = ts or utcnow() stmt = sa.update(AssetInfo).where( @@ -298,7 +298,7 @@ async def touch_asset_infos_by_fs_path( async def touch_asset_info_by_id( session: AsyncSession, *, - asset_info_id: int, + asset_info_id: str, ts: Optional[datetime] = None, only_if_newer: bool = True, ) -> int: @@ -325,7 +325,7 @@ async def list_asset_infos_page( offset: int = 0, sort: str = "created_at", order: str = "desc", -) -> tuple[list[AssetInfo], dict[int, list[str]], int]: +) -> tuple[list[AssetInfo], dict[str, list[str]], int]: """Return page of AssetInfo rows in the viewers visibility.""" base = ( select(AssetInfo) @@ -373,8 +373,8 @@ async def list_asset_infos_page( infos = (await session.execute(base)).scalars().unique().all() # Collect tags in bulk (single query) - id_list = [i.id for i in infos] - tag_map: dict[int, list[str]] = defaultdict(list) + id_list: list[str] = [i.id for i in infos] + tag_map: dict[str, list[str]] = defaultdict(list) if id_list: rows = await session.execute( select(AssetInfoTag.asset_info_id, Tag.name) @@ -390,7 +390,7 @@ async def list_asset_infos_page( async def fetch_asset_info_and_asset( session: AsyncSession, *, - asset_info_id: int, + asset_info_id: str, owner_id: str = "", ) -> Optional[tuple[AssetInfo, Asset]]: stmt = ( @@ -412,7 +412,7 @@ async def fetch_asset_info_and_asset( async def fetch_asset_info_asset_and_tags( session: AsyncSession, *, - asset_info_id: int, + asset_info_id: str, owner_id: str = "", ) -> Optional[tuple[AssetInfo, Asset, list[str]]]: stmt = ( @@ -449,7 +449,7 @@ async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: st async def list_asset_locations( session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None -) -> list[AssetLocation]: +) -> list[AssetLocation] | Sequence[AssetLocation]: stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash) if provider: stmt = stmt.where(AssetLocation.provider == provider) @@ -545,7 +545,7 @@ async def create_asset_info_for_existing_asset( async def set_asset_info_tags( session: AsyncSession, *, - asset_info_id: int, + asset_info_id: str, tags: Sequence[str], origin: str = "manual", ) -> dict: @@ -586,7 +586,7 @@ async def set_asset_info_tags( async def update_asset_info_full( session: AsyncSession, *, - asset_info_id: int, + asset_info_id: str, name: Optional[str] = None, tags: Optional[Sequence[str]] = None, user_metadata: Optional[dict] = None, @@ -634,7 +634,7 @@ async def update_asset_info_full( return info -async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int, owner_id: str) -> bool: +async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool: """Delete the user-visible AssetInfo row. Cascades clear tags and metadata.""" res = await session.execute(delete(AssetInfo).where( AssetInfo.id == asset_info_id, @@ -646,7 +646,7 @@ async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int, async def replace_asset_info_metadata_projection( session: AsyncSession, *, - asset_info_id: int, + asset_info_id: str, user_metadata: Optional[dict], ) -> None: """Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`.""" @@ -683,7 +683,7 @@ async def replace_asset_info_metadata_projection( await session.flush() -async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[str]: +async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]: return [ tag_name for (tag_name,) in ( @@ -763,7 +763,7 @@ async def list_tags_with_usage( async def add_tags_to_asset_info( session: AsyncSession, *, - asset_info_id: int, + asset_info_id: str, tags: Sequence[str], origin: str = "manual", create_if_missing: bool = True, @@ -829,7 +829,7 @@ async def add_tags_to_asset_info( async def remove_tags_from_asset_info( session: AsyncSession, *, - asset_info_id: int, + asset_info_id: str, tags: Sequence[str], ) -> dict: """Removes tags from an AssetInfo. From bf8363ec875246f179f36fa11a1d22deceb2a05f Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Fri, 29 Aug 2025 19:48:42 +0300 Subject: [PATCH 25/82] always autofill "filename" in the metadata --- app/_assets_helpers.py | 27 +++++++++++ app/database/services.py | 99 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 116 insertions(+), 10 deletions(-) diff --git a/app/_assets_helpers.py b/app/_assets_helpers.py index 4a8d39625376..ddc43f1eaa4e 100644 --- a/app/_assets_helpers.py +++ b/app/_assets_helpers.py @@ -145,3 +145,30 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: if owner_id == "": return AssetInfo.owner_id == "" return AssetInfo.owner_id.in_(["", owner_id]) + + +def compute_model_relative_filename(file_path: str) -> str | None: + """ + Return the model's path relative to the last well-known folder (the model category), + using forward slashes, eg: + /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" + /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" + + For non-model paths, returns None. + NOTE: this is a temporary helper, used only for initializing metadata["filename"] field. + """ + try: + root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path) + except ValueError: + return None + + if root_category != "models": + return None + + p = Path(rel_path) + # parts[0] is the well-known category (eg "checkpoints" or "text_encoders") + parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] + if not parts: + return None + inside = parts[1:] if len(parts) > 1 else [parts[0]] + return "/".join(inside) # normalize to POSIX style for portability diff --git a/app/database/services.py b/app/database/services.py index da8c02f67c76..af8861001b81 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -14,7 +14,7 @@ from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation from .timeutil import utcnow -from .._assets_helpers import normalize_tags, visible_owner_clause +from .._assets_helpers import normalize_tags, visible_owner_clause, compute_model_relative_filename async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: @@ -251,12 +251,38 @@ async def ingest_fs_asset( await session.flush() # 2c) Rebuild metadata projection if provided - if user_metadata is not None and out["asset_info_id"] is not None: - await replace_asset_info_metadata_projection( - session, - asset_info_id=out["asset_info_id"], - user_metadata=user_metadata, - ) + # Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore + # if user_metadata is not None and out["asset_info_id"] is not None: + # await replace_asset_info_metadata_projection( + # session, + # asset_info_id=out["asset_info_id"], + # user_metadata=user_metadata, + # ) + # start of adding metadata["filename"] + if out["asset_info_id"] is not None: + computed_filename = compute_model_relative_filename(abs_path) + + # Start from current metadata on this AssetInfo, if any + current_meta = existing_info.user_metadata or {} + new_meta = dict(current_meta) + + # Merge caller-provided metadata, if any (caller keys override current) + if user_metadata is not None: + for k, v in user_metadata.items(): + new_meta[k] = v + + # Enforce correct model-relative filename when known + if computed_filename: + new_meta["filename"] = computed_filename + + # Only write when there is a change + if new_meta != current_meta: + await replace_asset_info_metadata_projection( + session, + asset_info_id=out["asset_info_id"], + user_metadata=new_meta, + ) + # end of adding metadata["filename"] return out @@ -527,10 +553,33 @@ async def create_asset_info_for_existing_asset( session.add(info) await session.flush() # get info.id - if user_metadata is not None: + # Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore + # if user_metadata is not None: + # await replace_asset_info_metadata_projection( + # session, asset_info_id=info.id, user_metadata=user_metadata + # ) + + # start of adding metadata["filename"] + new_meta = dict(user_metadata or {}) + + computed_filename = None + try: + state = await get_cache_state_by_asset_hash(session, asset_hash=asset_hash) + if state and state.file_path: + computed_filename = compute_model_relative_filename(state.file_path) + except Exception: + computed_filename = None + + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta: await replace_asset_info_metadata_projection( - session, asset_info_id=info.id, user_metadata=user_metadata + session, + asset_info_id=info.id, + user_metadata=new_meta, ) + # end of adding metadata["filename"] if tags is not None: await set_asset_info_tags( @@ -612,11 +661,41 @@ async def update_asset_info_full( info.name = name touched = True + # Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore + # if user_metadata is not None: + # await replace_asset_info_metadata_projection( + # session, asset_info_id=asset_info_id, user_metadata=user_metadata + # ) + # touched = True + + # start of adding metadata["filename"] + computed_filename = None + try: + state = await get_cache_state_by_asset_hash(session, asset_hash=info.asset_hash) + if state and state.file_path: + computed_filename = compute_model_relative_filename(state.file_path) + except Exception: + computed_filename = None + if user_metadata is not None: + new_meta = dict(user_metadata) + if computed_filename: + new_meta["filename"] = computed_filename await replace_asset_info_metadata_projection( - session, asset_info_id=asset_info_id, user_metadata=user_metadata + session, asset_info_id=asset_info_id, user_metadata=new_meta ) touched = True + else: + if computed_filename: + current_meta = info.user_metadata or {} + if current_meta.get("filename") != computed_filename: + new_meta = dict(current_meta) + new_meta["filename"] = computed_filename + await replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=new_meta + ) + touched = True + # end of adding metadata["filename"] if tags is not None: await set_asset_info_tags( From ce270ba090149baacdcb958dc7ba65a37d87e648 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Fri, 5 Sep 2025 17:46:09 +0300 Subject: [PATCH 26/82] added Assets Autoscan feature --- app/api/schemas_out.py | 16 +- app/assets_manager.py | 2 +- app/assets_scanner.py | 512 ++++++++++++++++++++++++++--------------- comfy/cli_args.py | 3 +- main.py | 25 ++ 5 files changed, 368 insertions(+), 190 deletions(-) diff --git a/app/api/schemas_out.py b/app/api/schemas_out.py index 581d717960c3..8bb34096bb75 100644 --- a/app/api/schemas_out.py +++ b/app/api/schemas_out.py @@ -92,17 +92,25 @@ class TagsRemove(BaseModel): total_tags: list[str] = Field(default_factory=list) +class AssetScanError(BaseModel): + path: str + message: str + phase: Literal["fast", "slow"] + at: Optional[str] = Field(None, description="ISO timestamp") + + class AssetScanStatus(BaseModel): scan_id: str - root: Literal["models","input","output"] - status: Literal["scheduled","running","completed","failed","cancelled"] + root: Literal["models", "input", "output"] + status: Literal["scheduled", "running", "completed", "failed", "cancelled"] scheduled_at: Optional[str] = None started_at: Optional[str] = None finished_at: Optional[str] = None discovered: int = 0 processed: int = 0 - errors: int = 0 - last_error: Optional[str] = None + slow_queue_total: int = 0 + slow_queue_finished: int = 0 + file_errors: list[AssetScanError] = Field(default_factory=list) class AssetScanStatusResponse(BaseModel): diff --git a/app/assets_manager.py b/app/assets_manager.py index 3d7c040c4c18..b84b61508762 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -43,7 +43,7 @@ async def asset_exists(*, asset_hash: str) -> bool: def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None: - if not args.disable_model_processing: + if not args.enable_model_processing: if tags is None: tags = [] try: diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 5bafd6bb7c29..ccfc8e9e5efe 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -5,22 +5,22 @@ import time from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Literal, Optional, Sequence +from typing import Callable, Literal, Optional, Sequence import folder_paths from . import assets_manager from .api import schemas_out from ._assets_helpers import get_comfy_models_folders +from .database.db import create_session +from .database.services import check_fs_asset_exists_quick LOGGER = logging.getLogger(__name__) RootType = Literal["models", "input", "output"] ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") -# We run at most one scan per root; overall max parallelism is therefore 3 -# We also bound per-scan ingestion concurrency to avoid swamping threads/DB -DEFAULT_PER_SCAN_CONCURRENCY = 1 +SLOW_HASH_CONCURRENCY = 1 @dataclass @@ -34,15 +34,25 @@ class ScanProgress: discovered: int = 0 processed: int = 0 - errors: int = 0 - last_error: Optional[str] = None + slow_queue_total: int = 0 + slow_queue_finished: int = 0 + file_errors: list[dict] = field(default_factory=list) # {"path","message","phase","at"} - # Optional details for diagnostics (e.g., files per bucket) - details: dict[str, int] = field(default_factory=dict) + # Internal diagnostics for logs + _fast_total_seen: int = 0 + _fast_clean: int = 0 + + +@dataclass +class SlowQueueState: + queue: asyncio.Queue + workers: list[asyncio.Task] = field(default_factory=list) + closed: bool = False RUNNING_TASKS: dict[RootType, asyncio.Task] = {} PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {} +SLOW_STATE_BY_ROOT: dict[RootType, SlowQueueState] = {} def _new_scan_id(root: RootType) -> str: @@ -50,23 +60,13 @@ def _new_scan_id(root: RootType) -> str: def current_statuses() -> schemas_out.AssetScanStatusResponse: - return schemas_out.AssetScanStatusResponse( - scans=[ - schemas_out.AssetScanStatus( - scan_id=s.scan_id, - root=s.root, - status=s.status, - scheduled_at=_ts_to_iso(s.scheduled_at), - started_at=_ts_to_iso(s.started_at), - finished_at=_ts_to_iso(s.finished_at), - discovered=s.discovered, - processed=s.processed, - errors=s.errors, - last_error=s.last_error, - ) - for s in [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT] - ] - ) + scans = [] + for root in ALLOWED_ROOTS: + prog = PROGRESS_BY_ROOT.get(root) + if not prog: + continue + scans.append(_scan_progress_to_scan_status_model(prog)) + return schemas_out.AssetScanStatusResponse(scans=scans) async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse: @@ -81,8 +81,6 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes normalized: list[RootType] = [] seen = set() for r in roots or []: - if not isinstance(r, str): - continue rr = r.strip().lower() if rr in ALLOWED_ROOTS and rr not in seen: normalized.append(rr) # type: ignore @@ -98,142 +96,311 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled") PROGRESS_BY_ROOT[root] = prog - - task = asyncio.create_task(_run_scan_for_root(root, prog), name=f"asset-scan:{root}") - RUNNING_TASKS[root] = task + SLOW_STATE_BY_ROOT[root] = SlowQueueState(queue=asyncio.Queue()) + RUNNING_TASKS[root] = asyncio.create_task( + _pipeline_for_root(root, prog, progress_cb=None), + name=f"asset-scan:{root}", + ) results.append(prog) + return _status_response_for(results) + - return schemas_out.AssetScanStatusResponse( - scans=[ - schemas_out.AssetScanStatus( - scan_id=s.scan_id, - root=s.root, - status=s.status, - scheduled_at=_ts_to_iso(s.scheduled_at), - started_at=_ts_to_iso(s.started_at), - finished_at=_ts_to_iso(s.finished_at), - discovered=s.discovered, - processed=s.processed, - errors=s.errors, - last_error=s.last_error, +async def fast_reconcile_and_kickoff( + roots: Sequence[str] | None = None, + *, + progress_cb: Optional[Callable[[dict], None]] = None, +) -> schemas_out.AssetScanStatusResponse: + """ + Startup helper: do the fast pass now (so we know queue size), + start slow hashing in the background, return immediately. + """ + normalized = [*ALLOWED_ROOTS] if not roots else [r for r in roots if r in ALLOWED_ROOTS] + snaps: list[ScanProgress] = [] + + for root in normalized: + if root in RUNNING_TASKS and not RUNNING_TASKS[root].done(): + snaps.append(PROGRESS_BY_ROOT[root]) + continue + + prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled") + PROGRESS_BY_ROOT[root] = prog + state = SlowQueueState(queue=asyncio.Queue()) + SLOW_STATE_BY_ROOT[root] = state + + prog.status = "running" + prog.started_at = time.time() + try: + await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb) + except Exception as e: + _append_error(prog, phase="fast", path="", message=str(e)) + prog.status = "failed" + prog.finished_at = time.time() + LOGGER.exception("Fast reconcile failed for %s", root) + snaps.append(prog) + continue + + _start_slow_workers(root, prog, state, progress_cb=progress_cb) + RUNNING_TASKS[root] = asyncio.create_task( + _await_workers_then_finish(root, prog, state, progress_cb=progress_cb), + name=f"asset-hash:{root}", + ) + snaps.append(prog) + return _status_response_for(snaps) + + +def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse: + return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses]) + + +def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus: + return schemas_out.AssetScanStatus( + scan_id=progress.scan_id, + root=progress.root, + status=progress.status, + scheduled_at=_ts_to_iso(progress.scheduled_at), + started_at=_ts_to_iso(progress.started_at), + finished_at=_ts_to_iso(progress.finished_at), + discovered=progress.discovered, + processed=progress.processed, + slow_queue_total=progress.slow_queue_total, + slow_queue_finished=progress.slow_queue_finished, + file_errors=[ + schemas_out.AssetScanError( + path=e.get("path", ""), + message=e.get("message", ""), + phase=e.get("phase", "slow"), + at=e.get("at"), ) - for s in results - ] + for e in (progress.file_errors or []) + ], ) -async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None: - prog.started_at = time.time() +async def _pipeline_for_root( + root: RootType, + prog: ScanProgress, + progress_cb: Optional[Callable[[dict], None]], +) -> None: + state = SLOW_STATE_BY_ROOT.get(root) or SlowQueueState(queue=asyncio.Queue()) + SLOW_STATE_BY_ROOT[root] = state + prog.status = "running" + prog.started_at = time.time() + try: - if root == "models": - await _scan_models(prog) - elif root == "input": - base = folder_paths.get_input_directory() - await _scan_directory_tree(base, root, prog) - elif root == "output": - base = folder_paths.get_output_directory() - await _scan_directory_tree(base, root, prog) - else: - raise RuntimeError(f"Unsupported root: {root}") - prog.status = "completed" + await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb) + _start_slow_workers(root, prog, state, progress_cb=progress_cb) + await _await_workers_then_finish(root, prog, state, progress_cb=progress_cb) except asyncio.CancelledError: prog.status = "cancelled" raise except Exception as exc: - LOGGER.exception("Asset scan failed for %s", root) + _append_error(prog, phase="slow", path="", message=str(exc)) prog.status = "failed" - prog.errors += 1 - prog.last_error = str(exc) - finally: prog.finished_at = time.time() + LOGGER.exception("Asset scan failed for %s", root) + finally: RUNNING_TASKS.pop(root, None) -async def _scan_models(prog: ScanProgress) -> None: +async def _fast_reconcile_into_queue( + root: RootType, + prog: ScanProgress, + state: SlowQueueState, + *, + progress_cb: Optional[Callable[[dict], None]], +) -> None: """ - Scan all configured model buckets from folder_paths.folder_names_and_paths, - restricted to entries whose base paths lie under folder_paths.models_dir - (per get_comfy_models_folders). We trust those mappings and do not try to - infer anything else here. + Enumerate files, set 'discovered' to total files seen, increment 'processed' for fast-matched files, + and queue the rest for slow hashing. """ - targets: list[tuple[str, list[str]]] = get_comfy_models_folders() + if root == "models": + files = _collect_models_files() + preset_discovered = len(files) + files_iter = asyncio.Queue() + for p in files: + await files_iter.put(p) + await files_iter.put(None) # sentinel for our local draining loop + elif root == "input": + base = folder_paths.get_input_directory() + preset_discovered = _count_files_in_tree(os.path.abspath(base)) + files_iter = await _queue_tree_files(base) + elif root == "output": + base = folder_paths.get_output_directory() + preset_discovered = _count_files_in_tree(os.path.abspath(base)) + files_iter = await _queue_tree_files(base) + else: + raise RuntimeError(f"Unsupported root: {root}") + + prog.discovered = int(preset_discovered or 0) + + queued = 0 + checked = 0 + clean = 0 + + # Single session for the whole fast pass + async with await create_session() as sess: + while True: + item = await files_iter.get() + files_iter.task_done() + if item is None: + break + + abs_path = item + checked += 1 + + # Stat; skip empty/unreadable + try: + st = os.stat(abs_path, follow_symlinks=True) + if not st.st_size: + continue + size_bytes = int(st.st_size) + mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + except OSError as e: + _append_error(prog, phase="fast", path=abs_path, message=str(e)) + continue - plans: list[str] = [] # absolute file paths to ingest - per_bucket: dict[str, int] = {} + # Known good -> count as processed immediately + try: + known = await check_fs_asset_exists_quick( + sess, + file_path=abs_path, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + ) + except Exception as e: + _append_error(prog, phase="fast", path=abs_path, message=str(e)) + known = False + + if known: + clean += 1 + prog.processed += 1 # preserve original semantics + else: + await state.queue.put(abs_path) + queued += 1 + prog.slow_queue_total += 1 + + if progress_cb: + progress_cb({ + "root": root, + "phase": "fast", + "checked": checked, + "clean": clean, + "queued": queued, + "discovered": prog.discovered, + "processed": prog.processed, + }) + + prog._fast_total_seen = checked + prog._fast_clean = clean + + if progress_cb: + progress_cb({ + "root": root, + "phase": "fast", + "checked": checked, + "clean": clean, + "queued": queued, + "discovered": prog.discovered, + "processed": prog.processed, + "done": True, + }) + + state.closed = True + + +def _start_slow_workers( + root: RootType, + prog: ScanProgress, + state: SlowQueueState, + *, + progress_cb: Optional[Callable[[dict], None]], +) -> None: + if state.workers: + return - for folder_name, bases in targets: + async def _worker(_worker_id: int): + while True: + item = await state.queue.get() + try: + if item is None: + return + try: + await asyncio.to_thread(assets_manager.populate_db_with_asset, item) + except Exception as e: + _append_error(prog, phase="slow", path=item, message=str(e)) + finally: + # Slow queue finished for this item; also counts toward overall processed + prog.slow_queue_finished += 1 + prog.processed += 1 + if progress_cb: + progress_cb({ + "root": root, + "phase": "slow", + "processed": prog.processed, + "slow_queue_finished": prog.slow_queue_finished, + "slow_queue_total": prog.slow_queue_total, + }) + finally: + state.queue.task_done() + + state.workers = [asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") for i in range(SLOW_HASH_CONCURRENCY)] + + async def _close_when_empty(): + # When the fast phase closed the queue, push sentinels to end workers + while not state.closed: + await asyncio.sleep(0.05) + for _ in range(SLOW_HASH_CONCURRENCY): + await state.queue.put(None) + + asyncio.create_task(_close_when_empty()) + + +async def _await_workers_then_finish( + root: RootType, + prog: ScanProgress, + state: SlowQueueState, + *, + progress_cb: Optional[Callable[[dict], None]], +) -> None: + if state.workers: + await asyncio.gather(*state.workers, return_exceptions=True) + prog.finished_at = time.time() + prog.status = "completed" + if progress_cb: + progress_cb({ + "root": root, + "phase": "slow", + "processed": prog.processed, + "slow_queue_finished": prog.slow_queue_finished, + "slow_queue_total": prog.slow_queue_total, + "done": True, + }) + + +def _collect_models_files() -> list[str]: + """Collect absolute file paths from configured model buckets under models_dir.""" + out: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): rel_files = folder_paths.get_filename_list(folder_name) or [] - count_valid = 0 - for rel_path in rel_files: abs_path = folder_paths.get_full_path(folder_name, rel_path) if not abs_path: continue abs_path = os.path.abspath(abs_path) - - # Extra safety: ensure file is inside one of the allowed base paths + # ensure within allowed bases allowed = False - for base in bases: - base_abs = os.path.abspath(base) + for b in bases: + base_abs = os.path.abspath(b) try: - common = os.path.commonpath([abs_path, base_abs]) - except ValueError: - common = "" # Different drives on Windows - if common == base_abs: - allowed = True - break - if not allowed: - LOGGER.warning("Skipping file outside models base: %s", abs_path) - continue - - try: - if not os.path.getsize(abs_path): - continue # skip empty files - except OSError as e: - LOGGER.warning("Could not stat %s: %s – skipping", abs_path, e) - continue - - plans.append(abs_path) - count_valid += 1 - - if count_valid: - per_bucket[folder_name] = per_bucket.get(folder_name, 0) + count_valid - - prog.discovered = len(plans) - for k, v in per_bucket.items(): - prog.details[k] = prog.details.get(k, 0) + v - - if not plans: - LOGGER.info("Model scan %s: nothing to ingest", prog.scan_id) - return - - sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY) - tasks: list[asyncio.Task] = [] - - for abs_path in plans: - async def worker(fp_abs: str = abs_path): - try: - # Offload sync ingestion into a thread; populate_db_with_asset - # derives name and tags from the path using _assets_helpers. - await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs) - except Exception as e: - prog.errors += 1 - prog.last_error = str(e) - LOGGER.debug("Error ingesting %s: %s", fp_abs, e) - finally: - prog.processed += 1 - sem.release() - - await sem.acquire() - tasks.append(asyncio.create_task(worker())) - - if tasks: - await asyncio.gather(*tasks) - LOGGER.info( - "Model scan %s finished: discovered=%d processed=%d errors=%d", - prog.scan_id, prog.discovered, prog.processed, prog.errors - ) + if os.path.commonpath([abs_path, base_abs]) == base_abs: + allowed = True + break + except Exception: + pass + if allowed: + out.append(abs_path) + return out def _count_files_in_tree(base_abs: str) -> int: @@ -245,60 +412,37 @@ def _count_files_in_tree(base_abs: str) -> int: return total -async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress) -> None: +async def _queue_tree_files(base_dir: str) -> asyncio.Queue: """ - Generic scanner for input/output roots. We pass only the absolute path to - populate_db_with_asset and let it derive the relative name and tags. + Walk base_dir in a worker thread and return a queue prefilled with all paths, + terminated by a single None sentinel for the draining loop in fast reconcile. """ + q: asyncio.Queue = asyncio.Queue() base_abs = os.path.abspath(base_dir) if not os.path.isdir(base_abs): - LOGGER.info("Scan root %s skipped: base directory missing: %s", root, base_abs) - return - - prog.discovered = _count_files_in_tree(base_abs) - - sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY) - tasks: list[asyncio.Task] = [] - for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): - for name in filenames: - abs_path = os.path.abspath(os.path.join(dirpath, name)) - - # Safety: ensure within base - try: - if os.path.commonpath([abs_path, base_abs]) != base_abs: - LOGGER.warning("Skipping path outside root %s: %s", root, abs_path) - continue - except ValueError: - continue - - # Skip empty files and handle stat errors - try: - if not os.path.getsize(abs_path): - continue - except OSError as e: - LOGGER.warning("Could not stat %s: %s – skipping", abs_path, e) - continue - - async def worker(fp_abs: str = abs_path): - try: - await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs) - except Exception as e: - prog.errors += 1 - prog.last_error = str(e) - finally: - prog.processed += 1 - sem.release() - - await sem.acquire() - tasks.append(asyncio.create_task(worker())) - - if tasks: - await asyncio.gather(*tasks) - - LOGGER.info( - "%s scan %s finished: discovered=%d processed=%d errors=%d", - root.capitalize(), prog.scan_id, prog.discovered, prog.processed, prog.errors - ) + await q.put(None) + return q + + def _walk_list(): + paths: list[str] = [] + for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): + for name in filenames: + paths.append(os.path.abspath(os.path.join(dirpath, name))) + return paths + + for p in await asyncio.to_thread(_walk_list): + await q.put(p) + await q.put(None) + return q + + +def _append_error(prog: ScanProgress, *, phase: Literal["fast", "slow"], path: str, message: str) -> None: + prog.file_errors.append({ + "path": path, + "message": message, + "phase": phase, + "at": _ts_to_iso(time.time()), + }) def _ts_to_iso(ts: Optional[float]) -> Optional[str]: diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 7de4adbdc900..5e301b505de0 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -212,7 +212,8 @@ def is_valid_directory(path: str) -> str: os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.") -parser.add_argument("--disable-model-processing", action="store_true", help="Disable automatic processing of the model file, such as calculating hashes and populating the database.") +parser.add_argument("--enable-model-processing", action="store_true", help="Enable automatic processing of the model file, such as calculating hashes and populating the database.") +parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/main.py b/main.py index 557961d40ddb..017f88a63a07 100644 --- a/main.py +++ b/main.py @@ -279,10 +279,35 @@ def cleanup_temp(): shutil.rmtree(temp_dir, ignore_errors=True) async def setup_database(): + def _console_cb(e: dict): + root = e.get("root") + phase = e.get("phase") + if phase == "fast": + if e.get("done"): + logging.info( + f"[assets][{root}] fast done: processed={e['processed']}/{e['discovered']} queued={e['queued']}" + ) + elif e.get("checked", 0) % 500 == 0: # do not spam with fast progress + logging.info(f"[assets][{root}] fast progress: processed={e['processed']}/{e['discovered']}" + ) + elif phase == "slow": + if e.get("done"): + logging.info( + f"[assets][{root}] slow done: {e.get('slow_queue_finished', 0)}/{e.get('slow_queue_total', 0)}" + ) + else: + logging.info( + f"[assets][{root}] slow progress: {e.get('slow_queue_finished', 0)}/{e.get('slow_queue_total', 0)}" + ) + try: from app.database.db import init_db_engine, dependencies_available if dependencies_available(): await init_db_engine() + if not args.disable_assets_autoscan: + from app import assets_scanner + + await assets_scanner.fast_reconcile_and_kickoff(progress_cb=_console_cb) except Exception as e: logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") From 84384ca0b49004f0435ac66105fe5407c963d066 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Fri, 5 Sep 2025 23:02:26 +0300 Subject: [PATCH 27/82] temporary restore ModelManager --- app/model_manager.py | 195 +++++++++++++++++++++++++++++++++++++++++++ server.py | 3 + 2 files changed, 198 insertions(+) create mode 100644 app/model_manager.py diff --git a/app/model_manager.py b/app/model_manager.py new file mode 100644 index 000000000000..ab36bca74414 --- /dev/null +++ b/app/model_manager.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import os +import base64 +import json +import time +import logging +import folder_paths +import glob +import comfy.utils +from aiohttp import web +from PIL import Image +from io import BytesIO +from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types + + +class ModelFileManager: + def __init__(self) -> None: + self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} + + def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: + return self.cache.get(key, default) + + def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]): + self.cache[key] = value + + def clear_cache(self): + self.cache.clear() + + def add_routes(self, routes): + # NOTE: This is an experiment to replace `/models` + @routes.get("/experiment/models") + async def get_model_folders(request): + model_types = list(folder_paths.folder_names_and_paths.keys()) + folder_black_list = ["configs", "custom_nodes"] + output_folders: list[dict] = [] + for folder in model_types: + if folder in folder_black_list: + continue + output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)}) + return web.json_response(output_folders) + + # NOTE: This is an experiment to replace `/models/{folder}` + @routes.get("/experiment/models/{folder}") + async def get_all_models(request): + folder = request.match_info.get("folder", None) + if not folder in folder_paths.folder_names_and_paths: + return web.Response(status=404) + files = self.get_model_file_list(folder) + return web.json_response(files) + + @routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}") + async def get_model_preview(request): + folder_name = request.match_info.get("folder", None) + path_index = int(request.match_info.get("path_index", None)) + filename = request.match_info.get("filename", None) + + if not folder_name in folder_paths.folder_names_and_paths: + return web.Response(status=404) + + folders = folder_paths.folder_names_and_paths[folder_name] + folder = folders[0][path_index] + full_filename = os.path.join(folder, filename) + + previews = self.get_model_previews(full_filename) + default_preview = previews[0] if len(previews) > 0 else None + if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)): + return web.Response(status=404) + + try: + with Image.open(default_preview) as img: + img_bytes = BytesIO() + img.save(img_bytes, format="WEBP") + img_bytes.seek(0) + return web.Response(body=img_bytes.getvalue(), content_type="image/webp") + except: + return web.Response(status=404) + + def get_model_file_list(self, folder_name: str): + folder_name = map_legacy(folder_name) + folders = folder_paths.folder_names_and_paths[folder_name] + output_list: list[dict] = [] + + for index, folder in enumerate(folders[0]): + if not os.path.isdir(folder): + continue + out = self.cache_model_file_list_(folder) + if out is None: + out = self.recursive_search_models_(folder, index) + self.set_cache(folder, out) + output_list.extend(out[0]) + + return output_list + + def cache_model_file_list_(self, folder: str): + model_file_list_cache = self.get_cache(folder) + + if model_file_list_cache is None: + return None + if not os.path.isdir(folder): + return None + if os.path.getmtime(folder) != model_file_list_cache[1]: + return None + for x in model_file_list_cache[1]: + time_modified = model_file_list_cache[1][x] + folder = x + if os.path.getmtime(folder) != time_modified: + return None + + return model_file_list_cache + + def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]: + if not os.path.isdir(directory): + return [], {}, time.perf_counter() + + excluded_dir_names = [".git"] + # TODO use settings + include_hidden_files = False + + result: list[str] = [] + dirs: dict[str, float] = {} + + for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): + subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] + if not include_hidden_files: + subdirs[:] = [d for d in subdirs if not d.startswith(".")] + filenames = [f for f in filenames if not f.startswith(".")] + + filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions) + + for file_name in filenames: + try: + full_path = os.path.join(dirpath, file_name) + relative_path = os.path.relpath(full_path, directory) + + # Get file metadata + file_info = { + "name": relative_path, + "pathIndex": pathIndex, + "modified": os.path.getmtime(full_path), # Add modification time + "created": os.path.getctime(full_path), # Add creation time + "size": os.path.getsize(full_path) # Add file size + } + result.append(file_info) + + except Exception as e: + logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.") + continue + + for d in subdirs: + path: str = os.path.join(dirpath, d) + try: + dirs[path] = os.path.getmtime(path) + except FileNotFoundError: + logging.warning(f"Warning: Unable to access {path}. Skipping this path.") + continue + + return result, dirs, time.perf_counter() + + def get_model_previews(self, filepath: str) -> list[str | BytesIO]: + dirname = os.path.dirname(filepath) + + if not os.path.exists(dirname): + return [] + + basename = os.path.splitext(filepath)[0] + match_files = glob.glob(f"{basename}.*", recursive=False) + image_files = filter_files_content_types(match_files, "image") + safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None) + safetensors_metadata = {} + + result: list[str | BytesIO] = [] + + for filename in image_files: + _basename = os.path.splitext(filename)[0] + if _basename == basename: + result.append(filename) + if _basename == f"{basename}.preview": + result.append(filename) + + if safetensors_file: + safetensors_filepath = os.path.join(dirname, safetensors_file) + header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024) + if header: + safetensors_metadata = json.loads(header) + safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None) + if safetensors_images: + safetensors_images = json.loads(safetensors_images) + for image in safetensors_images: + result.append(BytesIO(base64.b64decode(image))) + + return result + + def __exit__(self, exc_type, exc_value, traceback): + self.clear_cache() diff --git a/server.py b/server.py index 310b7601c778..d3a0f8628c68 100644 --- a/server.py +++ b/server.py @@ -33,6 +33,7 @@ from comfy_api.internal import _ComfyNodeInternal from app.user_manager import UserManager +from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes @@ -154,6 +155,7 @@ def __init__(self, loop): mimetypes.add_type('image/webp', '.webp') self.user_manager = UserManager() + self.model_file_manager = ModelFileManager() self.custom_node_manager = CustomNodeManager() self.internal_routes = InternalRoutes(self) self.supports = ["custom_nodes_from_web"] @@ -762,6 +764,7 @@ async def setup(self): def add_routes(self): self.user_manager.add_routes(self.routes) + self.model_file_manager.add_routes(self.routes) self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items()) self.app.add_subapp('/internal', self.internal_routes.get_app()) From 789a62ce3508e1364795bba9be767417cf0f1899 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sat, 6 Sep 2025 17:43:46 +0300 Subject: [PATCH 28/82] assume that DB packages always present; refactoring & cleanup --- app/__init__.py | 5 +++ app/assets_scanner.py | 74 ++++++++++++++++++++++++++++--------------- app/database/db.py | 47 ++++++--------------------- main.py | 35 +++----------------- 4 files changed, 68 insertions(+), 93 deletions(-) diff --git a/app/__init__.py b/app/__init__.py index e69de29bb2d1..5fade97a49dd 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -0,0 +1,5 @@ +from .database.db import init_db_engine +from .assets_scanner import start_background_assets_scan + + +__all__ = ["init_db_engine", "start_background_assets_scan"] diff --git a/app/assets_scanner.py b/app/assets_scanner.py index ccfc8e9e5efe..86e8b23cd079 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -55,8 +55,8 @@ class SlowQueueState: SLOW_STATE_BY_ROOT: dict[RootType, SlowQueueState] = {} -def _new_scan_id(root: RootType) -> str: - return f"scan-{root}-{uuid.uuid4().hex[:8]}" +async def start_background_assets_scan(): + await fast_reconcile_and_kickoff(progress_cb=_console_cb) def current_statuses() -> schemas_out.AssetScanStatusResponse: @@ -108,7 +108,7 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes async def fast_reconcile_and_kickoff( roots: Sequence[str] | None = None, *, - progress_cb: Optional[Callable[[dict], None]] = None, + progress_cb: Optional[Callable[[str, str, int, bool, dict], None]] = None, ) -> schemas_out.AssetScanStatusResponse: """ Startup helper: do the fast pass now (so we know queue size), @@ -179,7 +179,7 @@ def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.A async def _pipeline_for_root( root: RootType, prog: ScanProgress, - progress_cb: Optional[Callable[[dict], None]], + progress_cb: Optional[Callable[[str, str, int, bool, dict], None]], ) -> None: state = SLOW_STATE_BY_ROOT.get(root) or SlowQueueState(queue=asyncio.Queue()) SLOW_STATE_BY_ROOT[root] = state @@ -208,7 +208,7 @@ async def _fast_reconcile_into_queue( prog: ScanProgress, state: SlowQueueState, *, - progress_cb: Optional[Callable[[dict], None]], + progress_cb: Optional[Callable[[str, str, int, bool, dict], None]], ) -> None: """ Enumerate files, set 'discovered' to total files seen, increment 'processed' for fast-matched files, @@ -281,29 +281,22 @@ async def _fast_reconcile_into_queue( prog.slow_queue_total += 1 if progress_cb: - progress_cb({ - "root": root, - "phase": "fast", + progress_cb(root, "fast", prog.processed, False, { "checked": checked, "clean": clean, "queued": queued, "discovered": prog.discovered, - "processed": prog.processed, }) prog._fast_total_seen = checked prog._fast_clean = clean if progress_cb: - progress_cb({ - "root": root, - "phase": "fast", + progress_cb(root, "fast", prog.processed, True, { "checked": checked, "clean": clean, "queued": queued, "discovered": prog.discovered, - "processed": prog.processed, - "done": True, }) state.closed = True @@ -314,7 +307,7 @@ def _start_slow_workers( prog: ScanProgress, state: SlowQueueState, *, - progress_cb: Optional[Callable[[dict], None]], + progress_cb: Optional[Callable[[str, str, int, bool, dict], None]], ) -> None: if state.workers: return @@ -334,10 +327,7 @@ async def _worker(_worker_id: int): prog.slow_queue_finished += 1 prog.processed += 1 if progress_cb: - progress_cb({ - "root": root, - "phase": "slow", - "processed": prog.processed, + progress_cb(root, "slow", prog.processed, False, { "slow_queue_finished": prog.slow_queue_finished, "slow_queue_total": prog.slow_queue_total, }) @@ -361,20 +351,16 @@ async def _await_workers_then_finish( prog: ScanProgress, state: SlowQueueState, *, - progress_cb: Optional[Callable[[dict], None]], + progress_cb: Optional[Callable[[str, str, int, bool, dict], None]], ) -> None: if state.workers: await asyncio.gather(*state.workers, return_exceptions=True) prog.finished_at = time.time() prog.status = "completed" if progress_cb: - progress_cb({ - "root": root, - "phase": "slow", - "processed": prog.processed, + progress_cb(root, "slow", prog.processed, True, { "slow_queue_finished": prog.slow_queue_finished, "slow_queue_total": prog.slow_queue_total, - "done": True, }) @@ -453,3 +439,41 @@ def _ts_to_iso(ts: Optional[float]) -> Optional[str]: return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat() except Exception: return None + + +def _new_scan_id(root: RootType) -> str: + return f"scan-{root}-{uuid.uuid4().hex[:8]}" + + +def _console_cb(root: str, phase: str, total_processed: int, finished: bool, e: dict): + if phase == "fast": + if finished: + logging.info( + "[assets][%s] fast done: processed=%s/%s queued=%s", + root, + total_processed, + e["discovered"], + e["queued"], + ) + elif e.get("checked", 0) % 500 == 0: # do not spam with fast progress + logging.info( + "[assets][%s] fast progress: processed=%s/%s", + root, + total_processed, + e["discovered"], + ) + elif phase == "slow": + if finished: + logging.info( + "[assets][%s] slow done: %s/%s", + root, + e.get("slow_queue_finished", 0), + e.get("slow_queue_total", 0), + ) + elif e.get('slow_queue_finished', 0) % 3 == 0: + logging.info( + "[assets][%s] slow progress: %s/%s", + root, + e.get("slow_queue_finished", 0), + e.get("slow_queue_total", 0), + ) diff --git a/app/database/db.py b/app/database/db.py index 2a619f13b751..67ddf412ba3b 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -4,44 +4,18 @@ from contextlib import asynccontextmanager from typing import Optional -from app.logger import log_startup_warning -from utils.install_util import get_missing_requirements_message from comfy.cli_args import args - +from alembic import command +from alembic.config import Config +from alembic.runtime.migration import MigrationContext +from alembic.script import ScriptDirectory +from sqlalchemy import create_engine, text +from sqlalchemy.engine import make_url +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine LOGGER = logging.getLogger(__name__) - -# Attempt imports which may not exist in some environments -try: - from alembic import command - from alembic.config import Config - from alembic.runtime.migration import MigrationContext - from alembic.script import ScriptDirectory - from sqlalchemy import create_engine, text - from sqlalchemy.engine import make_url - from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine - - _DB_AVAILABLE = True - ENGINE: AsyncEngine | None = None - SESSION: async_sessionmaker | None = None -except ImportError as e: - log_startup_warning( - ( - "------------------------------------------------------------------------\n" - f"Error importing DB dependencies: {e}\n" - f"{get_missing_requirements_message()}\n" - "This error is happening because ComfyUI now uses a local database.\n" - "------------------------------------------------------------------------" - ).strip() - ) - _DB_AVAILABLE = False - ENGINE = None - SESSION = None - - -def dependencies_available() -> bool: - """Check if DB dependencies are importable.""" - return _DB_AVAILABLE +ENGINE: Optional[AsyncEngine] = None +SESSION: Optional[async_sessionmaker] = None def _root_paths(): @@ -115,9 +89,6 @@ async def init_db_engine() -> None: """ global ENGINE, SESSION - if not dependencies_available(): - raise RuntimeError("Database dependencies are not available.") - if ENGINE is not None: return diff --git a/main.py b/main.py index 017f88a63a07..3485a7c76c1d 100644 --- a/main.py +++ b/main.py @@ -279,37 +279,12 @@ def cleanup_temp(): shutil.rmtree(temp_dir, ignore_errors=True) async def setup_database(): - def _console_cb(e: dict): - root = e.get("root") - phase = e.get("phase") - if phase == "fast": - if e.get("done"): - logging.info( - f"[assets][{root}] fast done: processed={e['processed']}/{e['discovered']} queued={e['queued']}" - ) - elif e.get("checked", 0) % 500 == 0: # do not spam with fast progress - logging.info(f"[assets][{root}] fast progress: processed={e['processed']}/{e['discovered']}" - ) - elif phase == "slow": - if e.get("done"): - logging.info( - f"[assets][{root}] slow done: {e.get('slow_queue_finished', 0)}/{e.get('slow_queue_total', 0)}" - ) - else: - logging.info( - f"[assets][{root}] slow progress: {e.get('slow_queue_finished', 0)}/{e.get('slow_queue_total', 0)}" - ) + from app import init_db_engine, start_background_assets_scan + + await init_db_engine() + if not args.disable_assets_autoscan: + await start_background_assets_scan() - try: - from app.database.db import init_db_engine, dependencies_available - if dependencies_available(): - await init_db_engine() - if not args.disable_assets_autoscan: - from app import assets_scanner - - await assets_scanner.fast_reconcile_and_kickoff(progress_cb=_console_cb) - except Exception as e: - logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") def start_comfyui(asyncio_loop=None): """ From 2d9be462d3e0bd66f36906b9c596b8f44445f915 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sat, 6 Sep 2025 19:22:51 +0300 Subject: [PATCH 29/82] add support for assets duplicates --- alembic_db/versions/0001_assets.py | 13 +++--- app/_assets_helpers.py | 2 +- app/assets_fetcher.py | 29 +++++++------ app/assets_scanner.py | 51 +++++++++++++++-------- app/database/models.py | 16 +++---- app/database/services.py | 67 +++++++++++++++++++++--------- 6 files changed, 116 insertions(+), 62 deletions(-) diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index c80874aa20ea..681af26355d2 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -1,5 +1,4 @@ -# File: /alembic_db/versions/0001_assets.py -"""initial assets schema + per-asset state cache +"""initial assets schema Revision ID: 0001_assets Revises: @@ -69,15 +68,18 @@ def upgrade() -> None: op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) - # ASSET_CACHE_STATE: 1:1 local cache metadata for an Asset + # ASSET_CACHE_STATE: N:1 local cache metadata rows per Asset op.create_table( "asset_cache_state", - sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True), + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False), sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file sa.Column("mtime_ns", sa.BigInteger(), nullable=True), sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), ) op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) + op.create_index("ix_asset_cache_state_asset_hash", "asset_cache_state", ["asset_hash"]) # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting op.create_table( @@ -144,7 +146,7 @@ def upgrade() -> None: {"name": "photomaker", "tag_type": "system"}, {"name": "classifiers", "tag_type": "system"}, - # Extra basic tags (used for vae_approx, ...) + # Extra basic tags {"name": "encoder", "tag_type": "system"}, {"name": "decoder", "tag_type": "system"}, ], @@ -162,6 +164,7 @@ def downgrade() -> None: op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") op.drop_table("asset_info_meta") + op.drop_index("ix_asset_cache_state_asset_hash", table_name="asset_cache_state") op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") op.drop_table("asset_cache_state") diff --git a/app/_assets_helpers.py b/app/_assets_helpers.py index ddc43f1eaa4e..8fb88cd34e19 100644 --- a/app/_assets_helpers.py +++ b/app/_assets_helpers.py @@ -147,7 +147,7 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: return AssetInfo.owner_id.in_(["", owner_id]) -def compute_model_relative_filename(file_path: str) -> str | None: +def compute_model_relative_filename(file_path: str) -> Optional[str]: """ Return the model's path relative to the last well-known folder (the model category), using forward slashes, eg: diff --git a/app/assets_fetcher.py b/app/assets_fetcher.py index ea1c8ed00420..36fa64ca954b 100644 --- a/app/assets_fetcher.py +++ b/app/assets_fetcher.py @@ -8,7 +8,7 @@ from .storage.hashing import blake3_hash_sync from .database.db import create_session -from .database.services import ingest_fs_asset, get_cache_state_by_asset_hash +from .database.services import ingest_fs_asset, list_cache_states_by_asset_hash from .resolvers import resolve_asset from ._assets_helpers import resolve_destination_from_tags, ensure_within_base @@ -26,20 +26,25 @@ async def ensure_asset_cached( tags_hint: Optional[list[str]] = None, ) -> str: """ - Ensure there is a verified local file for `asset_hash` in the correct Comfy folder. - Policy: - - Resolver must provide valid tags (root and, for models, category). - - If target path already exists: - * if hash matches -> reuse & ingest - * else -> remove and overwrite with the correct content + Ensure there is a verified local file for asset_hash in the correct Comfy folder. + + Fast path: + - If any cache_state row has a file_path that exists, return it immediately. + Preference order is the oldest ID first for stability. + + Slow path: + - Resolve remote location + placement tags. + - Download to the correct folder, verify hash, move into place. + - Ingest identity + cache state so future fast passes can skip hashing. """ lock = _FETCH_LOCKS.setdefault(asset_hash, asyncio.Lock()) async with lock: - # 1) If we already have a state -> trust the path + # 1) If we already have any cache_state path present on disk, use it (oldest-first) async with await create_session() as sess: - state = await get_cache_state_by_asset_hash(sess, asset_hash=asset_hash) - if state and os.path.isfile(state.file_path): - return state.file_path + states = await list_cache_states_by_asset_hash(sess, asset_hash=asset_hash) + for s in states: + if s and s.file_path and os.path.isfile(s.file_path): + return s.file_path # 2) Resolve remote location + placement hints (must include valid tags) res = await resolve_asset(asset_hash) @@ -107,7 +112,7 @@ async def ensure_asset_cached( finally: raise ValueError(f"Hash mismatch: expected {asset_hash}, got {canonical}") - # 7) Atomically move into place (we already removed an invalid file if it existed) + # 7) Atomically move into place if os.path.exists(final_path): os.remove(final_path) os.replace(tmp_path, final_path) diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 86e8b23cd079..42cf123d26f6 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import logging import os import uuid @@ -106,7 +107,7 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes async def fast_reconcile_and_kickoff( - roots: Sequence[str] | None = None, + roots: Optional[Sequence[str]] = None, *, progress_cb: Optional[Callable[[str, str, int, bool, dict], None]] = None, ) -> schemas_out.AssetScanStatusResponse: @@ -216,18 +217,18 @@ async def _fast_reconcile_into_queue( """ if root == "models": files = _collect_models_files() - preset_discovered = len(files) + preset_discovered = _count_nonzero_in_list(files) files_iter = asyncio.Queue() for p in files: await files_iter.put(p) await files_iter.put(None) # sentinel for our local draining loop elif root == "input": base = folder_paths.get_input_directory() - preset_discovered = _count_files_in_tree(os.path.abspath(base)) + preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True) files_iter = await _queue_tree_files(base) elif root == "output": base = folder_paths.get_output_directory() - preset_discovered = _count_files_in_tree(os.path.abspath(base)) + preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True) files_iter = await _queue_tree_files(base) else: raise RuntimeError(f"Unsupported root: {root}") @@ -378,26 +379,41 @@ def _collect_models_files() -> list[str]: allowed = False for b in bases: base_abs = os.path.abspath(b) - try: + with contextlib.suppress(Exception): if os.path.commonpath([abs_path, base_abs]) == base_abs: allowed = True break - except Exception: - pass if allowed: out.append(abs_path) return out -def _count_files_in_tree(base_abs: str) -> int: +def _count_files_in_tree(base_abs: str, *, only_nonzero: bool = False) -> int: if not os.path.isdir(base_abs): return 0 total = 0 - for _dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): - total += len(filenames) + for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): + if not only_nonzero: + total += len(filenames) + else: + for name in filenames: + with contextlib.suppress(OSError): + st = os.stat(os.path.join(dirpath, name), follow_symlinks=True) + if st.st_size: + total += 1 return total +def _count_nonzero_in_list(paths: list[str]) -> int: + cnt = 0 + for p in paths: + with contextlib.suppress(OSError): + st = os.stat(p, follow_symlinks=True) + if st.st_size: + cnt += 1 + return cnt + + async def _queue_tree_files(base_dir: str) -> asyncio.Queue: """ Walk base_dir in a worker thread and return a queue prefilled with all paths, @@ -455,7 +471,7 @@ def _console_cb(root: str, phase: str, total_processed: int, finished: bool, e: e["discovered"], e["queued"], ) - elif e.get("checked", 0) % 500 == 0: # do not spam with fast progress + elif e.get("checked", 0) % 1000 == 0: # do not spam with fast progress logging.info( "[assets][%s] fast progress: processed=%s/%s", root, @@ -464,12 +480,13 @@ def _console_cb(root: str, phase: str, total_processed: int, finished: bool, e: ) elif phase == "slow": if finished: - logging.info( - "[assets][%s] slow done: %s/%s", - root, - e.get("slow_queue_finished", 0), - e.get("slow_queue_total", 0), - ) + if e.get("slow_queue_finished", 0) or e.get("slow_queue_total", 0): + logging.info( + "[assets][%s] slow done: %s/%s", + root, + e.get("slow_queue_finished", 0), + e.get("slow_queue_total", 0), + ) elif e.get('slow_queue_finished', 0) % 3 == 0: logging.info( "[assets][%s] slow progress: %s/%s", diff --git a/app/database/models.py b/app/database/models.py index 47f8bbaf35e5..2038674681fa 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from typing import Any, Optional import uuid @@ -66,9 +68,8 @@ class Asset(Base): viewonly=True, ) - cache_state: Mapped["AssetCacheState | None"] = relationship( + cache_states: Mapped[list["AssetCacheState"]] = relationship( back_populates="asset", - uselist=False, cascade="all, delete-orphan", passive_deletes=True, ) @@ -93,24 +94,25 @@ def __repr__(self) -> str: class AssetCacheState(Base): __tablename__ = "asset_cache_state" - asset_hash: Mapped[str] = mapped_column( - String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True - ) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False) file_path: Mapped[str] = mapped_column(Text, nullable=False) mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) - asset: Mapped["Asset"] = relationship(back_populates="cache_state", uselist=False) + asset: Mapped["Asset"] = relationship(back_populates="cache_states") __table_args__ = ( Index("ix_asset_cache_state_file_path", "file_path"), + Index("ix_asset_cache_state_asset_hash", "asset_hash"), CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), ) def to_dict(self, include_none: bool = False) -> dict[str, Any]: return to_dict(self, include_none=include_none) def __repr__(self) -> str: - return f"" + return f"" class AssetLocation(Base): diff --git a/app/database/services.py b/app/database/services.py index af8861001b81..94a9b7016c88 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -4,7 +4,7 @@ from collections import defaultdict from datetime import datetime from decimal import Decimal -from typing import Any, Sequence, Optional, Iterable +from typing import Any, Sequence, Optional, Iterable, Union import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession @@ -82,14 +82,14 @@ async def ingest_fs_asset( require_existing_tags: bool = False, ) -> dict: """ - Upsert Asset identity row + cache state pointing at local file. + Upsert Asset identity row + cache state(s) pointing at local file. Always: - Insert Asset if missing; - - Insert AssetCacheState if missing; else update mtime_ns if different. + - Insert AssetCacheState if missing; else update mtime_ns and asset_hash if different. Optionally (when info_name is provided): - - Create an AssetInfo. + - Create or update an AssetInfo on (asset_hash, owner_id, name). - Link provided tags to that AssetInfo. * If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table. * If False (default), create unknown tags. @@ -157,11 +157,16 @@ async def ingest_fs_asset( out["state_created"] = True if not out["state_created"]: - state = await session.get(AssetCacheState, asset_hash) + # most likely a unique(file_path) conflict; update that row + state = ( + await session.execute( + select(AssetCacheState).where(AssetCacheState.file_path == locator).limit(1) + ) + ).scalars().first() if state is not None: changed = False - if state.file_path != locator: - state.file_path = locator + if state.asset_hash != asset_hash: + state.asset_hash = asset_hash changed = True if state.mtime_ns != int(mtime_ns): state.mtime_ns = int(mtime_ns) @@ -260,7 +265,15 @@ async def ingest_fs_asset( # ) # start of adding metadata["filename"] if out["asset_info_id"] is not None: - computed_filename = compute_model_relative_filename(abs_path) + primary_path = ( + await session.execute( + select(AssetCacheState.file_path) + .where(AssetCacheState.asset_hash == asset_hash) + .order_by(AssetCacheState.id.asc()) + .limit(1) + ) + ).scalars().first() + computed_filename = compute_model_relative_filename(primary_path) if primary_path else None # Start from current metadata on this AssetInfo, if any current_meta = existing_info.user_metadata or {} @@ -366,7 +379,6 @@ async def list_asset_infos_page( base = _apply_tag_filters(base, include_tags, exclude_tags) base = _apply_metadata_filter(base, metadata_filter) - # Sort sort = (sort or "created_at").lower() order = (order or "desc").lower() sort_map = { @@ -381,7 +393,6 @@ async def list_asset_infos_page( base = base.order_by(sort_exp).limit(limit).offset(offset) - # Total count (same filters, no ordering/limit/offset) count_stmt = ( select(func.count()) .select_from(AssetInfo) @@ -395,10 +406,9 @@ async def list_asset_infos_page( total = int((await session.execute(count_stmt)).scalar_one() or 0) - # Fetch rows infos = (await session.execute(base)).scalars().unique().all() - # Collect tags in bulk (single query) + # Collect tags in bulk id_list: list[str] = [i.id for i in infos] tag_map: dict[str, list[str]] = defaultdict(list) if id_list: @@ -470,12 +480,33 @@ async def fetch_asset_info_asset_and_tags( async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]: - return await session.get(AssetCacheState, asset_hash) + """Return the oldest cache row for this asset.""" + return ( + await session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_hash == asset_hash) + .order_by(AssetCacheState.id.asc()) + .limit(1) + ) + ).scalars().first() + + +async def list_cache_states_by_asset_hash( + session: AsyncSession, *, asset_hash: str +) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]: + """Return all cache rows for this asset ordered by oldest first.""" + return ( + await session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_hash == asset_hash) + .order_by(AssetCacheState.id.asc()) + ) + ).scalars().all() async def list_asset_locations( session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None -) -> list[AssetLocation] | Sequence[AssetLocation]: +) -> Union[list[AssetLocation], Sequence[AssetLocation]]: stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash) if provider: stmt = stmt.where(AssetLocation.provider == provider) @@ -815,7 +846,6 @@ async def list_tags_with_usage( if not include_zero: q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) - # Ordering if order == "name_asc": q = q.order_by(Tag.name.asc()) else: # default "count_desc" @@ -990,6 +1020,7 @@ def _apply_tag_filters( ) return stmt + def _apply_metadata_filter( stmt: sa.sql.Select, metadata_filter: Optional[dict], @@ -1050,7 +1081,7 @@ def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: for k, v in metadata_filter.items(): if isinstance(v, list): # ANY-of (exists for any element) - ors = [ _exists_clause_for_value(k, elem) for elem in v ] + ors = [_exists_clause_for_value(k, elem) for elem in v] if ors: stmt = stmt.where(sa.or_(*ors)) else: @@ -1079,12 +1110,10 @@ def _project_kv(key: str, value: Any) -> list[dict]: """ rows: list[dict] = [] - # None if value is None: rows.append({"key": key, "ordinal": 0, "val_json": None}) return rows - # Scalars if _is_scalar(value): if isinstance(value, bool): rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) @@ -1099,9 +1128,7 @@ def _project_kv(key: str, value: Any) -> list[dict]: rows.append({"key": key, "ordinal": 0, "val_json": value}) return rows - # Lists if isinstance(value, list): - # list of scalars? if all(_is_scalar(x) for x in value): for i, x in enumerate(value): if x is None: From b8ef9bb92c41f62e940848db713cdbb681c5c54f Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 7 Sep 2025 16:49:39 +0300 Subject: [PATCH 30/82] add detection of the missing files for existing assets --- alembic_db/versions/0001_assets.py | 3 + app/assets_scanner.py | 58 ++++++- app/database/_helpers.py | 183 ++++++++++++++++++++ app/database/services.py | 266 ++++++++++------------------- 4 files changed, 334 insertions(+), 176 deletions(-) create mode 100644 app/database/_helpers.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index 681af26355d2..9481100b0322 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -149,6 +149,9 @@ def upgrade() -> None: # Extra basic tags {"name": "encoder", "tag_type": "system"}, {"name": "decoder", "tag_type": "system"}, + + # Special tags + {"name": "missing", "tag_type": "system"}, ], ) diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 42cf123d26f6..33efbf047e89 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -14,7 +14,12 @@ from .api import schemas_out from ._assets_helpers import get_comfy_models_folders from .database.db import create_session -from .database.services import check_fs_asset_exists_quick +from .database.services import ( + check_fs_asset_exists_quick, + list_cache_states_under_prefixes, + add_missing_tag_for_asset_hash, + remove_missing_tag_for_asset_hash, +) LOGGER = logging.getLogger(__name__) @@ -239,7 +244,6 @@ async def _fast_reconcile_into_queue( checked = 0 clean = 0 - # Single session for the whole fast pass async with await create_session() as sess: while True: item = await files_iter.get() @@ -261,7 +265,6 @@ async def _fast_reconcile_into_queue( _append_error(prog, phase="fast", path=abs_path, message=str(e)) continue - # Known good -> count as processed immediately try: known = await check_fs_asset_exists_quick( sess, @@ -275,7 +278,7 @@ async def _fast_reconcile_into_queue( if known: clean += 1 - prog.processed += 1 # preserve original semantics + prog.processed += 1 else: await state.queue.put(abs_path) queued += 1 @@ -300,9 +303,56 @@ async def _fast_reconcile_into_queue( "discovered": prog.discovered, }) + await _reconcile_missing_tags_for_root(root, prog) state.closed = True +async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) -> None: + """ + For every AssetCacheState under the root's base directories: + - if at least one recorded file_path exists for a hash -> remove 'missing' + - if none of the recorded file_paths exist for a hash -> add 'missing' + """ + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + elif root == "input": + bases = [folder_paths.get_input_directory()] + else: + bases = [folder_paths.get_output_directory()] + + try: + async with await create_session() as sess: + states = await list_cache_states_under_prefixes(sess, prefixes=bases) + + present: set[str] = set() + missing: set[str] = set() + + for s in states: + try: + if os.path.isfile(s.file_path): + present.add(s.asset_hash) + else: + missing.add(s.asset_hash) + except Exception as e: + _append_error(prog, phase="fast", path=s.file_path, message=f"stat error: {e}") + + only_missing = missing - present + + for h in present: + with contextlib.suppress(Exception): + await remove_missing_tag_for_asset_hash(sess, asset_hash=h) + + for h in only_missing: + with contextlib.suppress(Exception): + await add_missing_tag_for_asset_hash(sess, asset_hash=h, origin="automatic") + + await sess.commit() + except Exception as e: + _append_error(prog, phase="fast", path="", message=f"missing-tag reconcile failed: {e}") + + def _start_slow_workers( root: RootType, prog: ScanProgress, diff --git a/app/database/_helpers.py b/app/database/_helpers.py new file mode 100644 index 000000000000..5ce97207691e --- /dev/null +++ b/app/database/_helpers.py @@ -0,0 +1,183 @@ +from decimal import Decimal +from typing import Any, Sequence, Optional, Iterable + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, exists + +from .models import AssetInfo, AssetInfoTag, Tag, AssetInfoMeta +from .._assets_helpers import normalize_tags + + +async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: + wanted = normalize_tags(list(names)) + if not wanted: + return [] + existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() + by_name = {t.name: t for t in existing} + to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name] + if to_create: + session.add_all(to_create) + await session.flush() + by_name.update({t.name: t for t in to_create}) + return [by_name[n] for n in wanted] + + +def apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Optional[Sequence[str]], + exclude_tags: Optional[Sequence[str]], +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + + +def apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: Optional[dict], +) -> sa.sql.Select: + """Apply metadata filters using the projection table asset_info_meta. + + Semantics: + - For scalar values: require EXISTS(asset_info_meta) with matching key + typed value. + - For None: key is missing OR key has explicit null (val_json IS NULL). + - For list values: ANY-of the list elements matches (EXISTS for any). + (Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))') + """ + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + subquery = ( + select(sa.literal(1)) + .select_from(AssetInfoMeta) + .where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + *preds, + ) + .limit(1) + ) + return sa.exists(subquery) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + # Missing OR null: + if value is None: + # either: no row for key OR a row for key with explicit null + no_row_for_key = ~sa.exists( + select(sa.literal(1)) + .select_from(AssetInfoMeta) + .where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + ) + .limit(1) + ) + null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None)) + return sa.or_(no_row_for_key, null_row) + + # Typed scalar matches: + if isinstance(value, bool): + return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) + if isinstance(value, (int, float, Decimal)): + # store as Decimal for equality against NUMERIC(38,10) + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetInfoMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetInfoMeta.val_str == value) + + # Complex: compare JSON (no index, but supported) + return _exists_for_pred(key, AssetInfoMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + # ANY-of (exists for any element) + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt + + +def is_scalar(v: Any) -> bool: + if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + + +def project_kv(key: str, value: Any) -> list[dict]: + """ + Turn a metadata key/value into one or more projection rows: + - scalar -> one row (ordinal=0) in the proper typed column + - list of scalars -> one row per element with ordinal=i + - dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list) + - None -> single row with val_json = None + Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...} + """ + rows: list[dict] = [] + + if value is None: + rows.append({"key": key, "ordinal": 0, "val_json": None}) + return rows + + if is_scalar(value): + if isinstance(value, bool): + rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) + elif isinstance(value, (int, float, Decimal)): + # store numeric; SQLAlchemy will coerce to Numeric + num = value if isinstance(value, Decimal) else Decimal(str(value)) + rows.append({"key": key, "ordinal": 0, "val_num": num}) + elif isinstance(value, str): + rows.append({"key": key, "ordinal": 0, "val_str": value}) + else: + # Fallback to json + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows + + if isinstance(value, list): + if all(is_scalar(x) for x in value): + for i, x in enumerate(value): + if x is None: + rows.append({"key": key, "ordinal": i, "val_json": None}) + elif isinstance(x, bool): + rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) + elif isinstance(x, (int, float, Decimal)): + num = x if isinstance(x, Decimal) else Decimal(str(x)) + rows.append({"key": key, "ordinal": i, "val_num": num}) + elif isinstance(x, str): + rows.append({"key": key, "ordinal": i, "val_str": x}) + else: + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + # list contains objects -> one val_json per element + for i, x in enumerate(value): + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + + # Dict or any other structure -> single json row + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows diff --git a/app/database/services.py b/app/database/services.py index 94a9b7016c88..ceed3749a882 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -3,18 +3,18 @@ import logging from collections import defaultdict from datetime import datetime -from decimal import Decimal -from typing import Any, Sequence, Optional, Iterable, Union +from typing import Any, Sequence, Optional, Union import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, delete, exists, func +from sqlalchemy import select, delete, func from sqlalchemy.orm import contains_eager, noload from sqlalchemy.exc import IntegrityError from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation from .timeutil import utcnow from .._assets_helpers import normalize_tags, visible_owner_clause, compute_model_relative_filename +from . import _helpers async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: @@ -221,7 +221,7 @@ async def ingest_fs_asset( norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] if norm and out["asset_info_id"] is not None: if not require_existing_tags: - await _ensure_tags_exist(session, norm, tag_type="user") + await _helpers.ensure_tags_exist(session, norm, tag_type="user") # Which tags exist? existing_tag_names = set( @@ -296,6 +296,10 @@ async def ingest_fs_asset( user_metadata=new_meta, ) # end of adding metadata["filename"] + try: + await remove_missing_tag_for_asset_hash(session, asset_hash=asset_hash) + except Exception: + logging.exception("Failed to clear 'missing' tag for %s", asset_hash) return out @@ -376,8 +380,8 @@ async def list_asset_infos_page( if name_contains: base = base.where(AssetInfo.name.ilike(f"%{name_contains}%")) - base = _apply_tag_filters(base, include_tags, exclude_tags) - base = _apply_metadata_filter(base, metadata_filter) + base = _helpers.apply_tag_filters(base, include_tags, exclude_tags) + base = _helpers.apply_metadata_filter(base, metadata_filter) sort = (sort or "created_at").lower() order = (order or "desc").lower() @@ -401,8 +405,8 @@ async def list_asset_infos_page( ) if name_contains: count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%")) - count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) - count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) + count_stmt = _helpers.apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = _helpers.apply_metadata_filter(count_stmt, metadata_filter) total = int((await session.execute(count_stmt)).scalar_one() or 0) @@ -646,7 +650,7 @@ async def set_asset_info_tags( to_remove = [t for t in current if t not in desired] if to_add: - await _ensure_tags_exist(session, to_add, tag_type="user") + await _helpers.ensure_tags_exist(session, to_add, tag_type="user") session.add_all([ AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) for t in to_add @@ -776,7 +780,7 @@ async def replace_asset_info_metadata_projection( rows: list[AssetInfoMeta] = [] for k, v in user_metadata.items(): - for r in _project_kv(k, v): + for r in _helpers.project_kv(k, v): rows.append( AssetInfoMeta( asset_info_id=asset_info_id, @@ -894,7 +898,7 @@ async def add_tags_to_asset_info( # Ensure tag rows exist if requested. if create_if_missing: - await _ensure_tags_exist(session, norm, tag_type="user") + await _helpers.ensure_tags_exist(session, norm, tag_type="user") # Snapshot current links current = { @@ -979,175 +983,93 @@ async def remove_tags_from_asset_info( return {"removed": to_remove, "not_present": not_present, "total_tags": total} -async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: - wanted = normalize_tags(list(names)) - if not wanted: - return [] - existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() - by_name = {t.name: t for t in existing} - to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name] - if to_create: - session.add_all(to_create) - await session.flush() - by_name.update({t.name: t for t in to_create}) - return [by_name[n] for n in wanted] - - -def _apply_tag_filters( - stmt: sa.sql.Select, - include_tags: Optional[Sequence[str]], - exclude_tags: Optional[Sequence[str]], -) -> sa.sql.Select: - """include_tags: every tag must be present; exclude_tags: none may be present.""" - include_tags = normalize_tags(include_tags) - exclude_tags = normalize_tags(exclude_tags) - - if include_tags: - for tag_name in include_tags: - stmt = stmt.where( - exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name == tag_name) - ) - ) +async def add_missing_tag_for_asset_hash( + session: AsyncSession, + *, + asset_hash: str, + origin: str = "automatic", +) -> int: + """Ensure every AssetInfo referencing asset_hash has the 'missing' tag. + Returns number of AssetInfos newly tagged. + """ + ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_hash == asset_hash))).scalars().all() + if not ids: + return 0 - if exclude_tags: - stmt = stmt.where( - ~exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name.in_(exclude_tags)) + existing = { + asset_info_id + for (asset_info_id,) in ( + await session.execute( + select(AssetInfoTag.asset_info_id).where( + AssetInfoTag.asset_info_id.in_(ids), + AssetInfoTag.tag_name == "missing", + ) ) - ) - return stmt + ).all() + } + to_add = [i for i in ids if i not in existing] + if not to_add: + return 0 + now = utcnow() + session.add_all( + [ + AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now) + for i in to_add + ] + ) + await session.flush() + return len(to_add) -def _apply_metadata_filter( - stmt: sa.sql.Select, - metadata_filter: Optional[dict], -) -> sa.sql.Select: - """Apply metadata filters using the projection table asset_info_meta. - Semantics: - - For scalar values: require EXISTS(asset_info_meta) with matching key + typed value. - - For None: key is missing OR key has explicit null (val_json IS NULL). - - For list values: ANY-of the list elements matches (EXISTS for any). - (Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))') +async def remove_missing_tag_for_asset_hash( + session: AsyncSession, + *, + asset_hash: str, +) -> int: + """Remove the 'missing' tag from every AssetInfo referencing asset_hash. + Returns number of link rows removed. """ - if not metadata_filter: - return stmt - - def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: - subquery = ( - select(sa.literal(1)) - .select_from(AssetInfoMeta) - .where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - *preds, - ) - .limit(1) + ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_hash == asset_hash))).scalars().all() + if not ids: + return 0 + + res = await session.execute( + delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(ids), + AssetInfoTag.tag_name == "missing", ) - return sa.exists(subquery) - - def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: - # Missing OR null: - if value is None: - # either: no row for key OR a row for key with explicit null - no_row_for_key = ~sa.exists( - select(sa.literal(1)) - .select_from(AssetInfoMeta) - .where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - ) - .limit(1) - ) - null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None)) - return sa.or_(no_row_for_key, null_row) - - # Typed scalar matches: - if isinstance(value, bool): - return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) - if isinstance(value, (int, float, Decimal)): - # store as Decimal for equality against NUMERIC(38,10) - num = value if isinstance(value, Decimal) else Decimal(str(value)) - return _exists_for_pred(key, AssetInfoMeta.val_num == num) - if isinstance(value, str): - return _exists_for_pred(key, AssetInfoMeta.val_str == value) - - # Complex: compare JSON (no index, but supported) - return _exists_for_pred(key, AssetInfoMeta.val_json == value) - - for k, v in metadata_filter.items(): - if isinstance(v, list): - # ANY-of (exists for any element) - ors = [_exists_clause_for_value(k, elem) for elem in v] - if ors: - stmt = stmt.where(sa.or_(*ors)) - else: - stmt = stmt.where(_exists_clause_for_value(k, v)) - return stmt + ) + await session.flush() + return int(res.rowcount or 0) -def _is_scalar(v: Any) -> bool: - if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries - return True - if isinstance(v, bool): - return True - if isinstance(v, (int, float, Decimal, str)): - return True - return False +async def list_cache_states_under_prefixes( + session: AsyncSession, + *, + prefixes: Sequence[str], +) -> list[AssetCacheState]: + """Return AssetCacheState rows whose file_path starts with any of the given absolute prefixes.""" + if not prefixes: + return [] + conds = [] + for p in prefixes: + if not p: + continue + base = os.path.abspath(p) + if not base.endswith(os.sep): + base = base + os.sep + conds.append(AssetCacheState.file_path.like(base + "%")) + + if not conds: + return [] -def _project_kv(key: str, value: Any) -> list[dict]: - """ - Turn a metadata key/value into one or more projection rows: - - scalar -> one row (ordinal=0) in the proper typed column - - list of scalars -> one row per element with ordinal=i - - dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list) - - None -> single row with val_json = None - Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...} - """ - rows: list[dict] = [] - - if value is None: - rows.append({"key": key, "ordinal": 0, "val_json": None}) - return rows - - if _is_scalar(value): - if isinstance(value, bool): - rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) - elif isinstance(value, (int, float, Decimal)): - # store numeric; SQLAlchemy will coerce to Numeric - num = value if isinstance(value, Decimal) else Decimal(str(value)) - rows.append({"key": key, "ordinal": 0, "val_num": num}) - elif isinstance(value, str): - rows.append({"key": key, "ordinal": 0, "val_str": value}) - else: - # Fallback to json - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows - - if isinstance(value, list): - if all(_is_scalar(x) for x in value): - for i, x in enumerate(value): - if x is None: - rows.append({"key": key, "ordinal": i, "val_json": None}) - elif isinstance(x, bool): - rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) - elif isinstance(x, (int, float, Decimal)): - num = x if isinstance(x, Decimal) else Decimal(str(x)) - rows.append({"key": key, "ordinal": i, "val_num": num}) - elif isinstance(x, str): - rows.append({"key": key, "ordinal": i, "val_str": x}) - else: - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - # list contains objects -> one val_json per element - for i, x in enumerate(value): - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - - # Dict or any other structure -> single json row - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows + rows = ( + await session.execute( + select(AssetCacheState) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.id.asc()) + ) + ).scalars().all() + return list(rows) From 6282d495ca58cc20242782822f1d09008906275c Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 7 Sep 2025 17:53:18 +0300 Subject: [PATCH 31/82] corrected detection of missing files for assets --- app/assets_scanner.py | 67 +++++++++++++++++++++++----------------- app/database/services.py | 15 ++++----- 2 files changed, 47 insertions(+), 35 deletions(-) diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 33efbf047e89..a77f877718c5 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -16,7 +16,7 @@ from .database.db import create_session from .database.services import ( check_fs_asset_exists_quick, - list_cache_states_under_prefixes, + list_cache_states_with_asset_under_prefixes, add_missing_tag_for_asset_hash, remove_missing_tag_for_asset_hash, ) @@ -194,6 +194,7 @@ async def _pipeline_for_root( prog.started_at = time.time() try: + await _reconcile_missing_tags_for_root(root, prog) await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb) _start_slow_workers(root, prog, state, progress_cb=progress_cb) await _await_workers_then_finish(root, prog, state, progress_cb=progress_cb) @@ -302,16 +303,17 @@ async def _fast_reconcile_into_queue( "queued": queued, "discovered": prog.discovered, }) - - await _reconcile_missing_tags_for_root(root, prog) state.closed = True async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) -> None: """ - For every AssetCacheState under the root's base directories: - - if at least one recorded file_path exists for a hash -> remove 'missing' - - if none of the recorded file_paths exist for a hash -> add 'missing' + Logic for detecting missing Assets files: + - Clear 'missing' only if at least one cached path passes fast check: + exists AND mtime_ns matches AND size matches. + - Otherwise set 'missing'. + Files that exist but fail fast check will be slow-hashed by the normal pipeline, + and ingest_fs_asset will clear 'missing' if they truly match. """ if root == "models": bases: list[str] = [] @@ -324,33 +326,41 @@ async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) - try: async with await create_session() as sess: - states = await list_cache_states_under_prefixes(sess, prefixes=bases) - - present: set[str] = set() - missing: set[str] = set() - - for s in states: + rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases) + + by_hash: dict[str, dict[str, bool]] = {} # {hash: {"any_fast_ok": bool}} + for state, size_db in rows: + h = state.asset_hash + acc = by_hash.get(h) + if acc is None: + acc = {"any_fast_ok": False} + by_hash[h] = acc + try: + st = os.stat(state.file_path, follow_symlinks=True) + actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + fast_ok = False + if state.mtime_ns is not None and int(state.mtime_ns) == int(actual_mtime_ns): + if int(size_db) > 0 and int(st.st_size) == int(size_db): + fast_ok = True + if fast_ok: + acc["any_fast_ok"] = True + except FileNotFoundError: + pass # not fast_ok + except OSError as e: + _append_error(prog, phase="fast", path=state.file_path, message=str(e)) + + for h, acc in by_hash.items(): try: - if os.path.isfile(s.file_path): - present.add(s.asset_hash) + if acc["any_fast_ok"]: + await remove_missing_tag_for_asset_hash(sess, asset_hash=h) else: - missing.add(s.asset_hash) - except Exception as e: - _append_error(prog, phase="fast", path=s.file_path, message=f"stat error: {e}") - - only_missing = missing - present - - for h in present: - with contextlib.suppress(Exception): - await remove_missing_tag_for_asset_hash(sess, asset_hash=h) - - for h in only_missing: - with contextlib.suppress(Exception): - await add_missing_tag_for_asset_hash(sess, asset_hash=h, origin="automatic") + await add_missing_tag_for_asset_hash(sess, asset_hash=h, origin="automatic") + except Exception as ex: + _append_error(prog, phase="fast", path="", message=f"reconcile {h[:18]}: {ex}") await sess.commit() except Exception as e: - _append_error(prog, phase="fast", path="", message=f"missing-tag reconcile failed: {e}") + _append_error(prog, phase="fast", path="", message=f"reconcile failed: {e}") def _start_slow_workers( @@ -406,6 +416,7 @@ async def _await_workers_then_finish( ) -> None: if state.workers: await asyncio.gather(*state.workers, return_exceptions=True) + await _reconcile_missing_tags_for_root(root, prog) prog.finished_at = time.time() prog.status = "completed" if progress_cb: diff --git a/app/database/services.py b/app/database/services.py index ceed3749a882..0b7e3711cbd3 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -157,7 +157,7 @@ async def ingest_fs_asset( out["state_created"] = True if not out["state_created"]: - # most likely a unique(file_path) conflict; update that row + # unique(file_path) conflict -> update that row state = ( await session.execute( select(AssetCacheState).where(AssetCacheState.file_path == locator).limit(1) @@ -1044,12 +1044,12 @@ async def remove_missing_tag_for_asset_hash( return int(res.rowcount or 0) -async def list_cache_states_under_prefixes( +async def list_cache_states_with_asset_under_prefixes( session: AsyncSession, *, prefixes: Sequence[str], -) -> list[AssetCacheState]: - """Return AssetCacheState rows whose file_path starts with any of the given absolute prefixes.""" +) -> list[tuple[AssetCacheState, int]]: + """Return (AssetCacheState, size_bytes) tuples for rows whose file_path starts with any of the absolute prefixes.""" if not prefixes: return [] @@ -1067,9 +1067,10 @@ async def list_cache_states_under_prefixes( rows = ( await session.execute( - select(AssetCacheState) + select(AssetCacheState, Asset.size_bytes) + .join(Asset, Asset.hash == AssetCacheState.asset_hash) .where(sa.or_(*conds)) .order_by(AssetCacheState.id.asc()) ) - ).scalars().all() - return list(rows) + ).all() + return [(r[0], int(r[1] or 0)) for r in rows] From 3fa0fc496c19587f6d5c660492bfd978a1f363b8 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 8 Sep 2025 18:13:32 +0300 Subject: [PATCH 32/82] fix: use UPSERT to eliminate rare race condition during ingesting many small files in parallel --- app/database/services.py | 74 +++++++++++++++++++++++----------------- 1 file changed, 43 insertions(+), 31 deletions(-) diff --git a/app/database/services.py b/app/database/services.py index 0b7e3711cbd3..42f647d91bd9 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -10,6 +10,8 @@ from sqlalchemy import select, delete, func from sqlalchemy.orm import contains_eager, noload from sqlalchemy.exc import IntegrityError +from sqlalchemy.dialects import sqlite as d_sqlite +from sqlalchemy.dialects import postgresql as d_pg from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation from .timeutil import utcnow @@ -143,39 +145,49 @@ async def ingest_fs_asset( else: logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash) - # ---- Step 2: INSERT/UPDATE AssetCacheState (mtime_ns, file_path) ---- - with contextlib.suppress(IntegrityError): - async with session.begin_nested(): - session.add( - AssetCacheState( - asset_hash=asset_hash, - file_path=locator, - mtime_ns=int(mtime_ns), + # ---- Step 2: UPSERT AssetCacheState (mtime_ns, file_path) ---- + dialect = session.bind.dialect.name # "sqlite" or "postgresql" + vals = { + "asset_hash": asset_hash, + "file_path": locator, + "mtime_ns": int(mtime_ns), + } + # 2-step idempotent write so we can set flags deterministically: + # INSERT ... ON CONFLICT(file_path) DO NOTHING + # if conflicted, UPDATE only when values actually change + if dialect == "sqlite": + ins = ( + d_sqlite.insert(AssetCacheState) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(AssetCacheState) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + res = await session.execute(ins) + if int(res.rowcount or 0) > 0: + out["state_created"] = True + else: + upd = ( + sa.update(AssetCacheState) + .where(AssetCacheState.file_path == locator) + .where( + sa.or_( + AssetCacheState.asset_hash != asset_hash, + AssetCacheState.mtime_ns.is_(None), + AssetCacheState.mtime_ns != int(mtime_ns), ) ) - await session.flush() - out["state_created"] = True - - if not out["state_created"]: - # unique(file_path) conflict -> update that row - state = ( - await session.execute( - select(AssetCacheState).where(AssetCacheState.file_path == locator).limit(1) - ) - ).scalars().first() - if state is not None: - changed = False - if state.asset_hash != asset_hash: - state.asset_hash = asset_hash - changed = True - if state.mtime_ns != int(mtime_ns): - state.mtime_ns = int(mtime_ns) - changed = True - if changed: - await session.flush() - out["state_updated"] = True - else: - logging.error("Locator state missing for %s after conflict; skipping update.", asset_hash) + .values(asset_hash=asset_hash, mtime_ns=int(mtime_ns)) + ) + res2 = await session.execute(upd) + if int(res2.rowcount or 0) > 0: + out["state_updated"] = True # ---- Optional: AssetInfo + tag links ---- if info_name: From e3311c9229f10c68437065af904f12ddbf68e282 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 8 Sep 2025 18:15:09 +0300 Subject: [PATCH 33/82] feat: support for in-memory SQLite databases --- app/database/db.py | 64 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/app/database/db.py b/app/database/db.py index 67ddf412ba3b..eaf6648dba5d 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -36,14 +36,39 @@ def _absolutize_sqlite_url(db_url: str) -> str: if not u.drivername.startswith("sqlite"): return db_url - # Make path absolute if relative - db_path = u.database or "" + db_path: str = u.database or "" + if isinstance(db_path, str) and db_path.startswith("file:"): + return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared" if not os.path.isabs(db_path): db_path = os.path.abspath(os.path.join(os.getcwd(), db_path)) u = u.set(database=db_path) return str(u) +def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]: + """ + If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory), + rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present. + Returns: (normalized_url, is_memory) + """ + try: + u = make_url(db_url) + except Exception: + return db_url, False + if not u.drivername.startswith("sqlite"): + return db_url, False + + db = u.database or "" + if db == ":memory:": + u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true") + return str(u), True + if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db: + if "uri=true" not in db: + u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true")) + return str(u), True + return str(u), False + + def _to_sync_driver_url(async_url: str) -> str: """Convert an async SQLAlchemy URL to a sync URL for Alembic.""" u = make_url(async_url) @@ -70,7 +95,10 @@ def _get_sqlite_file_path(sync_url: str) -> Optional[str]: if not u.drivername.startswith("sqlite"): return None - return u.database + db_path = u.database + if isinstance(db_path, str) and db_path.startswith("file:"): + return None # Not a real file if it is a URI like "file:...?" + return db_path def _get_alembic_config(sync_url: str) -> Config: @@ -96,8 +124,8 @@ async def init_db_engine() -> None: if not raw_url: raise RuntimeError("Database URL is not configured.") - # Absolutize SQLite path for async engine - db_url = _absolutize_sqlite_url(raw_url) + db_url, is_mem = _normalize_sqlite_memory_url(raw_url) + db_url = _absolutize_sqlite_url(db_url) # Prepare async engine connect_args = {} @@ -106,6 +134,8 @@ async def init_db_engine() -> None: "check_same_thread": False, "timeout": 12, } + if is_mem: + connect_args["uri"] = True ENGINE = create_async_engine( db_url, @@ -117,18 +147,19 @@ async def init_db_engine() -> None: # Enforce SQLite pragmas on the async engine if db_url.startswith("sqlite"): async with ENGINE.begin() as conn: - # WAL for concurrency and durability, Foreign Keys for referential integrity - current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar() - if str(current_mode).lower() != "wal": - new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar() - if str(new_mode).lower() != "wal": - raise RuntimeError("Failed to set SQLite journal mode to WAL.") - LOGGER.info("SQLite journal mode set to WAL.") + if not is_mem: + # WAL for concurrency and durability, Foreign Keys for referential integrity + current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar() + if str(current_mode).lower() != "wal": + new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar() + if str(new_mode).lower() != "wal": + raise RuntimeError("Failed to set SQLite journal mode to WAL.") + LOGGER.info("SQLite journal mode set to WAL.") await conn.execute(text("PRAGMA foreign_keys = ON;")) await conn.execute(text("PRAGMA synchronous = NORMAL;")) - await _run_migrations(raw_url=db_url) + await _run_migrations(raw_url=db_url, connect_args=connect_args) SESSION = async_sessionmaker( bind=ENGINE, @@ -139,7 +170,7 @@ async def init_db_engine() -> None: ) -async def _run_migrations(raw_url: str) -> None: +async def _run_migrations(raw_url: str, connect_args: dict) -> None: """ Run Alembic migrations up to head. @@ -148,12 +179,11 @@ async def _run_migrations(raw_url: str) -> None: """ # Convert to sync URL and make SQLite URL an absolute one sync_url = _to_sync_driver_url(raw_url) + sync_url, is_mem = _normalize_sqlite_memory_url(sync_url) sync_url = _absolutize_sqlite_url(sync_url) cfg = _get_alembic_config(sync_url) - - # Inspect current and target heads - engine = create_engine(sync_url, future=True) + engine = create_engine(sync_url, future=True, connect_args=connect_args) with engine.connect() as conn: context = MigrationContext.configure(conn) current_rev = context.get_current_revision() From 0e9de2b7c959925e6737779ebdcee920e152fba3 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 8 Sep 2025 20:43:45 +0300 Subject: [PATCH 34/82] feat: add first test --- tests-assets/conftest.py | 181 ++++++++++++++++++++++++++++++++++++ tests-assets/test_assets.py | 26 ++++++ 2 files changed, 207 insertions(+) create mode 100644 tests-assets/conftest.py create mode 100644 tests-assets/test_assets.py diff --git a/tests-assets/conftest.py b/tests-assets/conftest.py new file mode 100644 index 000000000000..82d02dc74ac7 --- /dev/null +++ b/tests-assets/conftest.py @@ -0,0 +1,181 @@ +import asyncio +import contextlib +import json +import os +import socket +import sys +import tempfile +import time +from pathlib import Path +from typing import AsyncIterator, Callable + +import aiohttp +import pytest +import pytest_asyncio +import subprocess + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def _make_base_dirs(root: Path) -> None: + for sub in ("models", "custom_nodes", "input", "output", "temp", "user"): + (root / sub).mkdir(parents=True, exist_ok=True) + + +async def _wait_http_ready(base: str, session: aiohttp.ClientSession, timeout: float = 90.0) -> None: + start = time.time() + last_err = None + while time.time() - start < timeout: + try: + async with session.get(base + "/api/assets") as r: + if r.status in (200, 400): + return + except Exception as e: + last_err = e + await asyncio.sleep(0.25) + raise RuntimeError(f"ComfyUI HTTP did not become ready: {last_err}") + + +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="session") +def comfy_tmp_base_dir() -> Path: + tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-")) + _make_base_dirs(tmp) + yield tmp + # cleanup in a best-effort way; ComfyUI should not keep files open in this dir + with contextlib.suppress(Exception): + for p in sorted(tmp.rglob("*"), reverse=True): + if p.is_file() or p.is_symlink(): + p.unlink(missing_ok=True) + for p in sorted(tmp.glob("**/*"), reverse=True): + with contextlib.suppress(Exception): + p.rmdir() + tmp.rmdir() + + +@pytest.fixture(scope="session") +def comfy_url_and_proc(comfy_tmp_base_dir: Path): + """ + Boot ComfyUI subprocess with: + - sandbox base dir + - sqlite memory DB + - autoscan disabled + Returns (base_url, process, port) + """ + port = 8500 # _free_port() + db_url = "sqlite+aiosqlite:///:memory:" + + # stdout/stderr capturing for debugging if something goes wrong + logs_dir = comfy_tmp_base_dir / "logs" + logs_dir.mkdir(exist_ok=True) + out_log = open(logs_dir / "stdout.log", "w", buffering=1) + err_log = open(logs_dir / "stderr.log", "w", buffering=1) + + comfy_root = Path(__file__).resolve().parent.parent + if not (comfy_root / "main.py").is_file(): + raise FileNotFoundError(f"main.py not found under {comfy_root}") + + proc = subprocess.Popen( + args=[ + sys.executable, + "main.py", + f"--base-directory={str(comfy_tmp_base_dir)}", + f"--database-url={db_url}", + "--disable-assets-autoscan", + "--listen", + "127.0.0.1", + "--port", + str(port), + ], + stdout=out_log, + stderr=err_log, + cwd=str(comfy_root), + env={**os.environ}, + ) + + base_url = f"http://127.0.0.1:{port}" + try: + async def _probe(): + async with aiohttp.ClientSession() as s: + await _wait_http_ready(base_url, s, timeout=90.0) + + asyncio.run(_probe()) + yield base_url, proc, port + except Exception as e: + with contextlib.suppress(Exception): + proc.terminate() + proc.wait(timeout=10) + raise RuntimeError(f"ComfyUI did not become ready: {e}") + + if proc and proc.poll() is None: + with contextlib.suppress(Exception): + proc.terminate() + proc.wait(timeout=15) + out_log.close() + err_log.close() + + +@pytest_asyncio.fixture +async def http() -> AsyncIterator[aiohttp.ClientSession]: + timeout = aiohttp.ClientTimeout(total=120) + async with aiohttp.ClientSession(timeout=timeout) as s: + yield s + + +@pytest.fixture +def api_base(comfy_url_and_proc) -> str: + base_url, _proc, _port = comfy_url_and_proc + return base_url + + +@pytest.fixture +def make_asset_bytes() -> Callable[[str], bytes]: + def _make(name: str) -> bytes: + # Generate deterministic small content variations based on name + seed = sum(ord(c) for c in name) % 251 + data = bytes((i * 31 + seed) % 256 for i in range(8192)) + return data + return _make + + +async def _upload_asset(session: aiohttp.ClientSession, base: str, *, name: str, tags: list[str], meta: dict) -> dict: + make_asset_bytes = bytes((i % 251) for i in range(4096)) + form = aiohttp.FormData() + form.add_field("file", make_asset_bytes, filename=name, content_type="application/octet-stream") + form.add_field("tags", json.dumps(tags)) + form.add_field("name", name) + form.add_field("user_metadata", json.dumps(meta)) + async with session.post(base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status in (200, 201), body + return body + + +@pytest_asyncio.fixture +async def seeded_asset(http: aiohttp.ClientSession, api_base: str) -> dict: + """ + Upload one asset into models/checkpoints/unit-tests/. + Returns response dict with id, asset_hash, tags, etc. + """ + name = "unit_1_example.safetensors" + tags = ["models", "checkpoints", "unit-tests", "alpha"] + meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None} + form = aiohttp.FormData() + form.add_field("file", b"A" * 4096, filename=name, content_type="application/octet-stream") + form.add_field("tags", json.dumps(tags)) + form.add_field("name", name) + form.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 201, body + return body diff --git a/tests-assets/test_assets.py b/tests-assets/test_assets.py new file mode 100644 index 000000000000..dfcedc52cced --- /dev/null +++ b/tests-assets/test_assets.py @@ -0,0 +1,26 @@ +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_tags_listing_endpoint(http: aiohttp.ClientSession, api_base: str): + # Include zero-usage tags by default + async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1: + body1 = await r1.json() + assert r1.status == 200 + names = [t["name"] for t in body1["tags"]] + # A few system tags from migration should exist: + assert "models" in names + assert "checkpoints" in names + + # Only used tags + async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2: + body2 = await r2.json() + assert r2.status == 200 + # Should contain no tags + assert not [t["name"] for t in body2["tags"]] + + # TODO-1: add some asset + # TODO-2: check that "used" tags are now non zero amount + + # TODO-3: do a global teardown, so the state of ComfyUI is clear after each test, and all test can be run solo or one-by-one without any problems. From dfb5703d40ef03713038e821106939dcbc2c4445 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 8 Sep 2025 22:37:39 +0300 Subject: [PATCH 35/82] feat: remove Asset when there is no references left + bugfixes + more tests --- app/api/assets_routes.py | 12 +++- app/assets_manager.py | 37 +++++++++- app/database/services.py | 34 ++++++++- tests-assets/conftest.py | 85 +++++++++++++++++----- tests-assets/test_assets.py | 26 ------- tests-assets/test_tags.py | 56 +++++++++++++++ tests-assets/test_uploads.py | 133 +++++++++++++++++++++++++++++++++++ 7 files changed, 332 insertions(+), 51 deletions(-) delete mode 100644 tests-assets/test_assets.py create mode 100644 tests-assets/test_tags.py create mode 100644 tests-assets/test_uploads.py diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 71e99f2310a2..248f7a2f93e4 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -277,9 +277,13 @@ async def upload_asset(request: web.Request) -> web.Response: os.remove(tmp_path) msg = str(e) if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH": - return _error_response(400, "HASH_MISMATCH", "Uploaded file hash does not match provided hash.") + return _error_response( + 400, + "HASH_MISMATCH", + "Uploaded file hash does not match provided hash.", + ) return _error_response(400, "BAD_REQUEST", "Invalid inputs.") - except Exception: + except Exception as e: if tmp_path and os.path.exists(tmp_path): os.remove(tmp_path) return _error_response(500, "INTERNAL", "Unexpected server error.") @@ -343,10 +347,14 @@ async def delete_asset(request: web.Request) -> web.Response: except Exception: return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") + delete_content = request.query.get("delete_content") + delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"} + try: deleted = await assets_manager.delete_asset_reference( asset_info_id=asset_info_id, owner_id=UserManager.get_request_user_id(request), + delete_content_if_orphan=delete_content, ) except Exception: return _error_response(500, "INTERNAL", "Unexpected server error.") diff --git a/app/assets_manager.py b/app/assets_manager.py index b84b61508762..a2a73773a8ff 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -27,6 +27,8 @@ create_asset_info_for_existing_asset, fetch_asset_info_asset_and_tags, get_asset_info_by_id, + list_cache_states_by_asset_hash, + asset_info_exists_for_hash, ) from .api import schemas_in, schemas_out from ._assets_helpers import ( @@ -371,11 +373,40 @@ async def update_asset( ) -async def delete_asset_reference(*, asset_info_id: str, owner_id: str) -> bool: +async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool: + """Delete single AssetInfo. If this was the last reference to Asset and delete_content_if_orphan=True (default), + delete the Asset row as well and remove all cached files recorded for that asset_hash. + """ async with await create_session() as session: - r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + asset_hash = info_row.asset_hash if info_row else None + deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not deleted: + await session.commit() + return False + + if not delete_content_if_orphan or not asset_hash: + await session.commit() + return True + + still_exists = await asset_info_exists_for_hash(session, asset_hash=asset_hash) + if still_exists: + await session.commit() + return True + + states = await list_cache_states_by_asset_hash(session, asset_hash=asset_hash) + file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)] + + asset_row = await get_asset_by_hash(session, asset_hash=asset_hash) + if asset_row is not None: + await session.delete(asset_row) + await session.commit() - return r + for p in file_paths: + with contextlib.suppress(Exception): + if p and os.path.isfile(p): + os.remove(p) + return True async def create_asset_from_hash( diff --git a/app/database/services.py b/app/database/services.py index 42f647d91bd9..842103e9ee0e 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -36,6 +36,17 @@ async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> return await session.get(AssetInfo, asset_info_id) +async def asset_info_exists_for_hash(session: AsyncSession, *, asset_hash: str) -> bool: + return ( + await session.execute( + sa.select(sa.literal(True)) + .select_from(AssetInfo) + .where(AssetInfo.asset_hash == asset_hash) + .limit(1) + ) + ).first() is not None + + async def check_fs_asset_exists_quick( session, *, @@ -586,7 +597,7 @@ async def create_asset_info_for_existing_asset( tag_origin: str = "manual", owner_id: str = "", ) -> AssetInfo: - """Create a new AssetInfo referencing an existing Asset (no content write).""" + """Create a new AssetInfo referencing an existing Asset. If row already exists, return it unchanged.""" now = utcnow() info = AssetInfo( owner_id=owner_id, @@ -597,8 +608,25 @@ async def create_asset_info_for_existing_asset( updated_at=now, last_access_time=now, ) - session.add(info) - await session.flush() # get info.id + try: + async with session.begin_nested(): + session.add(info) + await session.flush() # get info.id + except IntegrityError: + existing = ( + await session.execute( + select(AssetInfo) + .where( + AssetInfo.asset_hash == asset_hash, + AssetInfo.name == name, + AssetInfo.owner_id == owner_id, + ) + .limit(1) + ) + ).scalars().first() + if not existing: + raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") + return existing # Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore # if user_metadata is not None: diff --git a/tests-assets/conftest.py b/tests-assets/conftest.py index 82d02dc74ac7..24eed0728362 100644 --- a/tests-assets/conftest.py +++ b/tests-assets/conftest.py @@ -52,7 +52,6 @@ def comfy_tmp_base_dir() -> Path: tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-")) _make_base_dirs(tmp) yield tmp - # cleanup in a best-effort way; ComfyUI should not keep files open in this dir with contextlib.suppress(Exception): for p in sorted(tmp.rglob("*"), reverse=True): if p.is_file() or p.is_symlink(): @@ -72,10 +71,9 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path): - autoscan disabled Returns (base_url, process, port) """ - port = 8500 # _free_port() + port = _free_port() db_url = "sqlite+aiosqlite:///:memory:" - # stdout/stderr capturing for debugging if something goes wrong logs_dir = comfy_tmp_base_dir / "logs" logs_dir.mkdir(exist_ok=True) out_log = open(logs_dir / "stdout.log", "w", buffering=1) @@ -138,28 +136,59 @@ def api_base(comfy_url_and_proc) -> str: return base_url -@pytest.fixture -def make_asset_bytes() -> Callable[[str], bytes]: - def _make(name: str) -> bytes: - # Generate deterministic small content variations based on name - seed = sum(ord(c) for c in name) % 251 - data = bytes((i * 31 + seed) % 256 for i in range(8192)) - return data - return _make - - -async def _upload_asset(session: aiohttp.ClientSession, base: str, *, name: str, tags: list[str], meta: dict) -> dict: - make_asset_bytes = bytes((i % 251) for i in range(4096)) +async def _post_multipart_asset( + session: aiohttp.ClientSession, + base: str, + *, + name: str, + tags: list[str], + meta: dict, + data: bytes, + extra_fields: dict | None = None, +) -> tuple[int, dict]: form = aiohttp.FormData() - form.add_field("file", make_asset_bytes, filename=name, content_type="application/octet-stream") + form.add_field("file", data, filename=name, content_type="application/octet-stream") form.add_field("tags", json.dumps(tags)) form.add_field("name", name) form.add_field("user_metadata", json.dumps(meta)) + if extra_fields: + for k, v in extra_fields.items(): + form.add_field(k, v) async with session.post(base + "/api/assets", data=form) as r: body = await r.json() - assert r.status in (200, 201), body + return r.status, body + + +@pytest.fixture +def make_asset_bytes() -> Callable[[str, int], bytes]: + def _make(name: str, size: int = 8192) -> bytes: + seed = sum(ord(c) for c in name) % 251 + return bytes((i * 31 + seed) % 256 for i in range(size)) + return _make + + +@pytest_asyncio.fixture +async def asset_factory(http: aiohttp.ClientSession, api_base: str): + """ + Returns create(name, tags, meta, data) -> response dict + Tracks created ids and deletes them after the test. + """ + created: list[str] = [] + + async def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict: + status, body = await _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data) + assert status in (200, 201), body + created.append(body["id"]) return body + yield create + + # cleanup by id + for aid in created: + with contextlib.suppress(Exception): + async with http.delete(f"{api_base}/api/assets/{aid}") as r: + await r.read() + @pytest_asyncio.fixture async def seeded_asset(http: aiohttp.ClientSession, api_base: str) -> dict: @@ -179,3 +208,25 @@ async def seeded_asset(http: aiohttp.ClientSession, api_base: str) -> dict: body = await r.json() assert r.status == 201, body return body + + +@pytest_asyncio.fixture(autouse=True) +async def autoclean_unit_test_assets(http: aiohttp.ClientSession, api_base: str): + """Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test.""" + yield + + while True: + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests", "limit": "500", "sort": "name"}, + ) as r: + body = await r.json() + if r.status != 200: + break + ids = [a["id"] for a in body.get("assets", [])] + if not ids: + break + for aid in ids: + with contextlib.suppress(Exception): + async with http.delete(f"{api_base}/api/assets/{aid}") as dr: + await dr.read() diff --git a/tests-assets/test_assets.py b/tests-assets/test_assets.py deleted file mode 100644 index dfcedc52cced..000000000000 --- a/tests-assets/test_assets.py +++ /dev/null @@ -1,26 +0,0 @@ -import aiohttp -import pytest - - -@pytest.mark.asyncio -async def test_tags_listing_endpoint(http: aiohttp.ClientSession, api_base: str): - # Include zero-usage tags by default - async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1: - body1 = await r1.json() - assert r1.status == 200 - names = [t["name"] for t in body1["tags"]] - # A few system tags from migration should exist: - assert "models" in names - assert "checkpoints" in names - - # Only used tags - async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2: - body2 = await r2.json() - assert r2.status == 200 - # Should contain no tags - assert not [t["name"] for t in body2["tags"]] - - # TODO-1: add some asset - # TODO-2: check that "used" tags are now non zero amount - - # TODO-3: do a global teardown, so the state of ComfyUI is clear after each test, and all test can be run solo or one-by-one without any problems. diff --git a/tests-assets/test_tags.py b/tests-assets/test_tags.py new file mode 100644 index 000000000000..c63df48bc70f --- /dev/null +++ b/tests-assets/test_tags.py @@ -0,0 +1,56 @@ +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_tags_present(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + # Include zero-usage tags by default + async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1: + body1 = await r1.json() + assert r1.status == 200 + names = [t["name"] for t in body1["tags"]] + # A few system tags from migration should exist: + assert "models" in names + assert "checkpoints" in names + + # Only used tags before we add anything new from this test cycle + async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2: + body2 = await r2.json() + assert r2.status == 200 + # We already seeded one asset via fixture, so used tags must be non-empty + used_names = [t["name"] for t in body2["tags"]] + assert "models" in used_names + assert "checkpoints" in used_names + + # Prefix filter should refine the list + async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}) as r3: + b3 = await r3.json() + assert r3.status == 200 + names3 = [t["name"] for t in b3["tags"]] + assert "unit-tests" in names3 + assert "models" not in names3 # filtered out by prefix + + # Order by name ascending should be stable + async with http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}) as r4: + b4 = await r4.json() + assert r4.status == 200 + names4 = [t["name"] for t in b4["tags"]] + assert names4 == sorted(names4) + + +@pytest.mark.asyncio +async def test_tags_empty_usage(http: aiohttp.ClientSession, api_base: str): + # Include zero-usage tags by default + async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1: + body1 = await r1.json() + assert r1.status == 200 + names = [t["name"] for t in body1["tags"]] + # A few system tags from migration should exist: + assert "models" in names + assert "checkpoints" in names + + # With include_zero=False there should be no tags returned for the database without Assets. + async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2: + body2 = await r2.json() + assert r2.status == 200 + assert not [t["name"] for t in body2["tags"]] diff --git a/tests-assets/test_uploads.py b/tests-assets/test_uploads.py new file mode 100644 index 000000000000..65c34f139ee0 --- /dev/null +++ b/tests-assets/test_uploads.py @@ -0,0 +1,133 @@ +import json +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_upload_requires_multipart(http: aiohttp.ClientSession, api_base: str): + async with http.post(api_base + "/api/assets", json={"foo": "bar"}) as r: + body = await r.json() + assert r.status == 415 + assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE" + + +@pytest.mark.asyncio +async def test_upload_missing_file_and_hash(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData(default_to_multipart=True) + form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests"])) + form.add_field("name", "x.safetensors") + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "MISSING_FILE" + + +@pytest.mark.asyncio +async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 128, filename="m.safetensors", content_type="application/octet-stream") + form.add_field("tags", json.dumps(["models", "no_such_category", "unit-tests"])) + form.add_field("name", "m.safetensors") + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "INVALID_BODY" + assert "unknown models category" in body["error"]["message"] or "unknown model category" in body["error"]["message"] + + +@pytest.mark.asyncio +async def test_upload_tags_traversal_guard(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 256, filename="evil.safetensors", content_type="application/octet-stream") + # '..' should be rejected by destination resolver + form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"])) + form.add_field("name", "evil.safetensors") + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") + + +@pytest.mark.asyncio +async def test_upload_ok_duplicate_reference(http: aiohttp.ClientSession, api_base: str, make_asset_bytes): + name = "dup_a.safetensors" + tags = ["models", "checkpoints", "unit-tests", "alpha"] + meta = {"purpose": "dup"} + data = make_asset_bytes(name) + form1 = aiohttp.FormData() + form1.add_field("file", data, filename=name, content_type="application/octet-stream") + form1.add_field("tags", json.dumps(tags)) + form1.add_field("name", name) + form1.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form1) as r1: + a1 = await r1.json() + assert r1.status == 201, a1 + assert a1["created_new"] is True + + # Second upload with the same data and name should return created_new == False and the same asset + form2 = aiohttp.FormData() + form2.add_field("file", data, filename=name, content_type="application/octet-stream") + form2.add_field("tags", json.dumps(tags)) + form2.add_field("name", name) + form2.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form2) as r2: + a2 = await r2.json() + assert r2.status == 200, a2 + assert a2["created_new"] is False + assert a2["asset_hash"] == a1["asset_hash"] + assert a2["id"] == a1["id"] # old reference + + # Third upload with the same data but new name should return created_new == False and the new AssetReference + form3 = aiohttp.FormData() + form3.add_field("file", data, filename=name, content_type="application/octet-stream") + form3.add_field("tags", json.dumps(tags)) + form3.add_field("name", name + "_d") + form3.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form3) as r2: + a3 = await r2.json() + assert r2.status == 200, a3 + assert a3["created_new"] is False + assert a3["asset_hash"] == a1["asset_hash"] + assert a3["id"] != a1["id"] # old reference + + +@pytest.mark.asyncio +async def test_upload_fastpath_from_existing_hash_no_file(http: aiohttp.ClientSession, api_base: str): + # Seed a small file first + name = "fastpath_seed.safetensors" + tags = ["models", "checkpoints", "unit-tests"] + meta = {} + form1 = aiohttp.FormData() + form1.add_field("file", b"B" * 1024, filename=name, content_type="application/octet-stream") + form1.add_field("tags", json.dumps(tags)) + form1.add_field("name", name) + form1.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form1) as r1: + b1 = await r1.json() + assert r1.status == 201, b1 + h = b1["asset_hash"] + + # Now POST /api/assets with only hash and no file + form2 = aiohttp.FormData() + form2.add_field("hash", h) + form2.add_field("tags", json.dumps(tags)) + form2.add_field("name", "fastpath_copy.safetensors") + form2.add_field("user_metadata", json.dumps({"purpose": "copy"})) + async with http.post(api_base + "/api/assets", data=form2) as r2: + b2 = await r2.json() + assert r2.status == 200, b2 # fast path returns 200 with created_new == False + assert b2["created_new"] is False + assert b2["asset_hash"] == h + + +@pytest.mark.asyncio +async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str): + payload = { + "hash": "blake3:" + "0" * 64, + "name": "nonexistent.bin", + "tags": ["models", "checkpoints", "unit-tests"], + } + async with http.post(api_base + "/api/assets/from-hash", json=payload) as r: + body = await r.json() + assert r.status == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" From faa1e4de17fb248e3722666dcdf971337df16b37 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 9 Sep 2025 15:17:03 +0300 Subject: [PATCH 36/82] fixed another test --- tests-assets/test_uploads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests-assets/test_uploads.py b/tests-assets/test_uploads.py index 65c34f139ee0..8c50beb0993b 100644 --- a/tests-assets/test_uploads.py +++ b/tests-assets/test_uploads.py @@ -108,7 +108,7 @@ async def test_upload_fastpath_from_existing_hash_no_file(http: aiohttp.ClientSe h = b1["asset_hash"] # Now POST /api/assets with only hash and no file - form2 = aiohttp.FormData() + form2 = aiohttp.FormData(default_to_multipart=True) form2.add_field("hash", h) form2.add_field("tags", json.dumps(tags)) form2.add_field("name", "fastpath_copy.safetensors") From 0ef73e95fd9f4c1b2fbe0489f8bb20b71410274a Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 9 Sep 2025 16:02:39 +0300 Subject: [PATCH 37/82] fixed validation error + more tests --- app/api/assets_routes.py | 2 +- tests-assets/test_list_filter.py | 85 ++++++++++++++++++++++++++++++++ tests-assets/test_tags.py | 44 +++++++++++++++++ 3 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 tests-assets/test_list_filter.py diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 248f7a2f93e4..252242eaeaa8 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -485,4 +485,4 @@ def _error_response(status: int, code: str, message: str, details: Optional[dict def _validation_error_response(code: str, ve: ValidationError) -> web.Response: - return _error_response(400, code, "Validation failed.", {"errors": ve.errors()}) + return _error_response(400, code, "Validation failed.", {"errors": ve.json()}) diff --git a/tests-assets/test_list_filter.py b/tests-assets/test_list_filter.py new file mode 100644 index 000000000000..82abf4ab39e0 --- /dev/null +++ b/tests-assets/test_list_filter.py @@ -0,0 +1,85 @@ +import json +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_list_assets_paging_and_sort(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes): + names = ["a1_u.safetensors", "a2_u.safetensors", "a3_u.safetensors"] + for n in names: + await asset_factory( + n, + ["models", "checkpoints", "unit-tests", "paging"], + {"epoch": 1}, + make_asset_bytes(n, size=2048), + ) + + # name ascending for stable order + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "0"}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + got1 = [a["name"] for a in b1["assets"]] + assert got1 == sorted(names)[:2] + assert b1["has_more"] is True + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "2"}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 + got2 = [a["name"] for a in b2["assets"]] + assert got2 == sorted(names)[2:] + assert b2["has_more"] is False + + +@pytest.mark.asyncio +async def test_list_assets_include_exclude_and_name_contains(http: aiohttp.ClientSession, api_base: str, asset_factory): + a = await asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024) + b = await asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,alpha", "exclude_tags": "beta", "limit": "50"}, + ) as r: + body = await r.json() + assert r.status == 200 + names = [x["name"] for x in body["assets"]] + assert a["name"] in names + assert b["name"] not in names + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests", "name_contains": "inc_"}, + ) as r2: + body2 = await r2.json() + assert r2.status == 200 + names2 = [x["name"] for x in body2["assets"]] + assert a["name"] in names2 + assert b["name"] in names2 + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "non-existing-tag"}, + ) as r2: + body3 = await r2.json() + assert r2.status == 200 + assert not body3["assets"] + + +@pytest.mark.asyncio +async def test_list_assets_invalid_query_rejected(http: aiohttp.ClientSession, api_base: str): + # limit too small + async with http.get(api_base + "/api/assets", params={"limit": "0"}) as r1: + b1 = await r1.json() + assert r1.status == 400 + assert b1["error"]["code"] == "INVALID_QUERY" + + # bad metadata JSON + async with http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}) as r2: + b2 = await r2.json() + assert r2.status == 400 + assert b2["error"]["code"] == "INVALID_QUERY" diff --git a/tests-assets/test_tags.py b/tests-assets/test_tags.py index c63df48bc70f..273efdd03d76 100644 --- a/tests-assets/test_tags.py +++ b/tests-assets/test_tags.py @@ -54,3 +54,47 @@ async def test_tags_empty_usage(http: aiohttp.ClientSession, api_base: str): body2 = await r2.json() assert r2.status == 200 assert not [t["name"] for t in body2["tags"]] + + +@pytest.mark.asyncio +async def test_add_and_remove_tags(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # Add tags with duplicates and mixed case + payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]} + async with http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add) as r1: + b1 = await r1.json() + assert r1.status == 200, b1 + # normalized and deduplicated + assert "newtag" in b1["added"] or "beta" in b1["added"] or "unit-tests" not in b1["added"] + + async with http.get(f"{api_base}/api/assets/{aid}") as rg: + g = await rg.json() + assert rg.status == 200 + tags_now = set(g["tags"]) + assert "newtag" in tags_now + assert "beta" in tags_now + + # Remove a tag and a non-existent tag + payload_del = {"tags": ["newtag", "does-not-exist"]} + async with http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del) as r2: + b2 = await r2.json() + assert r2.status == 200 + assert "newtag" in b2["removed"] + assert "does-not-exist" in b2["not_present"] + + +@pytest.mark.asyncio +async def test_tags_list_order_and_prefix(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + # name ascending + async with http.get(api_base + "/api/tags", params={"order": "name_asc", "limit": "100"}) as r1: + b1 = await r1.json() + assert r1.status == 200 + names = [t["name"] for t in b1["tags"]] + assert names == sorted(names) + + # invalid limit rejected + async with http.get(api_base + "/api/tags", params={"limit": "1001"}) as r2: + b2 = await r2.json() + assert r2.status == 400 + assert b2["error"]["code"] == "INVALID_QUERY" From 357193f7b5f7db017f1f6b7f407136c8faa537ec Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 9 Sep 2025 17:28:19 +0300 Subject: [PATCH 38/82] fixed metadata filtering + tests --- app/api/assets_routes.py | 2 +- app/database/_helpers.py | 41 +-- app/database/models.py | 4 +- tests-assets/test_metadata_filters.py | 378 ++++++++++++++++++++++++++ 4 files changed, 403 insertions(+), 22 deletions(-) create mode 100644 tests-assets/test_metadata_filters.py diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 252242eaeaa8..001dfa324a97 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -283,7 +283,7 @@ async def upload_asset(request: web.Request) -> web.Response: "Uploaded file hash does not match provided hash.", ) return _error_response(400, "BAD_REQUEST", "Invalid inputs.") - except Exception as e: + except Exception: if tmp_path and os.path.exists(tmp_path): os.remove(tmp_path) return _error_response(500, "INTERNAL", "Unexpected server error.") diff --git a/app/database/_helpers.py b/app/database/_helpers.py index 5ce97207691e..a031e861cece 100644 --- a/app/database/_helpers.py +++ b/app/database/_helpers.py @@ -67,32 +67,29 @@ def apply_metadata_filter( return stmt def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: - subquery = ( - select(sa.literal(1)) - .select_from(AssetInfoMeta) - .where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - *preds, - ) - .limit(1) + return sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + *preds, ) - return sa.exists(subquery) def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: # Missing OR null: if value is None: # either: no row for key OR a row for key with explicit null - no_row_for_key = ~sa.exists( - select(sa.literal(1)) - .select_from(AssetInfoMeta) - .where( + no_row_for_key = sa.not_( + sa.exists().where( AssetInfoMeta.asset_info_id == AssetInfo.id, AssetInfoMeta.key == key, ) - .limit(1) ) - null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None)) + null_row = _exists_for_pred( + key, + AssetInfoMeta.val_json.is_(None), + AssetInfoMeta.val_str.is_(None), + AssetInfoMeta.val_num.is_(None), + AssetInfoMeta.val_bool.is_(None), + ) return sa.or_(no_row_for_key, null_row) # Typed scalar matches: @@ -135,13 +132,19 @@ def project_kv(key: str, value: Any) -> list[dict]: - scalar -> one row (ordinal=0) in the proper typed column - list of scalars -> one row per element with ordinal=i - dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list) - - None -> single row with val_json = None + - None -> single row with all value columns NULL Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...} """ rows: list[dict] = [] + def _null_row(ordinal: int) -> dict: + return { + "key": key, "ordinal": ordinal, + "val_str": None, "val_num": None, "val_bool": None, "val_json": None + } + if value is None: - rows.append({"key": key, "ordinal": 0, "val_json": None}) + rows.append(_null_row(0)) return rows if is_scalar(value): @@ -162,7 +165,7 @@ def project_kv(key: str, value: Any) -> list[dict]: if all(is_scalar(x) for x in value): for i, x in enumerate(value): if x is None: - rows.append({"key": key, "ordinal": i, "val_json": None}) + rows.append(_null_row(i)) elif isinstance(x, bool): rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) elif isinstance(x, (int, float, Decimal)): diff --git a/app/database/models.py b/app/database/models.py index 2038674681fa..5bb3a09bc27b 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -145,7 +145,7 @@ class AssetInfo(Base): String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False ) preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL")) - user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON) + user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True)) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=False), nullable=False, default=utcnow ) @@ -220,7 +220,7 @@ class AssetInfoMeta(Base): val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True) val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True) - val_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True) + val_json: Mapped[Optional[Any]] = mapped_column(JSON(none_as_null=True), nullable=True) asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries") diff --git a/tests-assets/test_metadata_filters.py b/tests-assets/test_metadata_filters.py new file mode 100644 index 000000000000..39d00fa2df1f --- /dev/null +++ b/tests-assets/test_metadata_filters.py @@ -0,0 +1,378 @@ +import json +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_meta_and_across_keys_and_types(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes): + name = "mf_and_mix.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-and"] + meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} + await asset_factory(name, tags, meta, make_asset_bytes(name, 4096)) + + # All keys must match (AND semantics) + f_ok = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-and", + "metadata_filter": json.dumps(f_ok), + }, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names = [a["name"] for a in b1["assets"]] + assert name in names + + # One key mismatched -> no result + f_bad = {"purpose": "mix", "epoch": 2, "active": True} + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-and", + "metadata_filter": json.dumps(f_bad), + }, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 + assert not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes): + name = "mf_types.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-types"] + meta = {"epoch": 1, "active": True} + await asset_factory(name, tags, meta, make_asset_bytes(name)) + + # int filter matches numeric + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"epoch": 1}), + }, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # string "1" must NOT match numeric 1 + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"epoch": "1"}), + }, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + # bool True matches, string "true" must NOT match + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"active": True}), + }, + ) as r3: + b3 = await r3.json() + assert r3.status == 200 and any(a["name"] == name for a in b3["assets"]) + + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"active": "true"}), + }, + ) as r4: + b4 = await r4.json() + assert r4.status == 200 and not b4["assets"] + + +@pytest.mark.asyncio +async def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes): + name = "mf_list_scalars.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-list"] + meta = {"flags": ["red", "green"]} + await asset_factory(name, tags, meta, make_asset_bytes(name, 3000)) + + # Any-of should match because "green" is present + filt_ok = {"flags": ["blue", "green"]} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_ok)}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # None of provided flags present -> no match + filt_miss = {"flags": ["blue", "yellow"]} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_miss)}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + # Duplicates in list should not break matching + filt_dup = {"flags": ["green", "green", "green"]} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_dup)}, + ) as r3: + b3 = await r3.json() + assert r3.status == 200 and any(a["name"] == name for a in b3["assets"]) + + +@pytest.mark.asyncio +async def test_meta_none_semantics_missing_or_null_and_any_of_with_none(http, api_base, asset_factory, make_asset_bytes): + # a1: key missing; a2: explicit null; a3: concrete value + t = ["models", "checkpoints", "unit-tests", "mf-none"] + a1 = await asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1")) + a2 = await asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2")) + a3 = await asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3")) + + # Filter {maybe: None} must match a1 and a2, not a3 + filt = {"maybe": None} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt), "sort": "name"}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + got = [a["name"] for a in b1["assets"]] + assert a1["name"] in got and a2["name"] in got and a3["name"] not in got + + # Any-of with None should include missing/null plus value matches + filt_any = {"maybe": [None, "x"]} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt_any), "sort": "name"}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 + got2 = [a["name"] for a in b2["assets"]] + assert a1["name"] in got2 and a2["name"] in got2 and a3["name"] in got2 + + +@pytest.mark.asyncio +async def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes): + name = "mf_nested_json.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-nested"] + cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}} + await asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200)) + + # Exact JSON object equality (same structure) + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-nested", + "metadata_filter": json.dumps({"config": cfg}), + }, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # Different JSON object should not match + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-nested", + "metadata_filter": json.dumps({"config": {"optimizer": "sgd"}}), + }, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes): + name = "mf_list_objects.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-objlist"] + transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}] + await asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048)) + + # Any-of for list of objects should match when one element equals the filter object + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-objlist", + "metadata_filter": json.dumps({"transforms": {"type": "flip", "p": 0.5}}), + }, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # Non-matching object -> no match + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-objlist", + "metadata_filter": json.dumps({"transforms": {"type": "rotate", "deg": 90}}), + }, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes): + name = "mf_keys_unicode.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-keys"] + meta = { + "weird.key": "v1", + "path/like": 7, + "with:colon": True, + "ключ": "значение", + "emoji": "🐍", + } + await asset_factory(name, tags, meta, make_asset_bytes(name, 1500)) + + # Match all the special keys + filt = {"weird.key": "v1", "path/like": 7, "with:colon": True, "emoji": "🐍"} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps(filt)}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # Unicode key match + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps({"ключ": "значение"})}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and any(a["name"] == name for a in b2["assets"]) + + +@pytest.mark.asyncio +async def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"] + a0 = await asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025)) + a1 = await asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026)) + + # count == 0 must match only a0 + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"count": 0})}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names1 = [a["name"] for a in b1["assets"]] + assert a0["name"] in names1 and a1["name"] not in names1 + + # Any-of list of booleans: True matches second asset + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"choices": True})}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and any(a["name"] == a1["name"] for a in b2["assets"]) + + +@pytest.mark.asyncio +async def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes): + name = "mf_mixed_list.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-mixed"] + meta = {"mix": ["1", 1, True, None]} + await asset_factory(name, tags, meta, make_asset_bytes(name, 1999)) + + # Should match because 1 is present + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": [2, 1]})}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # Should NOT match for False + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": False})}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes): + # Use a unique scope tag to avoid interference + t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"] + x = await asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua")) + y = await asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub")) + + # Filtering by unknown key with None should return both (missing key OR null) + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": None})}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names = {a["name"] for a in b1["assets"]} + assert x["name"] in names and y["name"] in names + + # Filtering by unknown key with concrete value should return none + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": "x"})}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_factory, make_asset_bytes): + # alpha matches epoch=1; beta has epoch=2 + a = await asset_factory( + "mf_tag_alpha.safetensors", + ["models", "checkpoints", "unit-tests", "mf-tag", "alpha"], + {"epoch": 1}, + make_asset_bytes("alpha"), + ) + b = await asset_factory( + "mf_tag_beta.safetensors", + ["models", "checkpoints", "unit-tests", "mf-tag", "beta"], + {"epoch": 2}, + make_asset_bytes("beta"), + ) + + params = { + "include_tags": "unit-tests,mf-tag,alpha", + "exclude_tags": "beta", + "name_contains": "mf_tag_", + "metadata_filter": json.dumps({"epoch": 1}), + } + async with http.get(api_base + "/api/assets", params=params) as r: + body = await r.json() + assert r.status == 200 + names = [x["name"] for x in body["assets"]] + assert a["name"] in names + assert b["name"] not in names + + +@pytest.mark.asyncio +async def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes): + # Three assets in same scope with different sizes and a common filter key + t = ["models", "checkpoints", "unit-tests", "mf-sort"] + n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors" + await asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024)) + await asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048)) + await asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072)) + + # Sort by size ascending with paging + q = {"include_tags": "unit-tests,mf-sort", "metadata_filter": json.dumps({"group": "g"}), "sort": "size", "order": "asc", "limit": "2"} + async with http.get(api_base + "/api/assets", params=q) as r1: + b1 = await r1.json() + assert r1.status == 200 + got1 = [a["name"] for a in b1["assets"]] + assert got1 == [n1, n2] + assert b1["has_more"] is True + + q2 = {**q, "offset": "2"} + async with http.get(api_base + "/api/assets", params=q2) as r2: + b2 = await r2.json() + assert r2.status == 200 + got2 = [a["name"] for a in b2["assets"]] + assert got2 == [n3] + assert b2["has_more"] is False From 1886f10e1904cdd5ba6f41b29dc90ae50096c460 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 9 Sep 2025 19:30:58 +0300 Subject: [PATCH 39/82] add download tests --- tests-assets/conftest.py | 9 ++++--- tests-assets/test_downloads.py | 45 ++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) create mode 100644 tests-assets/test_downloads.py diff --git a/tests-assets/conftest.py b/tests-assets/conftest.py index 24eed0728362..16a43d56dce4 100644 --- a/tests-assets/conftest.py +++ b/tests-assets/conftest.py @@ -7,7 +7,7 @@ import tempfile import time from pathlib import Path -from typing import AsyncIterator, Callable +from typing import AsyncIterator, Callable, Optional import aiohttp import pytest @@ -191,13 +191,16 @@ async def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict: @pytest_asyncio.fixture -async def seeded_asset(http: aiohttp.ClientSession, api_base: str) -> dict: +async def seeded_asset(request: pytest.FixtureRequest, http: aiohttp.ClientSession, api_base: str) -> dict: """ Upload one asset into models/checkpoints/unit-tests/. Returns response dict with id, asset_hash, tags, etc. """ name = "unit_1_example.safetensors" - tags = ["models", "checkpoints", "unit-tests", "alpha"] + p = getattr(request, "param", {}) or {} + tags: Optional[list[str]] = p.get("tags") + if tags is None: + tags = ["models", "checkpoints", "unit-tests", "alpha"] meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None} form = aiohttp.FormData() form.add_field("file", b"A" * 4096, filename=name, content_type="application/octet-stream") diff --git a/tests-assets/test_downloads.py b/tests-assets/test_downloads.py new file mode 100644 index 000000000000..fad8e2cbb08e --- /dev/null +++ b/tests-assets/test_downloads.py @@ -0,0 +1,45 @@ +from pathlib import Path +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_download_attachment_and_inline(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # default attachment + async with http.get(f"{api_base}/api/assets/{aid}/content") as r1: + data = await r1.read() + assert r1.status == 200 + cd = r1.headers.get("Content-Disposition", "") + assert "attachment" in cd + assert data and len(data) == 4096 + + # inline requested + async with http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline") as r2: + await r2.read() + assert r2.status == 200 + cd2 = r2.headers.get("Content-Disposition", "") + assert "inline" in cd2 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True) +async def test_download_missing_file_returns_404( + http: aiohttp.ClientSession, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict +): + # Remove the underlying file then attempt download. + # We initialize fixture without additional tags to know exactly the asset file path. + aid = seeded_asset["id"] + async with http.get(f"{api_base}/api/assets/{aid}") as rg: + detail = await rg.json() + assert rg.status == 200 + rel_inside_category = detail["name"] + abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / rel_inside_category + if abs_path.exists(): + abs_path.unlink() + + async with http.get(f"{api_base}/api/assets/{aid}/content") as r2: + body = await r2.json() + assert r2.status == 404 + assert body["error"]["code"] == "FILE_NOT_FOUND" From 964de8a8adefc6c305b6e48855f5765654f56473 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 9 Sep 2025 20:35:18 +0300 Subject: [PATCH 40/82] add more list_assets tests + fix one found bug --- app/api/assets_routes.py | 11 +- tests-assets/test_list_filter.py | 220 ++++++++++++++++++++++++++++++- 2 files changed, 229 insertions(+), 2 deletions(-) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 001dfa324a97..47afbdb6dd3f 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -34,7 +34,16 @@ async def head_asset_by_hash(request: web.Request) -> web.Response: @ROUTES.get("/api/assets") async def list_assets(request: web.Request) -> web.Response: - query_dict = dict(request.rel_url.query) + qp = request.rel_url.query + query_dict = {} + if "include_tags" in qp: + query_dict["include_tags"] = qp.getall("include_tags") + if "exclude_tags" in qp: + query_dict["exclude_tags"] = qp.getall("exclude_tags") + for k in ("name_contains", "metadata_filter", "limit", "offset", "sort", "order"): + v = qp.get(k) + if v is not None: + query_dict[k] = v try: q = schemas_in.ListAssetsQuery.model_validate(query_dict) diff --git a/tests-assets/test_list_filter.py b/tests-assets/test_list_filter.py index 82abf4ab39e0..b0b476af5bf4 100644 --- a/tests-assets/test_list_filter.py +++ b/tests-assets/test_list_filter.py @@ -1,4 +1,5 @@ -import json +import asyncio + import aiohttp import pytest @@ -70,6 +71,223 @@ async def test_list_assets_include_exclude_and_name_contains(http: aiohttp.Clien assert not body3["assets"] +@pytest.mark.asyncio +async def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-size"] + n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors" + await asset_factory(n1, t, {}, make_asset_bytes(n1, 1024)) + await asset_factory(n2, t, {}, make_asset_bytes(n2, 2048)) + await asset_factory(n3, t, {}, make_asset_bytes(n3, 3072)) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "asc"}, + ) as r1: + b1 = await r1.json() + names = [a["name"] for a in b1["assets"]] + assert names[:3] == [n1, n2, n3] + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "desc"}, + ) as r2: + b2 = await r2.json() + names2 = [a["name"] for a in b2["assets"]] + assert names2[:3] == [n3, n2, n1] + + + +@pytest.mark.asyncio +async def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-upd"] + a1 = await asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200)) + a2 = await asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200)) + + # Rename the second asset to bump updated_at + async with http.put(f"{api_base}/api/assets/{a2['id']}", json={"name": "upd_b_renamed.safetensors"}) as rp: + upd = await rp.json() + assert rp.status == 200, upd + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-upd", "sort": "updated_at", "order": "desc"}, + ) as r: + body = await r.json() + assert r.status == 200 + names = [x["name"] for x in body["assets"]] + assert names[0] == "upd_b_renamed.safetensors" + assert a1["name"] in names + + + +@pytest.mark.asyncio +async def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-access"] + await asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100)) + await asyncio.sleep(0.02) + a2 = await asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100)) + + # Touch last_access_time of b by downloading its content + await asyncio.sleep(0.02) + async with http.get(f"{api_base}/api/assets/{a2['id']}/content") as dl: + assert dl.status == 200 + await dl.read() + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-access", "sort": "last_access_time", "order": "desc"}, + ) as r: + body = await r.json() + assert r.status == 200 + names = [x["name"] for x in body["assets"]] + assert names[0] == a2["name"] + + +@pytest.mark.asyncio +async def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-include"] + a = await asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva")) + await asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb")) + + # CSV + case-insensitive + async with http.get( + api_base + "/api/assets", + params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a["name"] in names1 + assert not any("beta" in x for x in names1) + + # Repeated query params for include_tags + params_multi = [ + ("include_tags", "unit-tests"), + ("include_tags", "lf-include"), + ("include_tags", "alpha"), + ] + async with http.get(api_base + "/api/assets", params=params_multi) as r2: + b2 = await r2.json() + assert r2.status == 200 + names2 = [x["name"] for x in b2["assets"]] + assert a["name"] in names2 + assert not any("beta" in x for x in names2) + + # Duplicates and spaces in CSV + async with http.get( + api_base + "/api/assets", + params={"include_tags": " unit-tests , lf-include , alpha , alpha "}, + ) as r3: + b3 = await r3.json() + assert r3.status == 200 + names3 = [x["name"] for x in b3["assets"]] + assert a["name"] in names3 + + +@pytest.mark.asyncio +async def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-exclude"] + a = await asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900)) + await asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900)) + + # Exclude uppercase should work + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a["name"] in names1 + # Repeated excludes with duplicates + params_multi = [ + ("include_tags", "unit-tests"), + ("include_tags", "lf-exclude"), + ("exclude_tags", "beta"), + ("exclude_tags", "beta"), + ] + async with http.get(api_base + "/api/assets", params=params_multi) as r2: + b2 = await r2.json() + assert r2.status == 200 + names2 = [x["name"] for x in b2["assets"]] + assert all("beta" not in x for x in names2) + + +@pytest.mark.asyncio +async def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-name"] + a1 = await asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800)) + a2 = await asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800)) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": "casemix"}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a1["name"] in names1 + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": ".SAFE"}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 + names2 = [x["name"] for x in b2["assets"]] + assert a1["name"] in names2 + + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": "case-"}, + ) as r3: + b3 = await r3.json() + assert r3.status == 200 + names3 = [x["name"] for x in b3["assets"]] + assert a2["name"] in names3 + + +@pytest.mark.asyncio +async def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"] + await asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600)) + await asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600)) + await asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600)) + + # Offset far beyond total + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-pagelimits", "limit": "2", "offset": "10"}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + assert not b1["assets"] + assert b1["has_more"] is False + + # Boundary large limit (<=500 is valid) + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-pagelimits", "limit": "500"}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 + assert len(b2["assets"]) == 3 + assert b2["has_more"] is False + + +@pytest.mark.asyncio +async def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base): + async with http.get(api_base + "/api/assets", params={"offset": "-1"}) as r1: + b1 = await r1.json() + assert r1.status == 400 + assert b1["error"]["code"] == "INVALID_QUERY" + + async with http.get(api_base + "/api/assets", params={"limit": "abc"}) as r2: + b2 = await r2.json() + assert r2.status == 400 + assert b2["error"]["code"] == "INVALID_QUERY" + + @pytest.mark.asyncio async def test_list_assets_invalid_query_rejected(http: aiohttp.ClientSession, api_base: str): # limit too small From a9096f6c971f1d0180086a7ed464c38a7594b32f Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 9 Sep 2025 20:54:11 +0300 Subject: [PATCH 41/82] removed non-needed code, fix tests, +1 new test --- app/api/assets_routes.py | 43 +++++---------------------------- tests-assets/test_downloads.py | 29 +++++++++++++--------- tests-assets/test_validation.py | 23 ++++++++++++++++++ 3 files changed, 46 insertions(+), 49 deletions(-) create mode 100644 tests-assets/test_validation.py diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 47afbdb6dd3f..e9a4ff97afe2 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -66,19 +66,13 @@ async def list_assets(request: web.Request) -> web.Response: @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") async def download_asset_content(request: web.Request) -> web.Response: - asset_info_id_raw = request.match_info.get("id", "") - try: - asset_info_id = str(uuid.UUID(asset_info_id_raw)) - except Exception: - return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") - disposition = request.query.get("disposition", "attachment").lower().strip() if disposition not in {"inline", "attachment"}: disposition = "attachment" try: abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download( - asset_info_id=asset_info_id, + asset_info_id=str(uuid.UUID(request.match_info["id"])), owner_id=UserManager.get_request_user_id(request), ) except ValueError as ve: @@ -300,12 +294,7 @@ async def upload_asset(request: web.Request) -> web.Response: @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") async def get_asset(request: web.Request) -> web.Response: - asset_info_id_raw = request.match_info.get("id", "") - try: - asset_info_id = str(uuid.UUID(asset_info_id_raw)) - except Exception: - return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") - + asset_info_id = str(uuid.UUID(request.match_info["id"])) try: result = await assets_manager.get_asset( asset_info_id=asset_info_id, @@ -320,12 +309,7 @@ async def get_asset(request: web.Request) -> web.Response: @ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") async def update_asset(request: web.Request) -> web.Response: - asset_info_id_raw = request.match_info.get("id", "") - try: - asset_info_id = str(uuid.UUID(asset_info_id_raw)) - except Exception: - return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") - + asset_info_id = str(uuid.UUID(request.match_info["id"])) try: body = schemas_in.UpdateAssetBody.model_validate(await request.json()) except ValidationError as ve: @@ -350,12 +334,7 @@ async def update_asset(request: web.Request) -> web.Response: @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") async def delete_asset(request: web.Request) -> web.Response: - asset_info_id_raw = request.match_info.get("id", "") - try: - asset_info_id = str(uuid.UUID(asset_info_id_raw)) - except Exception: - return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") - + asset_info_id = str(uuid.UUID(request.match_info["id"])) delete_content = request.query.get("delete_content") delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"} @@ -398,12 +377,7 @@ async def get_tags(request: web.Request) -> web.Response: @ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") async def add_asset_tags(request: web.Request) -> web.Response: - asset_info_id_raw = request.match_info.get("id", "") - try: - asset_info_id = str(uuid.UUID(asset_info_id_raw)) - except Exception: - return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") - + asset_info_id = str(uuid.UUID(request.match_info["id"])) try: payload = await request.json() data = schemas_in.TagsAdd.model_validate(payload) @@ -429,12 +403,7 @@ async def add_asset_tags(request: web.Request) -> web.Response: @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") async def delete_asset_tags(request: web.Request) -> web.Response: - asset_info_id_raw = request.match_info.get("id", "") - try: - asset_info_id = str(uuid.UUID(asset_info_id_raw)) - except Exception: - return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") - + asset_info_id = str(uuid.UUID(request.match_info["id"])) try: payload = await request.json() data = schemas_in.TagsRemove.model_validate(payload) diff --git a/tests-assets/test_downloads.py b/tests-assets/test_downloads.py index fad8e2cbb08e..7a449dfe8cfc 100644 --- a/tests-assets/test_downloads.py +++ b/tests-assets/test_downloads.py @@ -30,16 +30,21 @@ async def test_download_missing_file_returns_404( ): # Remove the underlying file then attempt download. # We initialize fixture without additional tags to know exactly the asset file path. - aid = seeded_asset["id"] - async with http.get(f"{api_base}/api/assets/{aid}") as rg: - detail = await rg.json() - assert rg.status == 200 - rel_inside_category = detail["name"] - abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / rel_inside_category - if abs_path.exists(): - abs_path.unlink() + try: + aid = seeded_asset["id"] + async with http.get(f"{api_base}/api/assets/{aid}") as rg: + detail = await rg.json() + assert rg.status == 200 + rel_inside_category = detail["name"] + abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / rel_inside_category + if abs_path.exists(): + abs_path.unlink() - async with http.get(f"{api_base}/api/assets/{aid}/content") as r2: - body = await r2.json() - assert r2.status == 404 - assert body["error"]["code"] == "FILE_NOT_FOUND" + async with http.get(f"{api_base}/api/assets/{aid}/content") as r2: + body = await r2.json() + assert r2.status == 404 + assert body["error"]["code"] == "FILE_NOT_FOUND" + finally: + # We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually. + async with http.delete(f"{api_base}/api/assets/{aid}") as dr: + await dr.read() diff --git a/tests-assets/test_validation.py b/tests-assets/test_validation.py new file mode 100644 index 000000000000..24ccc794f941 --- /dev/null +++ b/tests-assets/test_validation.py @@ -0,0 +1,23 @@ +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_get_update_download_bad_ids(http: aiohttp.ClientSession, api_base: str): + # All endpoints should be not found, as we UUID regex directly in the route definition. + bad_id = "not-a-uuid" + + async with http.get(f"{api_base}/api/assets/{bad_id}") as r1: + assert r1.status == 404 + + async with http.get(f"{api_base}/api/assets/{bad_id}/content") as r3: + assert r3.status == 404 + + +@pytest.mark.asyncio +async def test_update_requires_at_least_one_field(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + async with http.put(f"{api_base}/api/assets/{aid}", json={}) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "INVALID_BODY" From 6eaed072c75920dba672347424c3ee7e37300383 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 10 Sep 2025 09:51:06 +0300 Subject: [PATCH 42/82] add some logic tests --- tests-assets/test_crud.py | 63 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 tests-assets/test_crud.py diff --git a/tests-assets/test_crud.py b/tests-assets/test_crud.py new file mode 100644 index 000000000000..5e0d953d572c --- /dev/null +++ b/tests-assets/test_crud.py @@ -0,0 +1,63 @@ +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_get_and_delete_asset(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # GET detail + async with http.get(f"{api_base}/api/assets/{aid}") as rg: + detail = await rg.json() + assert rg.status == 200, detail + assert detail["id"] == aid + assert "user_metadata" in detail + assert "filename" in detail["user_metadata"] + + # DELETE + async with http.delete(f"{api_base}/api/assets/{aid}") as rd: + assert rd.status == 204 + + # GET again -> 404 + async with http.get(f"{api_base}/api/assets/{aid}") as rg2: + body = await rg2.json() + assert rg2.status == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +@pytest.mark.asyncio +async def test_update_asset_fields(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + payload = { + "name": "unit_1_renamed.safetensors", + "tags": ["models", "checkpoints", "unit-tests", "beta"], + "user_metadata": {"purpose": "updated", "epoch": 2}, + } + async with http.put(f"{api_base}/api/assets/{aid}", json=payload) as ru: + body = await ru.json() + assert ru.status == 200, body + assert body["name"] == payload["name"] + assert "beta" in body["tags"] + assert body["user_metadata"]["purpose"] == "updated" + # filename should still be present and normalized by server + assert "filename" in body["user_metadata"] + + +@pytest.mark.asyncio +async def test_head_asset_by_hash_and_invalids(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + h = seeded_asset["asset_hash"] + + # Existing + async with http.head(f"{api_base}/api/assets/hash/{h}") as rh1: + assert rh1.status == 200 + + # Non-existent + async with http.head(f"{api_base}/api/assets/hash/blake3:{'0'*64}") as rh2: + assert rh2.status == 404 + + # Invalid format + async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh3: + jb = await rh3.json() + assert rh3.status == 400 + assert jb is None # HEAD request should not include "body" in response From 72548a8ac4026827dccb07108e4d5372ac084e03 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 10 Sep 2025 10:39:55 +0300 Subject: [PATCH 43/82] added additional tests; sorted tests --- tests-assets/test_crud.py | 116 ++++++++++++++++++- tests-assets/test_tags.py | 27 +++++ tests-assets/test_uploads.py | 192 ++++++++++++++++++++++++-------- tests-assets/test_validation.py | 23 ---- 4 files changed, 289 insertions(+), 69 deletions(-) delete mode 100644 tests-assets/test_validation.py diff --git a/tests-assets/test_crud.py b/tests-assets/test_crud.py index 5e0d953d572c..99ea329c52d9 100644 --- a/tests-assets/test_crud.py +++ b/tests-assets/test_crud.py @@ -1,7 +1,33 @@ +import uuid import aiohttp import pytest +@pytest.mark.asyncio +async def test_create_from_hash_success( + http: aiohttp.ClientSession, api_base: str, seeded_asset: dict +): + h = seeded_asset["asset_hash"] + payload = { + "hash": h, + "name": "from_hash_ok.safetensors", + "tags": ["models", "checkpoints", "unit-tests", "from-hash"], + "user_metadata": {"k": "v"}, + } + async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r1: + b1 = await r1.json() + assert r1.status == 201, b1 + assert b1["asset_hash"] == h + assert b1["created_new"] is False + aid = b1["id"] + + # Calling again with the same name should return the same AssetInfo id + async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r2: + b2 = await r2.json() + assert r2.status == 201, b2 + assert b2["id"] == aid + + @pytest.mark.asyncio async def test_get_and_delete_asset(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): aid = seeded_asset["id"] @@ -25,6 +51,41 @@ async def test_get_and_delete_asset(http: aiohttp.ClientSession, api_base: str, assert body["error"]["code"] == "ASSET_NOT_FOUND" +@pytest.mark.asyncio +async def test_delete_upon_reference_count( + http: aiohttp.ClientSession, api_base: str, seeded_asset: dict +): + # Create a second reference to the same asset via from-hash + src_hash = seeded_asset["asset_hash"] + payload = { + "hash": src_hash, + "name": "unit_ref_copy.safetensors", + "tags": ["models", "checkpoints", "unit-tests", "del-flow"], + "user_metadata": {"note": "copy"}, + } + async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r2: + copy = await r2.json() + assert r2.status == 201, copy + assert copy["asset_hash"] == src_hash + assert copy["created_new"] is False + + # Delete original reference -> asset identity must remain + aid1 = seeded_asset["id"] + async with http.delete(f"{api_base}/api/assets/{aid1}") as rd1: + assert rd1.status == 204 + + async with http.head(f"{api_base}/api/assets/hash/{src_hash}") as rh1: + assert rh1.status == 200 # identity still present + + # Delete the last reference with default semantics -> identity and cached files removed + aid2 = copy["id"] + async with http.delete(f"{api_base}/api/assets/{aid2}") as rd2: + assert rd2.status == 204 + + async with http.head(f"{api_base}/api/assets/hash/{src_hash}") as rh2: + assert rh2.status == 404 # orphan content removed + + @pytest.mark.asyncio async def test_update_asset_fields(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): aid = seeded_asset["id"] @@ -45,7 +106,7 @@ async def test_update_asset_fields(http: aiohttp.ClientSession, api_base: str, s @pytest.mark.asyncio -async def test_head_asset_by_hash_and_invalids(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): +async def test_head_asset_by_hash(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): h = seeded_asset["asset_hash"] # Existing @@ -56,6 +117,59 @@ async def test_head_asset_by_hash_and_invalids(http: aiohttp.ClientSession, api_ async with http.head(f"{api_base}/api/assets/hash/blake3:{'0'*64}") as rh2: assert rh2.status == 404 + +@pytest.mark.asyncio +async def test_delete_nonexistent_returns_404(http: aiohttp.ClientSession, api_base: str): + bogus = str(uuid.uuid4()) + async with http.delete(f"{api_base}/api/assets/{bogus}") as r: + body = await r.json() + assert r.status == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +@pytest.mark.asyncio +async def test_create_from_hash_invalids(http: aiohttp.ClientSession, api_base: str): + # Bad hash algorithm + bad = { + "hash": "sha256:" + "0" * 64, + "name": "x.bin", + "tags": ["models", "checkpoints", "unit-tests"], + } + async with http.post(f"{api_base}/api/assets/from-hash", json=bad) as r1: + b1 = await r1.json() + assert r1.status == 400 + assert b1["error"]["code"] == "INVALID_BODY" + + # Invalid JSON body + async with http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}") as r2: + b2 = await r2.json() + assert r2.status == 400 + assert b2["error"]["code"] == "INVALID_JSON" + + +@pytest.mark.asyncio +async def test_get_update_download_bad_ids(http: aiohttp.ClientSession, api_base: str): + # All endpoints should be not found, as we UUID regex directly in the route definition. + bad_id = "not-a-uuid" + + async with http.get(f"{api_base}/api/assets/{bad_id}") as r1: + assert r1.status == 404 + + async with http.get(f"{api_base}/api/assets/{bad_id}/content") as r3: + assert r3.status == 404 + + +@pytest.mark.asyncio +async def test_update_requires_at_least_one_field(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + async with http.put(f"{api_base}/api/assets/{aid}", json={}) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.asyncio +async def test_head_asset_bad_hash(http: aiohttp.ClientSession, api_base: str): # Invalid format async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh3: jb = await rh3.json() diff --git a/tests-assets/test_tags.py b/tests-assets/test_tags.py index 273efdd03d76..bba91581fb23 100644 --- a/tests-assets/test_tags.py +++ b/tests-assets/test_tags.py @@ -1,3 +1,4 @@ +import json import aiohttp import pytest @@ -98,3 +99,29 @@ async def test_tags_list_order_and_prefix(http: aiohttp.ClientSession, api_base: b2 = await r2.json() assert r2.status == 400 assert b2["error"]["code"] == "INVALID_QUERY" + + +@pytest.mark.asyncio +async def test_tags_endpoints_invalid_bodies(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # Add with empty list + async with http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": []}) as r1: + b1 = await r1.json() + assert r1.status == 400 + assert b1["error"]["code"] == "INVALID_BODY" + + # Remove with wrong type + async with http.delete(f"{api_base}/api/assets/{aid}/tags", json={"tags": [123]}) as r2: + b2 = await r2.json() + assert r2.status == 400 + assert b2["error"]["code"] == "INVALID_BODY" + + # metadata_filter provided as JSON array should be rejected (must be object) + async with http.get( + api_base + "/api/assets", + params={"metadata_filter": json.dumps([{"x": 1}])}, + ) as r3: + b3 = await r3.json() + assert r3.status == 400 + assert b3["error"]["code"] == "INVALID_QUERY" diff --git a/tests-assets/test_uploads.py b/tests-assets/test_uploads.py index 8c50beb0993b..1d8df4e40f3b 100644 --- a/tests-assets/test_uploads.py +++ b/tests-assets/test_uploads.py @@ -3,51 +3,6 @@ import pytest -@pytest.mark.asyncio -async def test_upload_requires_multipart(http: aiohttp.ClientSession, api_base: str): - async with http.post(api_base + "/api/assets", json={"foo": "bar"}) as r: - body = await r.json() - assert r.status == 415 - assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE" - - -@pytest.mark.asyncio -async def test_upload_missing_file_and_hash(http: aiohttp.ClientSession, api_base: str): - form = aiohttp.FormData(default_to_multipart=True) - form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests"])) - form.add_field("name", "x.safetensors") - async with http.post(api_base + "/api/assets", data=form) as r: - body = await r.json() - assert r.status == 400 - assert body["error"]["code"] == "MISSING_FILE" - - -@pytest.mark.asyncio -async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_base: str): - form = aiohttp.FormData() - form.add_field("file", b"A" * 128, filename="m.safetensors", content_type="application/octet-stream") - form.add_field("tags", json.dumps(["models", "no_such_category", "unit-tests"])) - form.add_field("name", "m.safetensors") - async with http.post(api_base + "/api/assets", data=form) as r: - body = await r.json() - assert r.status == 400 - assert body["error"]["code"] == "INVALID_BODY" - assert "unknown models category" in body["error"]["message"] or "unknown model category" in body["error"]["message"] - - -@pytest.mark.asyncio -async def test_upload_tags_traversal_guard(http: aiohttp.ClientSession, api_base: str): - form = aiohttp.FormData() - form.add_field("file", b"A" * 256, filename="evil.safetensors", content_type="application/octet-stream") - # '..' should be rejected by destination resolver - form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"])) - form.add_field("name", "evil.safetensors") - async with http.post(api_base + "/api/assets", data=form) as r: - body = await r.json() - assert r.status == 400 - assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") - - @pytest.mark.asyncio async def test_upload_ok_duplicate_reference(http: aiohttp.ClientSession, api_base: str, make_asset_bytes): name = "dup_a.safetensors" @@ -120,6 +75,56 @@ async def test_upload_fastpath_from_existing_hash_no_file(http: aiohttp.ClientSe assert b2["asset_hash"] == h +@pytest.mark.asyncio +async def test_upload_fastpath_with_known_hash_and_file( + http: aiohttp.ClientSession, api_base: str +): + # Seed + form1 = aiohttp.FormData() + form1.add_field("file", b"C" * 128, filename="seed.safetensors", content_type="application/octet-stream") + form1.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "fp"])) + form1.add_field("name", "seed.safetensors") + form1.add_field("user_metadata", json.dumps({})) + async with http.post(api_base + "/api/assets", data=form1) as r1: + b1 = await r1.json() + assert r1.status == 201, b1 + h = b1["asset_hash"] + + # Send both file and hash of existing content -> server must drain file and create from hash (200) + form2 = aiohttp.FormData() + form2.add_field("file", b"ignored" * 10, filename="ignored.bin", content_type="application/octet-stream") + form2.add_field("hash", h) + form2.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "fp"])) + form2.add_field("name", "copy_from_hash.safetensors") + form2.add_field("user_metadata", json.dumps({})) + async with http.post(api_base + "/api/assets", data=form2) as r2: + b2 = await r2.json() + assert r2.status == 200, b2 + assert b2["created_new"] is False + assert b2["asset_hash"] == h + + +@pytest.mark.asyncio +async def test_upload_multiple_tags_fields_are_merged(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"B" * 256, filename="merge.safetensors", content_type="application/octet-stream") + form.add_field("tags", "models,checkpoints") # CSV + form.add_field("tags", json.dumps(["unit-tests", "alpha"])) # JSON array in second field + form.add_field("name", "merge.safetensors") + form.add_field("user_metadata", json.dumps({"u": 1})) + async with http.post(api_base + "/api/assets", data=form) as r1: + created = await r1.json() + assert r1.status in (200, 201), created + aid = created["id"] + + # Verify all tags are present on the resource + async with http.get(f"{api_base}/api/assets/{aid}") as rg: + detail = await rg.json() + assert rg.status == 200, detail + tags = set(detail["tags"]) + assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags) + + @pytest.mark.asyncio async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str): payload = { @@ -131,3 +136,100 @@ async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_ba body = await r.json() assert r.status == 404 assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +@pytest.mark.asyncio +async def test_upload_zero_byte_rejected(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"", filename="empty.safetensors", content_type="application/octet-stream") + form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "edge"])) + form.add_field("name", "empty.safetensors") + form.add_field("user_metadata", json.dumps({})) + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "EMPTY_UPLOAD" + + +@pytest.mark.asyncio +async def test_upload_invalid_root_tag_rejected(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 64, filename="badroot.bin", content_type="application/octet-stream") + form.add_field("tags", json.dumps(["not-a-root", "whatever"])) + form.add_field("name", "badroot.bin") + form.add_field("user_metadata", json.dumps({})) + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.asyncio +async def test_upload_user_metadata_must_be_json(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 128, filename="badmeta.bin", content_type="application/octet-stream") + form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "edge"])) + form.add_field("name", "badmeta.bin") + form.add_field("user_metadata", "{not json}") # invalid + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.asyncio +async def test_upload_requires_multipart(http: aiohttp.ClientSession, api_base: str): + async with http.post(api_base + "/api/assets", json={"foo": "bar"}) as r: + body = await r.json() + assert r.status == 415 + assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE" + + +@pytest.mark.asyncio +async def test_upload_missing_file_and_hash(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData(default_to_multipart=True) + form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests"])) + form.add_field("name", "x.safetensors") + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "MISSING_FILE" + + +@pytest.mark.asyncio +async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 128, filename="m.safetensors", content_type="application/octet-stream") + form.add_field("tags", json.dumps(["models", "no_such_category", "unit-tests"])) + form.add_field("name", "m.safetensors") + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "INVALID_BODY" + assert "unknown models category" in body["error"]["message"] or "unknown model category" in body["error"]["message"] + + +@pytest.mark.asyncio +async def test_upload_models_requires_category(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 64, filename="nocat.safetensors", content_type="application/octet-stream") + form.add_field("tags", json.dumps(["models"])) # missing category + form.add_field("name", "nocat.safetensors") + form.add_field("user_metadata", json.dumps({})) + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.asyncio +async def test_upload_tags_traversal_guard(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 256, filename="evil.safetensors", content_type="application/octet-stream") + # '..' should be rejected by destination resolver + form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"])) + form.add_field("name", "evil.safetensors") + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") diff --git a/tests-assets/test_validation.py b/tests-assets/test_validation.py deleted file mode 100644 index 24ccc794f941..000000000000 --- a/tests-assets/test_validation.py +++ /dev/null @@ -1,23 +0,0 @@ -import aiohttp -import pytest - - -@pytest.mark.asyncio -async def test_get_update_download_bad_ids(http: aiohttp.ClientSession, api_base: str): - # All endpoints should be not found, as we UUID regex directly in the route definition. - bad_id = "not-a-uuid" - - async with http.get(f"{api_base}/api/assets/{bad_id}") as r1: - assert r1.status == 404 - - async with http.get(f"{api_base}/api/assets/{bad_id}/content") as r3: - assert r3.status == 404 - - -@pytest.mark.asyncio -async def test_update_requires_at_least_one_field(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): - aid = seeded_asset["id"] - async with http.put(f"{api_base}/api/assets/{aid}", json={}) as r: - body = await r.json() - assert r.status == 400 - assert body["error"]["code"] == "INVALID_BODY" From 0df1ccac6fb1e33015784e12b903e01ee0fa6f9b Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 10 Sep 2025 11:45:03 +0300 Subject: [PATCH 44/82] GitHub CI test for Assets --- .github/workflows/test-assets.yml | 174 +++++++++++++++++++++++++++++ alembic_db/versions/0001_assets.py | 3 +- app/database/db.py | 45 +++----- app/database/models.py | 24 ++-- tests-assets/conftest.py | 61 +++++++--- 5 files changed, 250 insertions(+), 57 deletions(-) create mode 100644 .github/workflows/test-assets.yml diff --git a/.github/workflows/test-assets.yml b/.github/workflows/test-assets.yml new file mode 100644 index 000000000000..3b3a7c73f3a4 --- /dev/null +++ b/.github/workflows/test-assets.yml @@ -0,0 +1,174 @@ +name: Asset System Tests + +on: + push: + paths: + - 'app/**' + - 'alembic_db/**' + - 'tests-assets/**' + - '.github/workflows/test-assets.yml' + - 'requirements.txt' + pull_request: + branches: [master] + workflow_dispatch: + +permissions: + contents: read + +env: + PIP_DISABLE_PIP_VERSION_CHECK: '1' + PYTHONUNBUFFERED: '1' + +jobs: + sqlite: + name: SQLite (${{ matrix.sqlite_mode }}) • Python ${{ matrix.python }} + runs-on: ubuntu-latest + timeout-minutes: 40 + strategy: + fail-fast: false + matrix: + python: ['3.9', '3.12'] + sqlite_mode: ['memory', 'file'] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + + - name: Install dependencies + run: | + python -m pip install -U pip wheel + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt + pip install pytest pytest-aiohttp pytest-asyncio + + - name: Set deterministic test base dir + id: basedir + shell: bash + run: | + BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}" + echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV" + echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV" + mkdir -p "$BASE/logs" + echo "ASSETS_TEST_BASE_DIR=$BASE" + + - name: Set DB URL for SQLite + id: setdb + shell: bash + run: | + if [ "${{ matrix.sqlite_mode }}" = "memory" ]; then + echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///:memory:" >> "$GITHUB_ENV" + else + DBFILE="$RUNNER_TEMP/assets-tests.sqlite" + mkdir -p "$(dirname "$DBFILE")" + echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///$DBFILE" >> "$GITHUB_ENV" + fi + + - name: Run tests + run: python -m pytest tests-assets + + - name: Show ComfyUI logs + if: always() + shell: bash + run: | + echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ====" + echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ====" + ls -la "$ASSETS_TEST_LOGS" || true + for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do + if [ -f "$f" ]; then + echo "----- BEGIN $f -----" + sed -n '1,400p' "$f" + echo "----- END $f -----" + fi + done + + - name: Upload ComfyUI logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: asset-logs-sqlite-${{ matrix.sqlite_mode }}-py${{ matrix.python }} + path: ${{ env.ASSETS_TEST_LOGS }}/*.log + if-no-files-found: warn + + postgres: + name: PostgreSQL ${{ matrix.pgsql }} • Python ${{ matrix.python }} + runs-on: ubuntu-latest + timeout-minutes: 40 + strategy: + fail-fast: false + matrix: + python: ['3.9', '3.12'] + pgsql: ['14', '16'] + + services: + postgres: + image: postgres:${{ matrix.pgsql }} + env: + POSTGRES_DB: assets + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U postgres -d assets" + --health-interval 10s + --health-timeout 5s + --health-retries 12 + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + + - name: Install dependencies + run: | + python -m pip install -U pip wheel + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt + pip install pytest pytest-aiohttp pytest-asyncio + pip install greenlet psycopg + + - name: Set deterministic test base dir + id: basedir + shell: bash + run: | + BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}" + echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV" + echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV" + mkdir -p "$BASE/logs" + echo "ASSETS_TEST_BASE_DIR=$BASE" + + - name: Set DB URL for PostgreSQL + shell: bash + run: | + echo "ASSETS_TEST_DB_URL=postgresql+psycopg://postgres:postgres@localhost:5432/assets" >> "$GITHUB_ENV" + + - name: Run tests + run: python -m pytest tests-assets + + - name: Show ComfyUI logs + if: always() + shell: bash + run: | + echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ====" + echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ====" + ls -la "$ASSETS_TEST_LOGS" || true + for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do + if [ -f "$f" ]; then + echo "----- BEGIN $f -----" + sed -n '1,400p' "$f" + echo "----- END $f -----" + fi + done + + - name: Upload ComfyUI logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: asset-logs-pgsql-${{ matrix.pgsql }}-py${{ matrix.python }} + path: ${{ env.ASSETS_TEST_LOGS }}/*.log + if-no-files-found: warn diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index 9481100b0322..f3b3ee0bfe46 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -7,6 +7,7 @@ from alembic import op import sqlalchemy as sa +from sqlalchemy.dialects import postgresql revision = "0001_assets" down_revision = None @@ -90,7 +91,7 @@ def upgrade() -> None: sa.Column("val_str", sa.String(length=2048), nullable=True), sa.Column("val_num", sa.Numeric(38, 10), nullable=True), sa.Column("val_bool", sa.Boolean(), nullable=True), - sa.Column("val_json", sa.JSON(), nullable=True), + sa.Column("val_json", sa.JSON().with_variant(postgresql.JSONB(), 'postgresql'), nullable=True), sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"), ) op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"]) diff --git a/app/database/db.py b/app/database/db.py index eaf6648dba5d..8280272b0216 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -69,23 +69,6 @@ def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]: return str(u), False -def _to_sync_driver_url(async_url: str) -> str: - """Convert an async SQLAlchemy URL to a sync URL for Alembic.""" - u = make_url(async_url) - driver = u.drivername - - if driver.startswith("sqlite+aiosqlite"): - u = u.set(drivername="sqlite") - elif driver.startswith("postgresql+asyncpg"): - u = u.set(drivername="postgresql") - else: - # Generic: strip the async driver part if present - if "+" in driver: - u = u.set(drivername=driver.split("+", 1)[0]) - - return str(u) - - def _get_sqlite_file_path(sync_url: str) -> Optional[str]: """Return the on-disk path for a SQLite URL, else None.""" try: @@ -159,7 +142,7 @@ async def init_db_engine() -> None: await conn.execute(text("PRAGMA foreign_keys = ON;")) await conn.execute(text("PRAGMA synchronous = NORMAL;")) - await _run_migrations(raw_url=db_url, connect_args=connect_args) + await _run_migrations(database_url=db_url, connect_args=connect_args) SESSION = async_sessionmaker( bind=ENGINE, @@ -170,20 +153,18 @@ async def init_db_engine() -> None: ) -async def _run_migrations(raw_url: str, connect_args: dict) -> None: - """ - Run Alembic migrations up to head. - - We deliberately use a synchronous engine for migrations because Alembic's - programmatic API is synchronous by default and this path is robust. - """ - # Convert to sync URL and make SQLite URL an absolute one - sync_url = _to_sync_driver_url(raw_url) - sync_url, is_mem = _normalize_sqlite_memory_url(sync_url) - sync_url = _absolutize_sqlite_url(sync_url) +async def _run_migrations(database_url: str, connect_args: dict) -> None: + if database_url.find("postgresql+psycopg") == -1: + """SQLite: Convert an async SQLAlchemy URL to a sync URL for Alembic.""" + u = make_url(database_url) + driver = u.drivername + if not driver.startswith("sqlite+aiosqlite"): + raise ValueError(f"Unsupported DB driver: {driver}") + database_url, is_mem = _normalize_sqlite_memory_url(str(u.set(drivername="sqlite"))) + database_url = _absolutize_sqlite_url(database_url) - cfg = _get_alembic_config(sync_url) - engine = create_engine(sync_url, future=True, connect_args=connect_args) + cfg = _get_alembic_config(database_url) + engine = create_engine(database_url, future=True, connect_args=connect_args) with engine.connect() as conn: context = MigrationContext.configure(conn) current_rev = context.get_current_revision() @@ -203,7 +184,7 @@ async def _run_migrations(raw_url: str, connect_args: dict) -> None: # Optional backup for SQLite file DBs backup_path = None - sqlite_path = _get_sqlite_file_path(sync_url) + sqlite_path = _get_sqlite_file_path(database_url) if sqlite_path and os.path.exists(sqlite_path): backup_path = sqlite_path + ".bkp" try: diff --git a/app/database/models.py b/app/database/models.py index 5bb3a09bc27b..55fc08e512cc 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from datetime import datetime from typing import Any, Optional import uuid @@ -18,11 +16,15 @@ Numeric, Boolean, ) +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign from .timeutil import utcnow +JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql') + + class Base(DeclarativeBase): pass @@ -46,7 +48,7 @@ class Asset(Base): hash: Mapped[str] = mapped_column(String(256), primary_key=True) size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) - mime_type: Mapped[str | None] = mapped_column(String(255)) + mime_type: Mapped[Optional[str]] = mapped_column(String(255)) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=False), nullable=False, default=utcnow ) @@ -97,7 +99,7 @@ class AssetCacheState(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False) file_path: Mapped[str] = mapped_column(Text, nullable=False) - mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True) asset: Mapped["Asset"] = relationship(back_populates="cache_states") @@ -122,9 +124,9 @@ class AssetLocation(Base): asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False) provider: Mapped[str] = mapped_column(String(32), nullable=False) # "gcs" locator: Mapped[str] = mapped_column(Text, nullable=False) # "gs://bucket/object" - expected_size_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True) - etag: Mapped[str | None] = mapped_column(String(256), nullable=True) - last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True) + expected_size_bytes: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True) + etag: Mapped[Optional[str]] = mapped_column(String(256), nullable=True) + last_modified: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) asset: Mapped["Asset"] = relationship(back_populates="locations") @@ -144,8 +146,8 @@ class AssetInfo(Base): asset_hash: Mapped[str] = mapped_column( String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False ) - preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL")) - user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True)) + preview_hash: Mapped[Optional[str]] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL")) + user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True)) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=False), nullable=False, default=utcnow ) @@ -162,7 +164,7 @@ class AssetInfo(Base): back_populates="infos", foreign_keys=[asset_hash], ) - preview_asset: Mapped[Asset | None] = relationship( + preview_asset: Mapped[Optional[Asset]] = relationship( "Asset", back_populates="preview_of", foreign_keys=[preview_hash], @@ -220,7 +222,7 @@ class AssetInfoMeta(Base): val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True) val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True) - val_json: Mapped[Optional[Any]] = mapped_column(JSON(none_as_null=True), nullable=True) + val_json: Mapped[Optional[Any]] = mapped_column(JSONB_V, nullable=True) asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries") diff --git a/tests-assets/conftest.py b/tests-assets/conftest.py index 16a43d56dce4..88fc0a5f30c8 100644 --- a/tests-assets/conftest.py +++ b/tests-assets/conftest.py @@ -15,6 +15,22 @@ import subprocess +def pytest_addoption(parser: pytest.Parser) -> None: + """ + Allow overriding the database URL used by the spawned ComfyUI process. + Priority: + 1) --db-url command line option + 2) ASSETS_TEST_DB_URL environment variable (used by CI) + 3) default: sqlite in-memory + """ + parser.addoption( + "--db-url", + action="store", + default=os.environ.get("ASSETS_TEST_DB_URL", "sqlite+aiosqlite:///:memory:"), + help="Async SQLAlchemy DB URL (e.g. sqlite+aiosqlite:///:memory: or postgresql+psycopg://user:pass@host/db)", + ) + + def _free_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) @@ -49,30 +65,38 @@ def event_loop(): @pytest.fixture(scope="session") def comfy_tmp_base_dir() -> Path: - tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-")) + env_base = os.environ.get("ASSETS_TEST_BASE_DIR") + created_by_fixture = False + if env_base: + tmp = Path(env_base) + tmp.mkdir(parents=True, exist_ok=True) + else: + tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-")) + created_by_fixture = True _make_base_dirs(tmp) yield tmp - with contextlib.suppress(Exception): - for p in sorted(tmp.rglob("*"), reverse=True): - if p.is_file() or p.is_symlink(): - p.unlink(missing_ok=True) - for p in sorted(tmp.glob("**/*"), reverse=True): - with contextlib.suppress(Exception): - p.rmdir() - tmp.rmdir() + if created_by_fixture: + with contextlib.suppress(Exception): + for p in sorted(tmp.rglob("*"), reverse=True): + if p.is_file() or p.is_symlink(): + p.unlink(missing_ok=True) + for p in sorted(tmp.glob("**/*"), reverse=True): + with contextlib.suppress(Exception): + p.rmdir() + tmp.rmdir() @pytest.fixture(scope="session") -def comfy_url_and_proc(comfy_tmp_base_dir: Path): +def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest): """ Boot ComfyUI subprocess with: - sandbox base dir - - sqlite memory DB + - sqlite memory DB (default) - autoscan disabled Returns (base_url, process, port) """ port = _free_port() - db_url = "sqlite+aiosqlite:///:memory:" + db_url = request.config.getoption("--db-url") logs_dir = comfy_tmp_base_dir / "logs" logs_dir.mkdir(exist_ok=True) @@ -94,6 +118,7 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path): "127.0.0.1", "--port", str(port), + "--cpu", ], stdout=out_log, stderr=err_log, @@ -101,6 +126,13 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path): env={**os.environ}, ) + for _ in range(50): + if proc.poll() is not None: + out_log.flush() + err_log.flush() + raise RuntimeError(f"ComfyUI exited early with code {proc.returncode}") + time.sleep(0.1) + base_url = f"http://127.0.0.1:{port}" try: async def _probe(): @@ -113,6 +145,9 @@ async def _probe(): with contextlib.suppress(Exception): proc.terminate() proc.wait(timeout=10) + with contextlib.suppress(Exception): + out_log.flush() + err_log.flush() raise RuntimeError(f"ComfyUI did not become ready: {e}") if proc and proc.poll() is None: @@ -144,7 +179,7 @@ async def _post_multipart_asset( tags: list[str], meta: dict, data: bytes, - extra_fields: dict | None = None, + extra_fields: Optional[dict] = None, ) -> tuple[int, dict]: form = aiohttp.FormData() form.add_field("file", data, filename=name, content_type="application/octet-stream") From 934377ac1ea6d05be60628206ad9a9b3788aa5cc Mon Sep 17 00:00:00 2001 From: Alexander Piskun Date: Fri, 12 Sep 2025 14:46:22 +0300 Subject: [PATCH 45/82] removed currently unnecessary "asset_locations" functionality --- alembic_db/versions/0001_assets.py | 19 ---- app/assets_fetcher.py | 137 ----------------------------- app/assets_manager.py | 17 ++-- app/resolvers/__init__.py | 35 -------- 4 files changed, 9 insertions(+), 199 deletions(-) delete mode 100644 app/assets_fetcher.py delete mode 100644 app/resolvers/__init__.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index f3b3ee0bfe46..bc98b5acf316 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -99,21 +99,6 @@ def upgrade() -> None: op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) - # ASSET_LOCATIONS: remote locations per asset - op.create_table( - "asset_locations", - sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), - sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False), - sa.Column("provider", sa.String(length=32), nullable=False), # e.g., "gcs" - sa.Column("locator", sa.Text(), nullable=False), # e.g., "gs://bucket/path/to/blob" - sa.Column("expected_size_bytes", sa.BigInteger(), nullable=True), - sa.Column("etag", sa.String(length=256), nullable=True), - sa.Column("last_modified", sa.String(length=128), nullable=True), - sa.UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"), - ) - op.create_index("ix_asset_locations_hash", "asset_locations", ["asset_hash"]) - op.create_index("ix_asset_locations_provider", "asset_locations", ["provider"]) - # Tags vocabulary for models tags_table = sa.table( "tags", @@ -158,10 +143,6 @@ def upgrade() -> None: def downgrade() -> None: - op.drop_index("ix_asset_locations_provider", table_name="asset_locations") - op.drop_index("ix_asset_locations_hash", table_name="asset_locations") - op.drop_table("asset_locations") - op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") diff --git a/app/assets_fetcher.py b/app/assets_fetcher.py deleted file mode 100644 index 36fa64ca954b..000000000000 --- a/app/assets_fetcher.py +++ /dev/null @@ -1,137 +0,0 @@ -from __future__ import annotations -import asyncio -import os -import tempfile -from typing import Optional -import mimetypes -import aiohttp - -from .storage.hashing import blake3_hash_sync -from .database.db import create_session -from .database.services import ingest_fs_asset, list_cache_states_by_asset_hash -from .resolvers import resolve_asset -from ._assets_helpers import resolve_destination_from_tags, ensure_within_base - -_FETCH_LOCKS: dict[str, asyncio.Lock] = {} - - -def _sanitize_filename(name: str) -> str: - return os.path.basename((name or "").strip()) or "file" - - -async def ensure_asset_cached( - asset_hash: str, - *, - preferred_name: Optional[str] = None, - tags_hint: Optional[list[str]] = None, -) -> str: - """ - Ensure there is a verified local file for asset_hash in the correct Comfy folder. - - Fast path: - - If any cache_state row has a file_path that exists, return it immediately. - Preference order is the oldest ID first for stability. - - Slow path: - - Resolve remote location + placement tags. - - Download to the correct folder, verify hash, move into place. - - Ingest identity + cache state so future fast passes can skip hashing. - """ - lock = _FETCH_LOCKS.setdefault(asset_hash, asyncio.Lock()) - async with lock: - # 1) If we already have any cache_state path present on disk, use it (oldest-first) - async with await create_session() as sess: - states = await list_cache_states_by_asset_hash(sess, asset_hash=asset_hash) - for s in states: - if s and s.file_path and os.path.isfile(s.file_path): - return s.file_path - - # 2) Resolve remote location + placement hints (must include valid tags) - res = await resolve_asset(asset_hash) - if not res: - raise FileNotFoundError(f"No resolver/locations for {asset_hash}") - - placement_tags = tags_hint or res.tags - if not placement_tags: - raise ValueError(f"Resolver did not provide placement tags for {asset_hash}") - - name_hint = res.filename or preferred_name or asset_hash.replace(":", "_") - safe_name = _sanitize_filename(name_hint) - - # 3) Map tags -> destination (strict: raises if invalid root or models category) - base_dir, subdirs = resolve_destination_from_tags(placement_tags) # may raise - dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir - os.makedirs(dest_dir, exist_ok=True) - - final_path = os.path.abspath(os.path.join(dest_dir, safe_name)) - ensure_within_base(final_path, base_dir) - - # 4) If target path exists, try to reuse; else delete invalid cache - if os.path.exists(final_path) and os.path.isfile(final_path): - existing_digest = blake3_hash_sync(final_path) - if f"blake3:{existing_digest}" == asset_hash: - size_bytes = os.path.getsize(final_path) - mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000)) - async with await create_session() as sess: - await ingest_fs_asset( - sess, - asset_hash=asset_hash, - abs_path=final_path, - size_bytes=size_bytes, - mtime_ns=mtime_ns, - mime_type=None, - info_name=None, - tags=(), - ) - await sess.commit() - return final_path - else: - # Invalid cache: remove before re-downloading - os.remove(final_path) - - # 5) Download to temp next to destination - timeout = aiohttp.ClientTimeout(total=60 * 30) - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(res.download_url, headers=dict(res.headers)) as resp: - resp.raise_for_status() - cl = resp.headers.get("Content-Length") - if res.expected_size and cl and int(cl) != int(res.expected_size): - raise ValueError("server Content-Length does not match expected size") - with tempfile.NamedTemporaryFile("wb", delete=False, dir=dest_dir) as tmp: - tmp_path = tmp.name - async for chunk in resp.content.iter_chunked(8 * 1024 * 1024): - if chunk: - tmp.write(chunk) - - # 6) Verify content hash - digest = blake3_hash_sync(tmp_path) - canonical = f"blake3:{digest}" - if canonical != asset_hash: - try: - os.remove(tmp_path) - finally: - raise ValueError(f"Hash mismatch: expected {asset_hash}, got {canonical}") - - # 7) Atomically move into place - if os.path.exists(final_path): - os.remove(final_path) - os.replace(tmp_path, final_path) - - # 8) Record identity + cache state (+ mime type) - size_bytes = os.path.getsize(final_path) - mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000)) - mime_type = mimetypes.guess_type(safe_name, strict=False)[0] - async with await create_session() as sess: - await ingest_fs_asset( - sess, - asset_hash=asset_hash, - abs_path=final_path, - size_bytes=size_bytes, - mtime_ns=mtime_ns, - mime_type=mime_type, - info_name=None, - tags=(), - ) - await sess.commit() - - return final_path diff --git a/app/assets_manager.py b/app/assets_manager.py index a2a73773a8ff..423a860e0776 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -36,7 +36,6 @@ ensure_within_base, resolve_destination_from_tags, ) -from .assets_fetcher import ensure_asset_cached async def asset_exists(*, asset_hash: str) -> bool: @@ -180,9 +179,9 @@ async def resolve_asset_content_for_download( """ Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time. Also touches last_access_time (only_if_newer). - Ensures the local cache is present (uses resolver if needed). Raises: ValueError if AssetInfo cannot be found + FileNotFoundError if file for Asset cannot be found """ async with await create_session() as session: pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) @@ -190,13 +189,15 @@ async def resolve_asset_content_for_download( raise ValueError(f"AssetInfo {asset_info_id} not found") info, asset = pair - tag_names = await get_asset_tags(session, asset_info_id=info.id) - - # Ensure cached (download if missing) - preferred_name = info.name or info.asset_hash.split(":", 1)[-1] - abs_path = await ensure_asset_cached(info.asset_hash, preferred_name=preferred_name, tags_hint=tag_names) + states = await list_cache_states_by_asset_hash(session, asset_hash=info.asset_hash) + abs_path = "" + for s in states: + if s and s.file_path and os.path.isfile(s.file_path): + abs_path = s.file_path + break + if not abs_path: + raise FileNotFoundError - async with await create_session() as session: await touch_asset_info_by_id(session, asset_info_id=asset_info_id) await session.commit() diff --git a/app/resolvers/__init__.py b/app/resolvers/__init__.py deleted file mode 100644 index c489ebad7c57..000000000000 --- a/app/resolvers/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -import contextlib -from dataclasses import dataclass -from typing import Protocol, Optional, Mapping - - -@dataclass -class ResolveResult: - provider: str # e.g., "gcs" - download_url: str # fully-qualified URL to fetch bytes - headers: Mapping[str, str] # optional auth headers etc - expected_size: Optional[int] = None - tags: Optional[list[str]] = None # e.g. ["models","vae","subdir"] - filename: Optional[str] = None # preferred basename - -class AssetResolver(Protocol): - provider: str - async def resolve(self, asset_hash: str) -> Optional[ResolveResult]: ... - - -_REGISTRY: list[AssetResolver] = [] - - -def register_resolver(resolver: AssetResolver) -> None: - """Append Resolver with simple de-dup per provider.""" - global _REGISTRY - _REGISTRY = [r for r in _REGISTRY if r.provider != resolver.provider] + [resolver] - - -async def resolve_asset(asset_hash: str) -> Optional[ResolveResult]: - for r in _REGISTRY: - with contextlib.suppress(Exception): # For Resolver failure we just try the next one - res = await r.resolve(asset_hash) - if res: - return res - return None From bb9ed04758369aa4c452d788d64ea7fed6a0dc68 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Fri, 12 Sep 2025 18:14:52 +0300 Subject: [PATCH 46/82] global refactoring; add support for Assets without the computed hash --- alembic_db/versions/0001_assets.py | 50 +- app/__init__.py | 6 +- app/_assets_helpers.py | 75 +- app/api/assets_routes.py | 27 +- app/api/schemas_in.py | 59 +- app/api/schemas_out.py | 12 +- app/assets_manager.py | 176 ++--- app/assets_scanner.py | 707 ++++++++---------- app/database/_helpers.py | 186 ----- app/database/db.py | 10 +- app/database/helpers/__init__.py | 23 + app/database/helpers/filters.py | 87 +++ app/database/helpers/ownership.py | 12 + app/database/helpers/projection.py | 64 ++ app/database/helpers/tags.py | 102 +++ app/database/models.py | 96 +-- app/database/services.py | 1116 ---------------------------- app/database/services/__init__.py | 56 ++ app/database/services/content.py | 746 +++++++++++++++++++ app/database/services/info.py | 579 +++++++++++++++ app/database/services/queries.py | 59 ++ comfy/cli_args.py | 1 - main.py | 4 +- server.py | 3 +- tests-assets/test_crud.py | 19 +- tests-assets/test_tags.py | 21 +- tests-assets/test_uploads.py | 2 +- 27 files changed, 2374 insertions(+), 1924 deletions(-) delete mode 100644 app/database/_helpers.py create mode 100644 app/database/helpers/__init__.py create mode 100644 app/database/helpers/filters.py create mode 100644 app/database/helpers/ownership.py create mode 100644 app/database/helpers/projection.py create mode 100644 app/database/helpers/tags.py delete mode 100644 app/database/services.py create mode 100644 app/database/services/__init__.py create mode 100644 app/database/services/content.py create mode 100644 app/database/services/info.py create mode 100644 app/database/services/queries.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index bc98b5acf316..1f5fb462280d 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -16,33 +16,44 @@ def upgrade() -> None: - # ASSETS: content identity (deduplicated by hash) + # ASSETS: content identity op.create_table( "assets", - sa.Column("hash", sa.String(length=256), primary_key=True), + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("hash", sa.String(length=256), nullable=True), sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), sa.Column("mime_type", sa.String(length=255), nullable=True), sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), ) + if op.get_bind().dialect.name == "postgresql": + op.create_index( + "uq_assets_hash_not_null", + "assets", + ["hash"], + unique=True, + postgresql_where=sa.text("hash IS NOT NULL"), + ) + else: + op.create_index("uq_assets_hash", "assets", ["hash"], unique=True) op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) - # ASSETS_INFO: user-visible references (mutable metadata) + # ASSETS_INFO: user-visible references op.create_table( "assets_info", sa.Column("id", sa.String(length=36), primary_key=True), sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), sa.Column("name", sa.String(length=512), nullable=False), - sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False), - sa.Column("preview_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True), sa.Column("user_metadata", sa.JSON(), nullable=True), sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), - sa.UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"), + sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), ) op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) - op.create_index("ix_assets_info_asset_hash", "assets_info", ["asset_hash"]) + op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"]) op.create_index("ix_assets_info_name", "assets_info", ["name"]) op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) @@ -69,18 +80,19 @@ def upgrade() -> None: op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) - # ASSET_CACHE_STATE: N:1 local cache metadata rows per Asset + # ASSET_CACHE_STATE: N:1 local cache rows per Asset op.create_table( "asset_cache_state", sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), - sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False), sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")), sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), ) op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) - op.create_index("ix_asset_cache_state_asset_hash", "asset_cache_state", ["asset_hash"]) + op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"]) # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting op.create_table( @@ -99,7 +111,7 @@ def upgrade() -> None: op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) - # Tags vocabulary for models + # Tags vocabulary tags_table = sa.table( "tags", sa.column("name", sa.String(length=512)), @@ -108,12 +120,10 @@ def upgrade() -> None: op.bulk_insert( tags_table, [ - # Root folder tags {"name": "models", "tag_type": "system"}, {"name": "input", "tag_type": "system"}, {"name": "output", "tag_type": "system"}, - # Core tags {"name": "configs", "tag_type": "system"}, {"name": "checkpoints", "tag_type": "system"}, {"name": "loras", "tag_type": "system"}, @@ -132,12 +142,11 @@ def upgrade() -> None: {"name": "photomaker", "tag_type": "system"}, {"name": "classifiers", "tag_type": "system"}, - # Extra basic tags {"name": "encoder", "tag_type": "system"}, {"name": "decoder", "tag_type": "system"}, - # Special tags {"name": "missing", "tag_type": "system"}, + {"name": "rescan", "tag_type": "system"}, ], ) @@ -149,8 +158,9 @@ def downgrade() -> None: op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") op.drop_table("asset_info_meta") - op.drop_index("ix_asset_cache_state_asset_hash", table_name="asset_cache_state") + op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state") op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state") op.drop_table("asset_cache_state") op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") @@ -160,14 +170,18 @@ def downgrade() -> None: op.drop_index("ix_tags_tag_type", table_name="tags") op.drop_table("tags") - op.drop_constraint("uq_assets_info_hash_owner_name", table_name="assets_info") + op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info") op.drop_index("ix_assets_info_owner_name", table_name="assets_info") op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") op.drop_index("ix_assets_info_created_at", table_name="assets_info") op.drop_index("ix_assets_info_name", table_name="assets_info") - op.drop_index("ix_assets_info_asset_hash", table_name="assets_info") + op.drop_index("ix_assets_info_asset_id", table_name="assets_info") op.drop_index("ix_assets_info_owner_id", table_name="assets_info") op.drop_table("assets_info") + if op.get_bind().dialect.name == "postgresql": + op.drop_index("uq_assets_hash_not_null", table_name="assets") + else: + op.drop_index("uq_assets_hash", table_name="assets") op.drop_index("ix_assets_mime_type", table_name="assets") op.drop_table("assets") diff --git a/app/__init__.py b/app/__init__.py index 5fade97a49dd..e8538bd29fb6 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,5 +1,5 @@ +from .assets_scanner import sync_seed_assets from .database.db import init_db_engine -from .assets_scanner import start_background_assets_scan +from .api.assets_routes import register_assets_system - -__all__ = ["init_db_engine", "start_background_assets_scan"] +__all__ = ["init_db_engine", "sync_seed_assets", "register_assets_system"] diff --git a/app/_assets_helpers.py b/app/_assets_helpers.py index 8fb88cd34e19..e0b982c985c0 100644 --- a/app/_assets_helpers.py +++ b/app/_assets_helpers.py @@ -1,12 +1,13 @@ +import contextlib import os +import uuid +from datetime import datetime, timezone from pathlib import Path -from typing import Optional, Literal, Sequence - -import sqlalchemy as sa +from typing import Literal, Optional, Sequence import folder_paths -from .database.models import AssetInfo +from .api import schemas_in def get_comfy_models_folders() -> list[tuple[str, list[str]]]: @@ -139,14 +140,6 @@ def ensure_within_base(candidate: str, base: str) -> None: raise ValueError("invalid destination path") -def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: - """Build owner visibility predicate for reads.""" - owner_id = (owner_id or "").strip() - if owner_id == "": - return AssetInfo.owner_id == "" - return AssetInfo.owner_id.in_(["", owner_id]) - - def compute_model_relative_filename(file_path: str) -> Optional[str]: """ Return the model's path relative to the last well-known folder (the model category), @@ -172,3 +165,61 @@ def compute_model_relative_filename(file_path: str) -> Optional[str]: return None inside = parts[1:] if len(parts) > 1 else [parts[0]] return "/".join(inside) # normalize to POSIX style for portability + + +def list_tree(base_dir: str) -> list[str]: + out: list[str] = [] + base_abs = os.path.abspath(base_dir) + if not os.path.isdir(base_abs): + return out + for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): + for name in filenames: + out.append(os.path.abspath(os.path.join(dirpath, name))) + return out + + +def prefixes_for_root(root: schemas_in.RootType) -> list[str]: + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + return [os.path.abspath(p) for p in bases] + if root == "input": + return [os.path.abspath(folder_paths.get_input_directory())] + if root == "output": + return [os.path.abspath(folder_paths.get_output_directory())] + return [] + + +def ts_to_iso(ts: Optional[float]) -> Optional[str]: + if ts is None: + return None + try: + return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat() + except Exception: + return None + + +def new_scan_id(root: schemas_in.RootType) -> str: + return f"scan-{root}-{uuid.uuid4().hex[:8]}" + + +def collect_models_files() -> list[str]: + out: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): + rel_files = folder_paths.get_filename_list(folder_name) or [] + for rel_path in rel_files: + abs_path = folder_paths.get_full_path(folder_name, rel_path) + if not abs_path: + continue + abs_path = os.path.abspath(abs_path) + allowed = False + for b in bases: + base_abs = os.path.abspath(b) + with contextlib.suppress(Exception): + if os.path.commonpath([abs_path, base_abs]) == base_abs: + allowed = True + break + if allowed: + out.append(abs_path) + return out diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index e9a4ff97afe2..384c9f6c0a53 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -1,7 +1,7 @@ import contextlib import os -import uuid import urllib.parse +import uuid from typing import Optional from aiohttp import web @@ -12,7 +12,6 @@ from .. import assets_manager, assets_scanner, user_manager from . import schemas_in, schemas_out - ROUTES = web.RouteTableDef() UserManager: Optional[user_manager.UserManager] = None @@ -272,6 +271,7 @@ async def upload_asset(request: web.Request) -> web.Response: temp_path=tmp_path, client_filename=file_client_name, owner_id=owner_id, + expected_asset_hash=spec.hash, ) status = 201 if created.created_new else 200 return web.json_response(created.model_dump(mode="json"), status=status) @@ -332,6 +332,29 @@ async def update_asset(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) +@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview") +async def set_asset_preview(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + body = schemas_in.SetPreviewBody.model_validate(await request.json()) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = await assets_manager.set_asset_preview( + asset_info_id=asset_info_id, + preview_asset_id=body.preview_id, + owner_id=UserManager.get_request_user_id(request), + ) + except (PermissionError, ValueError) as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result.model_dump(mode="json"), status=200) + + @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") async def delete_asset(request: web.Request) -> web.Response: asset_info_id = str(uuid.UUID(request.match_info["id"])) diff --git a/app/api/schemas_in.py b/app/api/schemas_in.py index 412b72e3af0b..bc521b313d49 100644 --- a/app/api/schemas_in.py +++ b/app/api/schemas_in.py @@ -1,7 +1,15 @@ import json +import uuid +from typing import Any, Literal, Optional -from typing import Any, Optional, Literal -from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator, conint +from pydantic import ( + BaseModel, + ConfigDict, + Field, + conint, + field_validator, + model_validator, +) class ListAssetsQuery(BaseModel): @@ -148,30 +156,12 @@ class TagsRemove(TagsAdd): pass -class ScheduleAssetScanBody(BaseModel): - roots: list[Literal["models","input","output"]] = Field(default_factory=list) +RootType = Literal["models", "input", "output"] +ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") - @field_validator("roots", mode="before") - @classmethod - def _normalize_roots(cls, v): - if v is None: - return [] - if isinstance(v, str): - items = [x.strip().lower() for x in v.split(",")] - elif isinstance(v, list): - items = [] - for x in v: - if isinstance(x, str): - items.extend([p.strip().lower() for p in x.split(",")]) - else: - return [] - out = [] - seen = set() - for r in items: - if r in {"models","input","output"} and r not in seen: - out.append(r) - seen.add(r) - return out + +class ScheduleAssetScanBody(BaseModel): + roots: list[RootType] = Field(..., min_length=1) class UploadAssetSpec(BaseModel): @@ -281,3 +271,22 @@ def _validate_order(self): if len(self.tags) < 2: raise ValueError("models uploads require a category tag as the second tag") return self + + +class SetPreviewBody(BaseModel): + """Set or clear the preview for an AssetInfo. Provide an Asset.id or null.""" + preview_id: Optional[str] = None + + @field_validator("preview_id", mode="before") + @classmethod + def _norm_uuid(cls, v): + if v is None: + return None + s = str(v).strip() + if not s: + return None + try: + uuid.UUID(s) + except Exception: + raise ValueError("preview_id must be a UUID") + return s diff --git a/app/api/schemas_out.py b/app/api/schemas_out.py index 8bb34096bb75..cc7e9572be4a 100644 --- a/app/api/schemas_out.py +++ b/app/api/schemas_out.py @@ -1,12 +1,13 @@ from datetime import datetime from typing import Any, Literal, Optional + from pydantic import BaseModel, ConfigDict, Field, field_serializer class AssetSummary(BaseModel): id: str name: str - asset_hash: str + asset_hash: Optional[str] size: Optional[int] = None mime_type: Optional[str] = None tags: list[str] = Field(default_factory=list) @@ -31,7 +32,7 @@ class AssetsList(BaseModel): class AssetUpdated(BaseModel): id: str name: str - asset_hash: str + asset_hash: Optional[str] tags: list[str] = Field(default_factory=list) user_metadata: dict[str, Any] = Field(default_factory=dict) updated_at: Optional[datetime] = None @@ -46,12 +47,12 @@ def _ser_updated(self, v: Optional[datetime], _info): class AssetDetail(BaseModel): id: str name: str - asset_hash: str + asset_hash: Optional[str] size: Optional[int] = None mime_type: Optional[str] = None tags: list[str] = Field(default_factory=list) user_metadata: dict[str, Any] = Field(default_factory=dict) - preview_hash: Optional[str] = None + preview_id: Optional[str] = None created_at: Optional[datetime] = None last_access_time: Optional[datetime] = None @@ -95,7 +96,6 @@ class TagsRemove(BaseModel): class AssetScanError(BaseModel): path: str message: str - phase: Literal["fast", "slow"] at: Optional[str] = Field(None, description="ISO timestamp") @@ -108,8 +108,6 @@ class AssetScanStatus(BaseModel): finished_at: Optional[str] = None discovered: int = 0 processed: int = 0 - slow_queue_total: int = 0 - slow_queue_finished: int = 0 file_errors: list[AssetScanError] = Field(default_factory=list) diff --git a/app/assets_manager.py b/app/assets_manager.py index 423a860e0776..9d2424ce6065 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -4,38 +4,39 @@ import os from typing import Optional, Sequence -from comfy.cli_args import args from comfy_api.internal import async_to_sync +from ._assets_helpers import ( + ensure_within_base, + get_name_and_tags_from_asset_path, + resolve_destination_from_tags, +) +from .api import schemas_in, schemas_out from .database.db import create_session -from .storage import hashing +from .database.models import Asset from .database.services import ( + add_tags_to_asset_info, + asset_exists_by_hash, + asset_info_exists_for_asset_id, check_fs_asset_exists_quick, + create_asset_info_for_existing_asset, + delete_asset_info_by_id, + fetch_asset_info_and_asset, + fetch_asset_info_asset_and_tags, + get_asset_by_hash, + get_asset_info_by_id, + get_asset_tags, ingest_fs_asset, - touch_asset_infos_by_fs_path, list_asset_infos_page, - update_asset_info_full, - get_asset_tags, + list_cache_states_by_asset_id, list_tags_with_usage, - add_tags_to_asset_info, remove_tags_from_asset_info, - fetch_asset_info_and_asset, + set_asset_info_preview, touch_asset_info_by_id, - delete_asset_info_by_id, - asset_exists_by_hash, - get_asset_by_hash, - create_asset_info_for_existing_asset, - fetch_asset_info_asset_and_tags, - get_asset_info_by_id, - list_cache_states_by_asset_hash, - asset_info_exists_for_hash, -) -from .api import schemas_in, schemas_out -from ._assets_helpers import ( - get_name_and_tags_from_asset_path, - ensure_within_base, - resolve_destination_from_tags, + touch_asset_infos_by_fs_path, + update_asset_info_full, ) +from .storage import hashing async def asset_exists(*, asset_hash: str) -> bool: @@ -44,29 +45,21 @@ async def asset_exists(*, asset_hash: str) -> bool: def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None: - if not args.enable_model_processing: - if tags is None: - tags = [] - try: - asset_name, path_tags = get_name_and_tags_from_asset_path(file_path) - async_to_sync.AsyncToSyncConverter.run_async_in_thread( - add_local_asset, - tags=list(dict.fromkeys([*path_tags, *tags])), - file_name=asset_name, - file_path=file_path, - ) - except ValueError as e: - logging.warning("Skipping non-asset path %s: %s", file_path, e) + if tags is None: + tags = [] + try: + asset_name, path_tags = get_name_and_tags_from_asset_path(file_path) + async_to_sync.AsyncToSyncConverter.run_async_in_thread( + add_local_asset, + tags=list(dict.fromkeys([*path_tags, *tags])), + file_name=asset_name, + file_path=file_path, + ) + except ValueError as e: + logging.warning("Skipping non-asset path %s: %s", file_path, e) async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None: - """Adds a local asset to the DB. If already present and unchanged, does nothing. - - Notes: - - Uses absolute path as the canonical locator for the cache backend. - - Computes BLAKE3 only when the fast existence check indicates it's needed. - - This function ensures the identity row and seeds mtime in asset_cache_state. - """ abs_path = os.path.abspath(file_path) size_bytes, mtime_ns = _get_size_mtime_ns(abs_path) if not size_bytes: @@ -132,7 +125,7 @@ async def list_assets( schemas_out.AssetSummary( id=info.id, name=info.name, - asset_hash=info.asset_hash, + asset_hash=asset.hash if asset else None, size=int(asset.size_bytes) if asset else None, mime_type=asset.mime_type if asset else None, tags=tags, @@ -156,16 +149,17 @@ async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.As if not res: raise ValueError(f"AssetInfo {asset_info_id} not found") info, asset, tag_names = res + preview_id = info.preview_id return schemas_out.AssetDetail( id=info.id, name=info.name, - asset_hash=info.asset_hash, + asset_hash=asset.hash if asset else None, size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, mime_type=asset.mime_type if asset else None, tags=tag_names, - preview_hash=info.preview_hash, user_metadata=info.user_metadata or {}, + preview_id=preview_id, created_at=info.created_at, last_access_time=info.last_access_time, ) @@ -176,20 +170,13 @@ async def resolve_asset_content_for_download( asset_info_id: str, owner_id: str = "", ) -> tuple[str, str, str]: - """ - Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time. - Also touches last_access_time (only_if_newer). - Raises: - ValueError if AssetInfo cannot be found - FileNotFoundError if file for Asset cannot be found - """ async with await create_session() as session: pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) if not pair: raise ValueError(f"AssetInfo {asset_info_id} not found") info, asset = pair - states = await list_cache_states_by_asset_hash(session, asset_hash=info.asset_hash) + states = await list_cache_states_by_asset_id(session, asset_id=asset.id) abs_path = "" for s in states: if s and s.file_path and os.path.isfile(s.file_path): @@ -214,16 +201,6 @@ async def upload_asset_from_temp_path( owner_id: str = "", expected_asset_hash: Optional[str] = None, ) -> schemas_out.AssetCreated: - """ - Finalize an uploaded temp file: - - compute blake3 hash - - if expected_asset_hash provided, verify equality (400 on mismatch at caller) - - if an Asset with the same hash exists: discard temp, create AssetInfo only (no write) - - else resolve destination from tags and atomically move into place - - ingest into DB (assets, locator state, asset_info + tags) - Returns a populated AssetCreated payload. - """ - try: digest = await hashing.blake3_hash(temp_path) except Exception as e: @@ -233,7 +210,6 @@ async def upload_asset_from_temp_path( if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower(): raise ValueError("HASH_MISMATCH") - # Fast path: content already known --> no writes, just create a reference async with await create_session() as session: existing = await get_asset_by_hash(session, asset_hash=asset_hash) if existing is not None: @@ -257,43 +233,37 @@ async def upload_asset_from_temp_path( return schemas_out.AssetCreated( id=info.id, name=info.name, - asset_hash=info.asset_hash, + asset_hash=existing.hash, size=int(existing.size_bytes) if existing.size_bytes is not None else None, mime_type=existing.mime_type, tags=tag_names, user_metadata=info.user_metadata or {}, - preview_hash=info.preview_hash, + preview_id=info.preview_id, created_at=info.created_at, last_access_time=info.last_access_time, created_new=False, ) - # Resolve destination (only for truly new content) base_dir, subdirs = resolve_destination_from_tags(spec.tags) dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir os.makedirs(dest_dir, exist_ok=True) - # Decide filename desired_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) dest_abs = os.path.abspath(os.path.join(dest_dir, desired_name)) ensure_within_base(dest_abs, base_dir) - # Content type based on final name content_type = mimetypes.guess_type(desired_name, strict=False)[0] or "application/octet-stream" - # Atomic move into place try: os.replace(temp_path, dest_abs) except Exception as e: raise RuntimeError(f"failed to move uploaded file into place: {e}") - # Stat final file try: size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs) except OSError as e: raise RuntimeError(f"failed to stat destination file: {e}") - # Ingest + build response async with await create_session() as session: result = await ingest_fs_asset( session, @@ -304,7 +274,7 @@ async def upload_asset_from_temp_path( mime_type=content_type, info_name=os.path.basename(dest_abs), owner_id=owner_id, - preview_hash=None, + preview_id=None, user_metadata=spec.user_metadata or {}, tags=spec.tags, tag_origin="manual", @@ -324,12 +294,12 @@ async def upload_asset_from_temp_path( return schemas_out.AssetCreated( id=info.id, name=info.name, - asset_hash=info.asset_hash, + asset_hash=asset.hash, size=int(asset.size_bytes), mime_type=asset.mime_type, tags=tag_names, user_metadata=info.user_metadata or {}, - preview_hash=info.preview_hash, + preview_id=info.preview_id, created_at=info.created_at, last_access_time=info.last_access_time, created_new=result["asset_created"], @@ -367,38 +337,74 @@ async def update_asset( return schemas_out.AssetUpdated( id=info.id, name=info.name, - asset_hash=info.asset_hash, + asset_hash=info.asset.hash if info.asset else None, tags=tag_names, user_metadata=info.user_metadata or {}, updated_at=info.updated_at, ) +async def set_asset_preview( + *, + asset_info_id: str, + preview_asset_id: Optional[str], + owner_id: str = "", +) -> schemas_out.AssetDetail: + async with await create_session() as session: + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + await set_asset_info_preview( + session, + asset_info_id=asset_info_id, + preview_asset_id=preview_asset_id, + ) + + res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not res: + raise RuntimeError("State changed during preview update") + info, asset, tags = res + await session.commit() + + return schemas_out.AssetDetail( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, + mime_type=asset.mime_type if asset else None, + tags=tags, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + ) + + async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool: - """Delete single AssetInfo. If this was the last reference to Asset and delete_content_if_orphan=True (default), - delete the Asset row as well and remove all cached files recorded for that asset_hash. - """ async with await create_session() as session: info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) - asset_hash = info_row.asset_hash if info_row else None + asset_id = info_row.asset_id if info_row else None deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) if not deleted: await session.commit() return False - if not delete_content_if_orphan or not asset_hash: + if not delete_content_if_orphan or not asset_id: await session.commit() return True - still_exists = await asset_info_exists_for_hash(session, asset_hash=asset_hash) + still_exists = await asset_info_exists_for_asset_id(session, asset_id=asset_id) if still_exists: await session.commit() return True - states = await list_cache_states_by_asset_hash(session, asset_hash=asset_hash) + states = await list_cache_states_by_asset_id(session, asset_id=asset_id) file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)] - asset_row = await get_asset_by_hash(session, asset_hash=asset_hash) + asset_row = await session.get(Asset, asset_id) if asset_row is not None: await session.delete(asset_row) @@ -439,12 +445,12 @@ async def create_asset_from_hash( return schemas_out.AssetCreated( id=info.id, name=info.name, - asset_hash=info.asset_hash, + asset_hash=asset.hash, size=int(asset.size_bytes), mime_type=asset.mime_type, tags=tag_names, user_metadata=info.user_metadata or {}, - preview_hash=info.preview_hash, + preview_id=info.preview_id, created_at=info.created_at, last_access_time=info.last_access_time, created_new=False, diff --git a/app/assets_scanner.py b/app/assets_scanner.py index a77f877718c5..6cca5b16571c 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -1,52 +1,55 @@ import asyncio -import contextlib import logging import os -import uuid import time from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Callable, Literal, Optional, Sequence +from typing import Literal, Optional + +import sqlalchemy as sa import folder_paths -from . import assets_manager -from .api import schemas_out -from ._assets_helpers import get_comfy_models_folders +from ._assets_helpers import ( + collect_models_files, + get_comfy_models_folders, + get_name_and_tags_from_asset_path, + list_tree, + new_scan_id, + prefixes_for_root, + ts_to_iso, +) +from .api import schemas_in, schemas_out from .database.db import create_session +from .database.helpers import ( + add_missing_tag_for_asset_id, + remove_missing_tag_for_asset_id, +) +from .database.models import Asset, AssetCacheState, AssetInfo from .database.services import ( - check_fs_asset_exists_quick, + compute_hash_and_dedup_for_cache_state, + ensure_seed_for_path, + list_cache_states_by_asset_id, list_cache_states_with_asset_under_prefixes, - add_missing_tag_for_asset_hash, - remove_missing_tag_for_asset_hash, + list_unhashed_candidates_under_prefixes, + list_verify_candidates_under_prefixes, ) LOGGER = logging.getLogger(__name__) -RootType = Literal["models", "input", "output"] -ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") - SLOW_HASH_CONCURRENCY = 1 @dataclass class ScanProgress: scan_id: str - root: RootType + root: schemas_in.RootType status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled" scheduled_at: float = field(default_factory=lambda: time.time()) started_at: Optional[float] = None finished_at: Optional[float] = None - discovered: int = 0 processed: int = 0 - slow_queue_total: int = 0 - slow_queue_finished: int = 0 - file_errors: list[dict] = field(default_factory=list) # {"path","message","phase","at"} - - # Internal diagnostics for logs - _fast_total_seen: int = 0 - _fast_clean: int = 0 + file_errors: list[dict] = field(default_factory=list) @dataclass @@ -56,18 +59,14 @@ class SlowQueueState: closed: bool = False -RUNNING_TASKS: dict[RootType, asyncio.Task] = {} -PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {} -SLOW_STATE_BY_ROOT: dict[RootType, SlowQueueState] = {} - - -async def start_background_assets_scan(): - await fast_reconcile_and_kickoff(progress_cb=_console_cb) +RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {} +PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {} +SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {} def current_statuses() -> schemas_out.AssetScanStatusResponse: scans = [] - for root in ALLOWED_ROOTS: + for root in schemas_in.ALLOWED_ROOTS: prog = PROGRESS_BY_ROOT.get(root) if not prog: continue @@ -75,83 +74,65 @@ def current_statuses() -> schemas_out.AssetScanStatusResponse: return schemas_out.AssetScanStatusResponse(scans=scans) -async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse: - """Schedule scans for the provided roots; returns progress snapshots. - - Rules: - - Only roots in {models, input, output} are accepted. - - If a root is already scanning, we do NOT enqueue another one. Status returned as-is. - - Otherwise a new task is created and started immediately. - - Files with zero size are skipped. - """ - normalized: list[RootType] = [] - seen = set() - for r in roots or []: - rr = r.strip().lower() - if rr in ALLOWED_ROOTS and rr not in seen: - normalized.append(rr) # type: ignore - seen.add(rr) - if not normalized: - normalized = list(ALLOWED_ROOTS) # schedule all by default - +async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse: results: list[ScanProgress] = [] - for root in normalized: + for root in roots: if root in RUNNING_TASKS and not RUNNING_TASKS[root].done(): results.append(PROGRESS_BY_ROOT[root]) continue - prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled") + prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled") PROGRESS_BY_ROOT[root] = prog - SLOW_STATE_BY_ROOT[root] = SlowQueueState(queue=asyncio.Queue()) + state = SlowQueueState(queue=asyncio.Queue()) + SLOW_STATE_BY_ROOT[root] = state RUNNING_TASKS[root] = asyncio.create_task( - _pipeline_for_root(root, prog, progress_cb=None), + _run_hash_verify_pipeline(root, prog, state), name=f"asset-scan:{root}", ) results.append(prog) return _status_response_for(results) -async def fast_reconcile_and_kickoff( - roots: Optional[Sequence[str]] = None, - *, - progress_cb: Optional[Callable[[str, str, int, bool, dict], None]] = None, -) -> schemas_out.AssetScanStatusResponse: - """ - Startup helper: do the fast pass now (so we know queue size), - start slow hashing in the background, return immediately. - """ - normalized = [*ALLOWED_ROOTS] if not roots else [r for r in roots if r in ALLOWED_ROOTS] - snaps: list[ScanProgress] = [] - - for root in normalized: - if root in RUNNING_TASKS and not RUNNING_TASKS[root].done(): - snaps.append(PROGRESS_BY_ROOT[root]) - continue - - prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled") - PROGRESS_BY_ROOT[root] = prog - state = SlowQueueState(queue=asyncio.Queue()) - SLOW_STATE_BY_ROOT[root] = state +async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: + for r in roots: + try: + await _fast_db_consistency_pass(r) + except Exception as ex: + LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex) + + paths: list[str] = [] + if "models" in roots: + paths.extend(collect_models_files()) + if "input" in roots: + paths.extend(list_tree(folder_paths.get_input_directory())) + if "output" in roots: + paths.extend(list_tree(folder_paths.get_output_directory())) - prog.status = "running" - prog.started_at = time.time() + for p in paths: try: - await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb) - except Exception as e: - _append_error(prog, phase="fast", path="", message=str(e)) - prog.status = "failed" - prog.finished_at = time.time() - LOGGER.exception("Fast reconcile failed for %s", root) - snaps.append(prog) + st = os.stat(p, follow_symlinks=True) + if not int(st.st_size or 0): + continue + size_bytes = int(st.st_size) + mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + name, tags = get_name_and_tags_from_asset_path(p) + await _seed_one_async(p, size_bytes, mtime_ns, name, tags) + except OSError: continue - _start_slow_workers(root, prog, state, progress_cb=progress_cb) - RUNNING_TASKS[root] = asyncio.create_task( - _await_workers_then_finish(root, prog, state, progress_cb=progress_cb), - name=f"asset-hash:{root}", + +async def _seed_one_async(p: str, size_bytes: int, mtime_ns: int, name: str, tags: list[str]) -> None: + async with await create_session() as sess: + await ensure_seed_for_path( + sess, + abs_path=p, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + info_name=name, + tags=tags, + owner_id="", ) - snaps.append(prog) - return _status_response_for(snaps) + await sess.commit() def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse: @@ -163,18 +144,15 @@ def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.A scan_id=progress.scan_id, root=progress.root, status=progress.status, - scheduled_at=_ts_to_iso(progress.scheduled_at), - started_at=_ts_to_iso(progress.started_at), - finished_at=_ts_to_iso(progress.finished_at), + scheduled_at=ts_to_iso(progress.scheduled_at), + started_at=ts_to_iso(progress.started_at), + finished_at=ts_to_iso(progress.finished_at), discovered=progress.discovered, processed=progress.processed, - slow_queue_total=progress.slow_queue_total, - slow_queue_finished=progress.slow_queue_finished, file_errors=[ schemas_out.AssetScanError( path=e.get("path", ""), message=e.get("message", ""), - phase=e.get("phase", "slow"), at=e.get("at"), ) for e in (progress.file_errors or []) @@ -182,27 +160,100 @@ def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.A ) -async def _pipeline_for_root( - root: RootType, - prog: ScanProgress, - progress_cb: Optional[Callable[[str, str, int, bool, dict], None]], -) -> None: - state = SLOW_STATE_BY_ROOT.get(root) or SlowQueueState(queue=asyncio.Queue()) - SLOW_STATE_BY_ROOT[root] = state +async def _refresh_verify_flags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None: + """Fast pass to mark verify candidates by comparing stored mtime_ns with on-disk mtime.""" + prefixes = prefixes_for_root(root) + if not prefixes: + return + + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + conds.append(AssetCacheState.file_path.like(base + "%")) + + async with await create_session() as sess: + rows = ( + await sess.execute( + sa.select( + AssetCacheState.id, + AssetCacheState.mtime_ns, + AssetCacheState.needs_verify, + Asset.hash, + AssetCacheState.file_path, + ) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(sa.or_(*conds)) + ) + ).all() + + to_set = [] + to_clear = [] + for sid, mtime_db, needs_verify, a_hash, fp in rows: + try: + st = os.stat(fp, follow_symlinks=True) + except OSError: + # Missing files are handled by missing-tag reconciliation later. + continue + + actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + if a_hash is not None: + if mtime_db is None or int(mtime_db) != int(actual_mtime_ns): + if not needs_verify: + to_set.append(sid) + else: + if needs_verify: + to_clear.append(sid) + + if to_set: + await sess.execute( + sa.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_set)) + .values(needs_verify=True) + ) + if to_clear: + await sess.execute( + sa.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_clear)) + .values(needs_verify=False) + ) + await sess.commit() + +async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None: prog.status = "running" prog.started_at = time.time() - try: - await _reconcile_missing_tags_for_root(root, prog) - await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb) - _start_slow_workers(root, prog, state, progress_cb=progress_cb) - await _await_workers_then_finish(root, prog, state, progress_cb=progress_cb) + prefixes = prefixes_for_root(root) + + await _refresh_verify_flags_for_root(root, prog) + + # collect candidates from DB + async with await create_session() as sess: + verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes) + unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes) + # dedupe: prioritize verification first + seen = set() + ordered: list[int] = [] + for lst in (verify_ids, unhashed_ids): + for sid in lst: + if sid not in seen: + seen.add(sid); ordered.append(sid) + + prog.discovered = len(ordered) + + # queue up work + for sid in ordered: + await state.queue.put(sid) + state.closed = True + _start_state_workers(root, prog, state) + await _await_state_workers_then_finish(root, prog, state) except asyncio.CancelledError: prog.status = "cancelled" raise except Exception as exc: - _append_error(prog, phase="slow", path="", message=str(exc)) + _append_error(prog, path="", message=str(exc)) prog.status = "failed" prog.finished_at = time.time() LOGGER.exception("Asset scan failed for %s", root) @@ -210,110 +261,13 @@ async def _pipeline_for_root( RUNNING_TASKS.pop(root, None) -async def _fast_reconcile_into_queue( - root: RootType, - prog: ScanProgress, - state: SlowQueueState, - *, - progress_cb: Optional[Callable[[str, str, int, bool, dict], None]], -) -> None: - """ - Enumerate files, set 'discovered' to total files seen, increment 'processed' for fast-matched files, - and queue the rest for slow hashing. +async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None: """ - if root == "models": - files = _collect_models_files() - preset_discovered = _count_nonzero_in_list(files) - files_iter = asyncio.Queue() - for p in files: - await files_iter.put(p) - await files_iter.put(None) # sentinel for our local draining loop - elif root == "input": - base = folder_paths.get_input_directory() - preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True) - files_iter = await _queue_tree_files(base) - elif root == "output": - base = folder_paths.get_output_directory() - preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True) - files_iter = await _queue_tree_files(base) - else: - raise RuntimeError(f"Unsupported root: {root}") - - prog.discovered = int(preset_discovered or 0) - - queued = 0 - checked = 0 - clean = 0 - - async with await create_session() as sess: - while True: - item = await files_iter.get() - files_iter.task_done() - if item is None: - break - - abs_path = item - checked += 1 - - # Stat; skip empty/unreadable - try: - st = os.stat(abs_path, follow_symlinks=True) - if not st.st_size: - continue - size_bytes = int(st.st_size) - mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) - except OSError as e: - _append_error(prog, phase="fast", path=abs_path, message=str(e)) - continue - - try: - known = await check_fs_asset_exists_quick( - sess, - file_path=abs_path, - size_bytes=size_bytes, - mtime_ns=mtime_ns, - ) - except Exception as e: - _append_error(prog, phase="fast", path=abs_path, message=str(e)) - known = False + Detect missing files quickly and toggle 'missing' tag per asset_id. - if known: - clean += 1 - prog.processed += 1 - else: - await state.queue.put(abs_path) - queued += 1 - prog.slow_queue_total += 1 - - if progress_cb: - progress_cb(root, "fast", prog.processed, False, { - "checked": checked, - "clean": clean, - "queued": queued, - "discovered": prog.discovered, - }) - - prog._fast_total_seen = checked - prog._fast_clean = clean - - if progress_cb: - progress_cb(root, "fast", prog.processed, True, { - "checked": checked, - "clean": clean, - "queued": queued, - "discovered": prog.discovered, - }) - state.closed = True - - -async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) -> None: - """ - Logic for detecting missing Assets files: - - Clear 'missing' only if at least one cached path passes fast check: - exists AND mtime_ns matches AND size matches. - - Otherwise set 'missing'. - Files that exist but fail fast check will be slow-hashed by the normal pipeline, - and ingest_fs_asset will clear 'missing' if they truly match. + Rules: + - Only hashed assets (assets.hash != NULL) participate in missing tagging. + - We consider ALL cache states of the asset (across roots) before tagging. """ if root == "models": bases: list[str] = [] @@ -326,232 +280,217 @@ async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) - try: async with await create_session() as sess: + # state + hash + size for the current root rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases) - by_hash: dict[str, dict[str, bool]] = {} # {hash: {"any_fast_ok": bool}} - for state, size_db in rows: - h = state.asset_hash - acc = by_hash.get(h) + # Track fast_ok within the scanned root and whether the asset is hashed + by_asset: dict[str, dict[str, bool]] = {} + for state, a_hash, size_db in rows: + aid = state.asset_id + acc = by_asset.get(aid) if acc is None: - acc = {"any_fast_ok": False} - by_hash[h] = acc + acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)} + by_asset[aid] = acc try: st = os.stat(state.file_path, follow_symlinks=True) actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) fast_ok = False - if state.mtime_ns is not None and int(state.mtime_ns) == int(actual_mtime_ns): - if int(size_db) > 0 and int(st.st_size) == int(size_db): - fast_ok = True + if acc["hashed"]: + if state.mtime_ns is not None and int(state.mtime_ns) == int(actual_mtime_ns): + if int(acc["size_db"]) > 0 and int(st.st_size) == int(acc["size_db"]): + fast_ok = True if fast_ok: - acc["any_fast_ok"] = True + acc["any_fast_ok_here"] = True except FileNotFoundError: - pass # not fast_ok + pass except OSError as e: - _append_error(prog, phase="fast", path=state.file_path, message=str(e)) + _append_error(prog, path=state.file_path, message=str(e)) - for h, acc in by_hash.items(): + # Decide per asset, considering ALL its states (not just this root) + for aid, acc in by_asset.items(): try: - if acc["any_fast_ok"]: - await remove_missing_tag_for_asset_hash(sess, asset_hash=h) + if not acc["hashed"]: + # Never tag seed assets as missing + continue + + any_fast_ok_global = acc["any_fast_ok_here"] + if not any_fast_ok_global: + # Check other states outside this root + others = await list_cache_states_by_asset_id(sess, asset_id=aid) + for st in others: + try: + s = os.stat(st.file_path, follow_symlinks=True) + actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000)) + if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns): + if acc["size_db"] > 0 and int(s.st_size) == acc["size_db"]: + any_fast_ok_global = True + break + except OSError: + continue + + if any_fast_ok_global: + await remove_missing_tag_for_asset_id(sess, asset_id=aid) else: - await add_missing_tag_for_asset_hash(sess, asset_hash=h, origin="automatic") + await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") except Exception as ex: - _append_error(prog, phase="fast", path="", message=f"reconcile {h[:18]}: {ex}") + _append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}") await sess.commit() except Exception as e: - _append_error(prog, phase="fast", path="", message=f"reconcile failed: {e}") + _append_error(prog, path="", message=f"reconcile failed: {e}") -def _start_slow_workers( - root: RootType, - prog: ScanProgress, - state: SlowQueueState, - *, - progress_cb: Optional[Callable[[str, str, int, bool, dict], None]], -) -> None: +def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None: if state.workers: return - async def _worker(_worker_id: int): + async def _worker(_wid: int): while True: - item = await state.queue.get() + sid = await state.queue.get() try: - if item is None: + if sid is None: return try: - await asyncio.to_thread(assets_manager.populate_db_with_asset, item) - except Exception as e: - _append_error(prog, phase="slow", path=item, message=str(e)) + async with await create_session() as sess: + # Optional: fetch path for better error messages + st = await sess.get(AssetCacheState, sid) + try: + await compute_hash_and_dedup_for_cache_state(sess, state_id=sid) + await sess.commit() + except Exception as e: + path = st.file_path if st else f"state:{sid}" + _append_error(prog, path=path, message=str(e)) + raise + except Exception: + pass finally: - # Slow queue finished for this item; also counts toward overall processed - prog.slow_queue_finished += 1 prog.processed += 1 - if progress_cb: - progress_cb(root, "slow", prog.processed, False, { - "slow_queue_finished": prog.slow_queue_finished, - "slow_queue_total": prog.slow_queue_total, - }) finally: state.queue.task_done() - state.workers = [asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") for i in range(SLOW_HASH_CONCURRENCY)] + state.workers = [ + asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") + for i in range(SLOW_HASH_CONCURRENCY) + ] - async def _close_when_empty(): - # When the fast phase closed the queue, push sentinels to end workers + async def _close_when_ready(): while not state.closed: await asyncio.sleep(0.05) for _ in range(SLOW_HASH_CONCURRENCY): await state.queue.put(None) - asyncio.create_task(_close_when_empty()) + asyncio.create_task(_close_when_ready()) -async def _await_workers_then_finish( - root: RootType, - prog: ScanProgress, - state: SlowQueueState, - *, - progress_cb: Optional[Callable[[str, str, int, bool, dict], None]], -) -> None: +async def _await_state_workers_then_finish(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None: if state.workers: await asyncio.gather(*state.workers, return_exceptions=True) await _reconcile_missing_tags_for_root(root, prog) prog.finished_at = time.time() prog.status = "completed" - if progress_cb: - progress_cb(root, "slow", prog.processed, True, { - "slow_queue_finished": prog.slow_queue_finished, - "slow_queue_total": prog.slow_queue_total, - }) - - -def _collect_models_files() -> list[str]: - """Collect absolute file paths from configured model buckets under models_dir.""" - out: list[str] = [] - for folder_name, bases in get_comfy_models_folders(): - rel_files = folder_paths.get_filename_list(folder_name) or [] - for rel_path in rel_files: - abs_path = folder_paths.get_full_path(folder_name, rel_path) - if not abs_path: - continue - abs_path = os.path.abspath(abs_path) - # ensure within allowed bases - allowed = False - for b in bases: - base_abs = os.path.abspath(b) - with contextlib.suppress(Exception): - if os.path.commonpath([abs_path, base_abs]) == base_abs: - allowed = True - break - if allowed: - out.append(abs_path) - return out - - -def _count_files_in_tree(base_abs: str, *, only_nonzero: bool = False) -> int: - if not os.path.isdir(base_abs): - return 0 - total = 0 - for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): - if not only_nonzero: - total += len(filenames) - else: - for name in filenames: - with contextlib.suppress(OSError): - st = os.stat(os.path.join(dirpath, name), follow_symlinks=True) - if st.st_size: - total += 1 - return total - - -def _count_nonzero_in_list(paths: list[str]) -> int: - cnt = 0 - for p in paths: - with contextlib.suppress(OSError): - st = os.stat(p, follow_symlinks=True) - if st.st_size: - cnt += 1 - return cnt -async def _queue_tree_files(base_dir: str) -> asyncio.Queue: - """ - Walk base_dir in a worker thread and return a queue prefilled with all paths, - terminated by a single None sentinel for the draining loop in fast reconcile. - """ - q: asyncio.Queue = asyncio.Queue() - base_abs = os.path.abspath(base_dir) - if not os.path.isdir(base_abs): - await q.put(None) - return q - - def _walk_list(): - paths: list[str] = [] - for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): - for name in filenames: - paths.append(os.path.abspath(os.path.join(dirpath, name))) - return paths - - for p in await asyncio.to_thread(_walk_list): - await q.put(p) - await q.put(None) - return q - - -def _append_error(prog: ScanProgress, *, phase: Literal["fast", "slow"], path: str, message: str) -> None: +def _append_error(prog: ScanProgress, *, path: str, message: str) -> None: prog.file_errors.append({ "path": path, "message": message, - "phase": phase, - "at": _ts_to_iso(time.time()), + "at": ts_to_iso(time.time()), }) -def _ts_to_iso(ts: Optional[float]) -> Optional[str]: - if ts is None: - return None - # interpret ts as seconds since epoch UTC and return naive UTC (consistent with other models) - try: - return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat() - except Exception: - return None - - -def _new_scan_id(root: RootType) -> str: - return f"scan-{root}-{uuid.uuid4().hex[:8]}" +async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: + """ + Quick pass over asset_cache_state for `root`: + - If file missing and Asset.hash is NULL and the Asset has no other states, delete the Asset and its infos. + - If file missing and Asset.hash is NOT NULL: + * If at least one state for this Asset is fast-ok, delete the missing state. + * If none are fast-ok, add 'missing' tag to all AssetInfos for this Asset. + - If at least one state becomes fast-ok for a hashed Asset, remove the 'missing' tag. + """ + prefixes = prefixes_for_root(root) + if not prefixes: + return + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + conds.append(AssetCacheState.file_path.like(base + "%")) -def _console_cb(root: str, phase: str, total_processed: int, finished: bool, e: dict): - if phase == "fast": - if finished: - logging.info( - "[assets][%s] fast done: processed=%s/%s queued=%s", - root, - total_processed, - e["discovered"], - e["queued"], - ) - elif e.get("checked", 0) % 1000 == 0: # do not spam with fast progress - logging.info( - "[assets][%s] fast progress: processed=%s/%s", - root, - total_processed, - e["discovered"], - ) - elif phase == "slow": - if finished: - if e.get("slow_queue_finished", 0) or e.get("slow_queue_total", 0): - logging.info( - "[assets][%s] slow done: %s/%s", - root, - e.get("slow_queue_finished", 0), - e.get("slow_queue_total", 0), - ) - elif e.get('slow_queue_finished', 0) % 3 == 0: - logging.info( - "[assets][%s] slow progress: %s/%s", - root, - e.get("slow_queue_finished", 0), - e.get("slow_queue_total", 0), + async with await create_session() as sess: + if not conds: + return + + rows = ( + await sess.execute( + sa.select(AssetCacheState, Asset.hash, Asset.size_bytes) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) ) + ).all() + + # Group by asset_id with status per state + by_asset: dict[str, dict] = {} + for st, a_hash, a_size in rows: + aid = st.asset_id + acc = by_asset.get(aid) + if acc is None: + acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} + by_asset[aid] = acc + + exists = False + fast_ok = False + try: + s = os.stat(st.file_path, follow_symlinks=True) + exists = True + actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000)) + if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns): + if acc["size_db"] == 0 or int(s.st_size) == acc["size_db"]: + fast_ok = True + except FileNotFoundError: + exists = False + except OSError as ex: + exists = False + LOGGER.debug("fast pass stat error for %s: %s", st.file_path, ex) + + acc["states"].append({"obj": st, "exists": exists, "fast_ok": fast_ok}) + + # Apply actions + for aid, acc in by_asset.items(): + a_hash = acc["hash"] + states = acc["states"] + any_fast_ok = any(s["fast_ok"] for s in states) + all_missing = all(not s["exists"] for s in states) + missing_states = [s["obj"] for s in states if not s["exists"]] + + if a_hash is None: + # Seed asset: if all states gone (and in practice there is only one), remove the whole Asset + if states and all_missing: + await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid)) + asset = await sess.get(Asset, aid) + if asset: + await sess.delete(asset) + # else leave it for the slow scan to verify/rehash + else: + if any_fast_ok: + # Remove 'missing' and delete just the stale state rows + for st in missing_states: + try: + await sess.delete(await sess.get(AssetCacheState, st.id)) + except Exception: + pass + try: + await remove_missing_tag_for_asset_id(sess, asset_id=aid) + except Exception: + pass + else: + # No fast-ok path: mark as missing + try: + await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") + except Exception: + pass + + await sess.flush() + await sess.commit() diff --git a/app/database/_helpers.py b/app/database/_helpers.py deleted file mode 100644 index a031e861cece..000000000000 --- a/app/database/_helpers.py +++ /dev/null @@ -1,186 +0,0 @@ -from decimal import Decimal -from typing import Any, Sequence, Optional, Iterable - -import sqlalchemy as sa -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, exists - -from .models import AssetInfo, AssetInfoTag, Tag, AssetInfoMeta -from .._assets_helpers import normalize_tags - - -async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: - wanted = normalize_tags(list(names)) - if not wanted: - return [] - existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() - by_name = {t.name: t for t in existing} - to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name] - if to_create: - session.add_all(to_create) - await session.flush() - by_name.update({t.name: t for t in to_create}) - return [by_name[n] for n in wanted] - - -def apply_tag_filters( - stmt: sa.sql.Select, - include_tags: Optional[Sequence[str]], - exclude_tags: Optional[Sequence[str]], -) -> sa.sql.Select: - """include_tags: every tag must be present; exclude_tags: none may be present.""" - include_tags = normalize_tags(include_tags) - exclude_tags = normalize_tags(exclude_tags) - - if include_tags: - for tag_name in include_tags: - stmt = stmt.where( - exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name == tag_name) - ) - ) - - if exclude_tags: - stmt = stmt.where( - ~exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name.in_(exclude_tags)) - ) - ) - return stmt - - -def apply_metadata_filter( - stmt: sa.sql.Select, - metadata_filter: Optional[dict], -) -> sa.sql.Select: - """Apply metadata filters using the projection table asset_info_meta. - - Semantics: - - For scalar values: require EXISTS(asset_info_meta) with matching key + typed value. - - For None: key is missing OR key has explicit null (val_json IS NULL). - - For list values: ANY-of the list elements matches (EXISTS for any). - (Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))') - """ - if not metadata_filter: - return stmt - - def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: - return sa.exists().where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - *preds, - ) - - def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: - # Missing OR null: - if value is None: - # either: no row for key OR a row for key with explicit null - no_row_for_key = sa.not_( - sa.exists().where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - ) - ) - null_row = _exists_for_pred( - key, - AssetInfoMeta.val_json.is_(None), - AssetInfoMeta.val_str.is_(None), - AssetInfoMeta.val_num.is_(None), - AssetInfoMeta.val_bool.is_(None), - ) - return sa.or_(no_row_for_key, null_row) - - # Typed scalar matches: - if isinstance(value, bool): - return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) - if isinstance(value, (int, float, Decimal)): - # store as Decimal for equality against NUMERIC(38,10) - num = value if isinstance(value, Decimal) else Decimal(str(value)) - return _exists_for_pred(key, AssetInfoMeta.val_num == num) - if isinstance(value, str): - return _exists_for_pred(key, AssetInfoMeta.val_str == value) - - # Complex: compare JSON (no index, but supported) - return _exists_for_pred(key, AssetInfoMeta.val_json == value) - - for k, v in metadata_filter.items(): - if isinstance(v, list): - # ANY-of (exists for any element) - ors = [_exists_clause_for_value(k, elem) for elem in v] - if ors: - stmt = stmt.where(sa.or_(*ors)) - else: - stmt = stmt.where(_exists_clause_for_value(k, v)) - return stmt - - -def is_scalar(v: Any) -> bool: - if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries - return True - if isinstance(v, bool): - return True - if isinstance(v, (int, float, Decimal, str)): - return True - return False - - -def project_kv(key: str, value: Any) -> list[dict]: - """ - Turn a metadata key/value into one or more projection rows: - - scalar -> one row (ordinal=0) in the proper typed column - - list of scalars -> one row per element with ordinal=i - - dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list) - - None -> single row with all value columns NULL - Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...} - """ - rows: list[dict] = [] - - def _null_row(ordinal: int) -> dict: - return { - "key": key, "ordinal": ordinal, - "val_str": None, "val_num": None, "val_bool": None, "val_json": None - } - - if value is None: - rows.append(_null_row(0)) - return rows - - if is_scalar(value): - if isinstance(value, bool): - rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) - elif isinstance(value, (int, float, Decimal)): - # store numeric; SQLAlchemy will coerce to Numeric - num = value if isinstance(value, Decimal) else Decimal(str(value)) - rows.append({"key": key, "ordinal": 0, "val_num": num}) - elif isinstance(value, str): - rows.append({"key": key, "ordinal": 0, "val_str": value}) - else: - # Fallback to json - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows - - if isinstance(value, list): - if all(is_scalar(x) for x in value): - for i, x in enumerate(value): - if x is None: - rows.append(_null_row(i)) - elif isinstance(x, bool): - rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) - elif isinstance(x, (int, float, Decimal)): - num = x if isinstance(x, Decimal) else Decimal(str(x)) - rows.append({"key": key, "ordinal": i, "val_num": num}) - elif isinstance(x, str): - rows.append({"key": key, "ordinal": i, "val_str": x}) - else: - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - # list contains objects -> one val_json per element - for i, x in enumerate(value): - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - - # Dict or any other structure -> single json row - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows diff --git a/app/database/db.py b/app/database/db.py index 8280272b0216..82c9cc737c39 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -4,14 +4,20 @@ from contextlib import asynccontextmanager from typing import Optional -from comfy.cli_args import args from alembic import command from alembic.config import Config from alembic.runtime.migration import MigrationContext from alembic.script import ScriptDirectory from sqlalchemy import create_engine, text from sqlalchemy.engine import make_url -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) + +from comfy.cli_args import args LOGGER = logging.getLogger(__name__) ENGINE: Optional[AsyncEngine] = None diff --git a/app/database/helpers/__init__.py b/app/database/helpers/__init__.py new file mode 100644 index 000000000000..19d7507fa44f --- /dev/null +++ b/app/database/helpers/__init__.py @@ -0,0 +1,23 @@ +from .filters import apply_metadata_filter, apply_tag_filters +from .ownership import visible_owner_clause +from .projection import is_scalar, project_kv +from .tags import ( + add_missing_tag_for_asset_hash, + add_missing_tag_for_asset_id, + ensure_tags_exist, + remove_missing_tag_for_asset_hash, + remove_missing_tag_for_asset_id, +) + +__all__ = [ + "apply_tag_filters", + "apply_metadata_filter", + "is_scalar", + "project_kv", + "ensure_tags_exist", + "add_missing_tag_for_asset_id", + "add_missing_tag_for_asset_hash", + "remove_missing_tag_for_asset_id", + "remove_missing_tag_for_asset_hash", + "visible_owner_clause", +] diff --git a/app/database/helpers/filters.py b/app/database/helpers/filters.py new file mode 100644 index 000000000000..0b6d85b8d572 --- /dev/null +++ b/app/database/helpers/filters.py @@ -0,0 +1,87 @@ +from typing import Optional, Sequence + +import sqlalchemy as sa +from sqlalchemy import exists + +from ..._assets_helpers import normalize_tags +from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag + + +def apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Optional[Sequence[str]], + exclude_tags: Optional[Sequence[str]], +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + + +def apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: Optional[dict], +) -> sa.sql.Select: + """Apply filters using asset_info_meta projection table.""" + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + return sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + *preds, + ) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + if value is None: + no_row_for_key = sa.not_( + sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + ) + ) + null_row = _exists_for_pred( + key, + AssetInfoMeta.val_json.is_(None), + AssetInfoMeta.val_str.is_(None), + AssetInfoMeta.val_num.is_(None), + AssetInfoMeta.val_bool.is_(None), + ) + return sa.or_(no_row_for_key, null_row) + + if isinstance(value, bool): + return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) + if isinstance(value, (int, float)): + from decimal import Decimal + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetInfoMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetInfoMeta.val_str == value) + return _exists_for_pred(key, AssetInfoMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt diff --git a/app/database/helpers/ownership.py b/app/database/helpers/ownership.py new file mode 100644 index 000000000000..c0073160831e --- /dev/null +++ b/app/database/helpers/ownership.py @@ -0,0 +1,12 @@ +import sqlalchemy as sa + +from ..models import AssetInfo + + +def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" + + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetInfo.owner_id == "" + return AssetInfo.owner_id.in_(["", owner_id]) diff --git a/app/database/helpers/projection.py b/app/database/helpers/projection.py new file mode 100644 index 000000000000..687802d1803c --- /dev/null +++ b/app/database/helpers/projection.py @@ -0,0 +1,64 @@ +from decimal import Decimal + + +def is_scalar(v): + if v is None: + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + + +def project_kv(key: str, value): + """ + Turn a metadata key/value into typed projection rows. + Returns list[dict] with keys: + key, ordinal, and one of val_str / val_num / val_bool / val_json (others None) + """ + rows: list[dict] = [] + + def _null_row(ordinal: int) -> dict: + return { + "key": key, "ordinal": ordinal, + "val_str": None, "val_num": None, "val_bool": None, "val_json": None + } + + if value is None: + rows.append(_null_row(0)) + return rows + + if is_scalar(value): + if isinstance(value, bool): + rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) + elif isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + rows.append({"key": key, "ordinal": 0, "val_num": num}) + elif isinstance(value, str): + rows.append({"key": key, "ordinal": 0, "val_str": value}) + else: + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows + + if isinstance(value, list): + if all(is_scalar(x) for x in value): + for i, x in enumerate(value): + if x is None: + rows.append(_null_row(i)) + elif isinstance(x, bool): + rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) + elif isinstance(x, (int, float, Decimal)): + num = x if isinstance(x, Decimal) else Decimal(str(x)) + rows.append({"key": key, "ordinal": i, "val_num": num}) + elif isinstance(x, str): + rows.append({"key": key, "ordinal": i, "val_str": x}) + else: + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + for i, x in enumerate(value): + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows diff --git a/app/database/helpers/tags.py b/app/database/helpers/tags.py new file mode 100644 index 000000000000..47934309654a --- /dev/null +++ b/app/database/helpers/tags.py @@ -0,0 +1,102 @@ +from typing import Iterable + +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from ..._assets_helpers import normalize_tags +from ..models import Asset, AssetInfo, AssetInfoTag, Tag +from ..timeutil import utcnow + + +async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: + wanted = normalize_tags(list(names)) + if not wanted: + return [] + existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() + by_name = {t.name: t for t in existing} + to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name] + if to_create: + session.add_all(to_create) + await session.flush() + by_name.update({t.name: t for t in to_create}) + return [by_name[n] for n in wanted] + + +async def add_missing_tag_for_asset_id( + session: AsyncSession, + *, + asset_id: str, + origin: str = "automatic", +) -> int: + """Ensure every AssetInfo for asset_id has 'missing' tag.""" + ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all() + if not ids: + return 0 + + existing = { + asset_info_id + for (asset_info_id,) in ( + await session.execute( + select(AssetInfoTag.asset_info_id).where( + AssetInfoTag.asset_info_id.in_(ids), + AssetInfoTag.tag_name == "missing", + ) + ) + ).all() + } + to_add = [i for i in ids if i not in existing] + if not to_add: + return 0 + + now = utcnow() + session.add_all( + [ + AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now) + for i in to_add + ] + ) + await session.flush() + return len(to_add) + + +async def add_missing_tag_for_asset_hash( + session: AsyncSession, + *, + asset_hash: str, + origin: str = "automatic", +) -> int: + asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first() + if not asset: + return 0 + return await add_missing_tag_for_asset_id(session, asset_id=asset.id, origin=origin) + + +async def remove_missing_tag_for_asset_id( + session: AsyncSession, + *, + asset_id: str, +) -> int: + """Remove the 'missing' tag from all AssetInfos for asset_id.""" + ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all() + if not ids: + return 0 + + res = await session.execute( + delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(ids), + AssetInfoTag.tag_name == "missing", + ) + ) + await session.flush() + return int(res.rowcount or 0) + + +async def remove_missing_tag_for_asset_hash( + session: AsyncSession, + *, + asset_hash: str, +) -> int: + asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first() + if not asset: + return 0 + return await remove_missing_tag_for_asset_id(session, asset_id=asset.id) diff --git a/app/database/models.py b/app/database/models.py index 55fc08e512cc..6a6798bcfd5a 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,27 +1,26 @@ +import uuid from datetime import datetime from typing import Any, Optional -import uuid from sqlalchemy import ( - Integer, + JSON, BigInteger, + Boolean, + CheckConstraint, DateTime, ForeignKey, Index, - UniqueConstraint, - JSON, + Integer, + Numeric, String, Text, - CheckConstraint, - Numeric, - Boolean, + UniqueConstraint, ) from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign +from sqlalchemy.orm import DeclarativeBase, Mapped, foreign, mapped_column, relationship from .timeutil import utcnow - JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql') @@ -46,7 +45,8 @@ def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: class Asset(Base): __tablename__ = "assets" - hash: Mapped[str] = mapped_column(String(256), primary_key=True) + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True) size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) mime_type: Mapped[Optional[str]] = mapped_column(String(255)) created_at: Mapped[datetime] = mapped_column( @@ -56,8 +56,8 @@ class Asset(Base): infos: Mapped[list["AssetInfo"]] = relationship( "AssetInfo", back_populates="asset", - primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash), - foreign_keys=lambda: [AssetInfo.asset_hash], + primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id), + foreign_keys=lambda: [AssetInfo.asset_id], cascade="all,delete-orphan", passive_deletes=True, ) @@ -65,8 +65,8 @@ class Asset(Base): preview_of: Mapped[list["AssetInfo"]] = relationship( "AssetInfo", back_populates="preview_asset", - primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash), - foreign_keys=lambda: [AssetInfo.preview_hash], + primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id), + foreign_keys=lambda: [AssetInfo.preview_id], viewonly=True, ) @@ -76,36 +76,32 @@ class Asset(Base): passive_deletes=True, ) - locations: Mapped[list["AssetLocation"]] = relationship( - back_populates="asset", - cascade="all, delete-orphan", - passive_deletes=True, - ) - __table_args__ = ( Index("ix_assets_mime_type", "mime_type"), + CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), ) def to_dict(self, include_none: bool = False) -> dict[str, Any]: return to_dict(self, include_none=include_none) def __repr__(self) -> str: - return f"" + return f"" class AssetCacheState(Base): __tablename__ = "asset_cache_state" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False) + asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False) file_path: Mapped[str] = mapped_column(Text, nullable=False) mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True) + needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) asset: Mapped["Asset"] = relationship(back_populates="cache_states") __table_args__ = ( Index("ix_asset_cache_state_file_path", "file_path"), - Index("ix_asset_cache_state_asset_hash", "asset_hash"), + Index("ix_asset_cache_state_asset_id", "asset_id"), CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), ) @@ -114,27 +110,7 @@ def to_dict(self, include_none: bool = False) -> dict[str, Any]: return to_dict(self, include_none=include_none) def __repr__(self) -> str: - return f"" - - -class AssetLocation(Base): - __tablename__ = "asset_locations" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False) - provider: Mapped[str] = mapped_column(String(32), nullable=False) # "gcs" - locator: Mapped[str] = mapped_column(Text, nullable=False) # "gs://bucket/object" - expected_size_bytes: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True) - etag: Mapped[Optional[str]] = mapped_column(String(256), nullable=True) - last_modified: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) - - asset: Mapped["Asset"] = relationship(back_populates="locations") - - __table_args__ = ( - UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"), - Index("ix_asset_locations_hash", "asset_hash"), - Index("ix_asset_locations_provider", "provider"), - ) + return f"" class AssetInfo(Base): @@ -143,31 +119,23 @@ class AssetInfo(Base): id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") name: Mapped[str] = mapped_column(String(512), nullable=False) - asset_hash: Mapped[str] = mapped_column( - String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False - ) - preview_hash: Mapped[Optional[str]] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL")) + asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False) + preview_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL")) user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True)) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow - ) - updated_at: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow - ) - last_access_time: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) - # Relationships asset: Mapped[Asset] = relationship( "Asset", back_populates="infos", - foreign_keys=[asset_hash], + foreign_keys=[asset_id], + lazy="selectin", ) preview_asset: Mapped[Optional[Asset]] = relationship( "Asset", back_populates="preview_of", - foreign_keys=[preview_hash], + foreign_keys=[preview_id], ) metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship( @@ -186,16 +154,16 @@ class AssetInfo(Base): tags: Mapped[list["Tag"]] = relationship( secondary="asset_info_tags", back_populates="asset_infos", - lazy="joined", + lazy="selectin", viewonly=True, overlaps="tag_links,asset_info_links,asset_infos,tag", ) __table_args__ = ( - UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"), + UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), Index("ix_assets_info_owner_name", "owner_id", "name"), Index("ix_assets_info_owner_id", "owner_id"), - Index("ix_assets_info_asset_hash", "asset_hash"), + Index("ix_assets_info_asset_id", "asset_id"), Index("ix_assets_info_name", "name"), Index("ix_assets_info_created_at", "created_at"), Index("ix_assets_info_last_access_time", "last_access_time"), @@ -207,7 +175,7 @@ def to_dict(self, include_none: bool = False) -> dict[str, Any]: return data def __repr__(self) -> str: - return f"" + return f"" class AssetInfoMeta(Base): diff --git a/app/database/services.py b/app/database/services.py deleted file mode 100644 index 842103e9ee0e..000000000000 --- a/app/database/services.py +++ /dev/null @@ -1,1116 +0,0 @@ -import contextlib -import os -import logging -from collections import defaultdict -from datetime import datetime -from typing import Any, Sequence, Optional, Union - -import sqlalchemy as sa -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, delete, func -from sqlalchemy.orm import contains_eager, noload -from sqlalchemy.exc import IntegrityError -from sqlalchemy.dialects import sqlite as d_sqlite -from sqlalchemy.dialects import postgresql as d_pg - -from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation -from .timeutil import utcnow -from .._assets_helpers import normalize_tags, visible_owner_clause, compute_model_relative_filename -from . import _helpers - - -async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: - row = ( - await session.execute( - select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) - ) - ).first() - return row is not None - - -async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]: - return await session.get(Asset, asset_hash) - - -async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]: - return await session.get(AssetInfo, asset_info_id) - - -async def asset_info_exists_for_hash(session: AsyncSession, *, asset_hash: str) -> bool: - return ( - await session.execute( - sa.select(sa.literal(True)) - .select_from(AssetInfo) - .where(AssetInfo.asset_hash == asset_hash) - .limit(1) - ) - ).first() is not None - - -async def check_fs_asset_exists_quick( - session, - *, - file_path: str, - size_bytes: Optional[int] = None, - mtime_ns: Optional[int] = None, -) -> bool: - """ - Returns 'True' if there is already AssetCacheState record that matches this absolute path, - AND (if provided) mtime_ns matches stored locator-state, - AND (if provided) size_bytes matches verified size when known. - """ - locator = os.path.abspath(file_path) - - stmt = select(sa.literal(True)).select_from(AssetCacheState).join( - Asset, Asset.hash == AssetCacheState.asset_hash - ).where(AssetCacheState.file_path == locator).limit(1) - - conds = [] - if mtime_ns is not None: - conds.append(AssetCacheState.mtime_ns == int(mtime_ns)) - if size_bytes is not None: - conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes))) - - if conds: - stmt = stmt.where(*conds) - - row = (await session.execute(stmt)).first() - return row is not None - - -async def ingest_fs_asset( - session: AsyncSession, - *, - asset_hash: str, - abs_path: str, - size_bytes: int, - mtime_ns: int, - mime_type: Optional[str] = None, - info_name: Optional[str] = None, - owner_id: str = "", - preview_hash: Optional[str] = None, - user_metadata: Optional[dict] = None, - tags: Sequence[str] = (), - tag_origin: str = "manual", - require_existing_tags: bool = False, -) -> dict: - """ - Upsert Asset identity row + cache state(s) pointing at local file. - - Always: - - Insert Asset if missing; - - Insert AssetCacheState if missing; else update mtime_ns and asset_hash if different. - - Optionally (when info_name is provided): - - Create or update an AssetInfo on (asset_hash, owner_id, name). - - Link provided tags to that AssetInfo. - * If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table. - * If False (default), create unknown tags. - - Returns flags and ids: - { - "asset_created": bool, - "asset_updated": bool, - "state_created": bool, - "state_updated": bool, - "asset_info_id": str | None, - } - """ - locator = os.path.abspath(abs_path) - datetime_now = utcnow() - - out: dict[str, Any] = { - "asset_created": False, - "asset_updated": False, - "state_created": False, - "state_updated": False, - "asset_info_id": None, - } - - # ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ---- - with contextlib.suppress(IntegrityError): - async with session.begin_nested(): - session.add( - Asset( - hash=asset_hash, - size_bytes=int(size_bytes), - mime_type=mime_type, - created_at=datetime_now, - ) - ) - await session.flush() - out["asset_created"] = True - - if not out["asset_created"]: - existing = await session.get(Asset, asset_hash) - if existing is not None: - changed = False - if existing.size_bytes != size_bytes: - existing.size_bytes = size_bytes - changed = True - if mime_type and existing.mime_type != mime_type: - existing.mime_type = mime_type - changed = True - if changed: - out["asset_updated"] = True - else: - logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash) - - # ---- Step 2: UPSERT AssetCacheState (mtime_ns, file_path) ---- - dialect = session.bind.dialect.name # "sqlite" or "postgresql" - vals = { - "asset_hash": asset_hash, - "file_path": locator, - "mtime_ns": int(mtime_ns), - } - # 2-step idempotent write so we can set flags deterministically: - # INSERT ... ON CONFLICT(file_path) DO NOTHING - # if conflicted, UPDATE only when values actually change - if dialect == "sqlite": - ins = ( - d_sqlite.insert(AssetCacheState) - .values(**vals) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - ) - elif dialect == "postgresql": - ins = ( - d_pg.insert(AssetCacheState) - .values(**vals) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - ) - else: - raise NotImplementedError(f"Unsupported database dialect: {dialect}") - res = await session.execute(ins) - if int(res.rowcount or 0) > 0: - out["state_created"] = True - else: - upd = ( - sa.update(AssetCacheState) - .where(AssetCacheState.file_path == locator) - .where( - sa.or_( - AssetCacheState.asset_hash != asset_hash, - AssetCacheState.mtime_ns.is_(None), - AssetCacheState.mtime_ns != int(mtime_ns), - ) - ) - .values(asset_hash=asset_hash, mtime_ns=int(mtime_ns)) - ) - res2 = await session.execute(upd) - if int(res2.rowcount or 0) > 0: - out["state_updated"] = True - - # ---- Optional: AssetInfo + tag links ---- - if info_name: - # 2a) Upsert AssetInfo idempotently on (asset_hash, owner_id, name) - with contextlib.suppress(IntegrityError): - async with session.begin_nested(): - info = AssetInfo( - owner_id=owner_id, - name=info_name, - asset_hash=asset_hash, - preview_hash=preview_hash, - created_at=datetime_now, - updated_at=datetime_now, - last_access_time=datetime_now, - ) - session.add(info) - await session.flush() # get info.id (UUID) - out["asset_info_id"] = info.id - - existing_info = ( - await session.execute( - select(AssetInfo) - .where( - AssetInfo.asset_hash == asset_hash, - AssetInfo.name == info_name, - (AssetInfo.owner_id == owner_id), - ) - .limit(1) - ) - ).unique().scalar_one_or_none() - if not existing_info: - raise RuntimeError("Failed to update or insert AssetInfo.") - - if preview_hash is not None and existing_info.preview_hash != preview_hash: - existing_info.preview_hash = preview_hash - existing_info.updated_at = datetime_now - if existing_info.last_access_time < datetime_now: - existing_info.last_access_time = datetime_now - await session.flush() - out["asset_info_id"] = existing_info.id - - # 2b) Link tags (if any). We DO NOT create new Tag rows here by default. - norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] - if norm and out["asset_info_id"] is not None: - if not require_existing_tags: - await _helpers.ensure_tags_exist(session, norm, tag_type="user") - - # Which tags exist? - existing_tag_names = set( - name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() - ) - missing = [t for t in norm if t not in existing_tag_names] - if missing and require_existing_tags: - raise ValueError(f"Unknown tags: {missing}") - - # Which links already exist? - existing_links = set( - tag_name - for (tag_name,) in ( - await session.execute( - select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) - ) - ).all() - ) - to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] - if to_add: - session.add_all( - [ - AssetInfoTag( - asset_info_id=out["asset_info_id"], - tag_name=t, - origin=tag_origin, - added_at=datetime_now, - ) - for t in to_add - ] - ) - await session.flush() - - # 2c) Rebuild metadata projection if provided - # Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore - # if user_metadata is not None and out["asset_info_id"] is not None: - # await replace_asset_info_metadata_projection( - # session, - # asset_info_id=out["asset_info_id"], - # user_metadata=user_metadata, - # ) - # start of adding metadata["filename"] - if out["asset_info_id"] is not None: - primary_path = ( - await session.execute( - select(AssetCacheState.file_path) - .where(AssetCacheState.asset_hash == asset_hash) - .order_by(AssetCacheState.id.asc()) - .limit(1) - ) - ).scalars().first() - computed_filename = compute_model_relative_filename(primary_path) if primary_path else None - - # Start from current metadata on this AssetInfo, if any - current_meta = existing_info.user_metadata or {} - new_meta = dict(current_meta) - - # Merge caller-provided metadata, if any (caller keys override current) - if user_metadata is not None: - for k, v in user_metadata.items(): - new_meta[k] = v - - # Enforce correct model-relative filename when known - if computed_filename: - new_meta["filename"] = computed_filename - - # Only write when there is a change - if new_meta != current_meta: - await replace_asset_info_metadata_projection( - session, - asset_info_id=out["asset_info_id"], - user_metadata=new_meta, - ) - # end of adding metadata["filename"] - try: - await remove_missing_tag_for_asset_hash(session, asset_hash=asset_hash) - except Exception: - logging.exception("Failed to clear 'missing' tag for %s", asset_hash) - return out - - -async def touch_asset_infos_by_fs_path( - session: AsyncSession, - *, - file_path: str, - ts: Optional[datetime] = None, - only_if_newer: bool = True, -) -> int: - locator = os.path.abspath(file_path) - ts = ts or utcnow() - - stmt = sa.update(AssetInfo).where( - sa.exists( - sa.select(sa.literal(1)) - .select_from(AssetCacheState) - .where( - AssetCacheState.asset_hash == AssetInfo.asset_hash, - AssetCacheState.file_path == locator, - ) - ) - ) - - if only_if_newer: - stmt = stmt.where( - sa.or_( - AssetInfo.last_access_time.is_(None), - AssetInfo.last_access_time < ts, - ) - ) - - stmt = stmt.values(last_access_time=ts) - - res = await session.execute(stmt) - return int(res.rowcount or 0) - - -async def touch_asset_info_by_id( - session: AsyncSession, - *, - asset_info_id: str, - ts: Optional[datetime] = None, - only_if_newer: bool = True, -) -> int: - ts = ts or utcnow() - stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) - if only_if_newer: - stmt = stmt.where( - sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) - ) - stmt = stmt.values(last_access_time=ts) - res = await session.execute(stmt) - return int(res.rowcount or 0) - - -async def list_asset_infos_page( - session: AsyncSession, - *, - owner_id: str = "", - include_tags: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, - name_contains: Optional[str] = None, - metadata_filter: Optional[dict] = None, - limit: int = 20, - offset: int = 0, - sort: str = "created_at", - order: str = "desc", -) -> tuple[list[AssetInfo], dict[str, list[str]], int]: - """Return page of AssetInfo rows in the viewers visibility.""" - base = ( - select(AssetInfo) - .join(Asset, Asset.hash == AssetInfo.asset_hash) - .options(contains_eager(AssetInfo.asset)) - .where(visible_owner_clause(owner_id)) - ) - - if name_contains: - base = base.where(AssetInfo.name.ilike(f"%{name_contains}%")) - - base = _helpers.apply_tag_filters(base, include_tags, exclude_tags) - base = _helpers.apply_metadata_filter(base, metadata_filter) - - sort = (sort or "created_at").lower() - order = (order or "desc").lower() - sort_map = { - "name": AssetInfo.name, - "created_at": AssetInfo.created_at, - "updated_at": AssetInfo.updated_at, - "last_access_time": AssetInfo.last_access_time, - "size": Asset.size_bytes, - } - sort_col = sort_map.get(sort, AssetInfo.created_at) - sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() - - base = base.order_by(sort_exp).limit(limit).offset(offset) - - count_stmt = ( - select(func.count()) - .select_from(AssetInfo) - .join(Asset, Asset.hash == AssetInfo.asset_hash) - .where(visible_owner_clause(owner_id)) - ) - if name_contains: - count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%")) - count_stmt = _helpers.apply_tag_filters(count_stmt, include_tags, exclude_tags) - count_stmt = _helpers.apply_metadata_filter(count_stmt, metadata_filter) - - total = int((await session.execute(count_stmt)).scalar_one() or 0) - - infos = (await session.execute(base)).scalars().unique().all() - - # Collect tags in bulk - id_list: list[str] = [i.id for i in infos] - tag_map: dict[str, list[str]] = defaultdict(list) - if id_list: - rows = await session.execute( - select(AssetInfoTag.asset_info_id, Tag.name) - .join(Tag, Tag.name == AssetInfoTag.tag_name) - .where(AssetInfoTag.asset_info_id.in_(id_list)) - ) - for aid, tag_name in rows.all(): - tag_map[aid].append(tag_name) - - return infos, tag_map, total - - -async def fetch_asset_info_and_asset( - session: AsyncSession, - *, - asset_info_id: str, - owner_id: str = "", -) -> Optional[tuple[AssetInfo, Asset]]: - stmt = ( - select(AssetInfo, Asset) - .join(Asset, Asset.hash == AssetInfo.asset_hash) - .where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - .limit(1) - ) - row = await session.execute(stmt) - pair = row.first() - if not pair: - return None - return pair[0], pair[1] - - -async def fetch_asset_info_asset_and_tags( - session: AsyncSession, - *, - asset_info_id: str, - owner_id: str = "", -) -> Optional[tuple[AssetInfo, Asset, list[str]]]: - stmt = ( - select(AssetInfo, Asset, Tag.name) - .join(Asset, Asset.hash == AssetInfo.asset_hash) - .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) - .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) - .where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - .options(noload(AssetInfo.tags)) - .order_by(Tag.name.asc()) - ) - - rows = (await session.execute(stmt)).all() - if not rows: - return None - - # First row contains the mapped entities; tags may repeat across rows - first_info, first_asset, _ = rows[0] - tags: list[str] = [] - seen: set[str] = set() - for _info, _asset, tag_name in rows: - if tag_name and tag_name not in seen: - seen.add(tag_name) - tags.append(tag_name) - return first_info, first_asset, tags - - -async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]: - """Return the oldest cache row for this asset.""" - return ( - await session.execute( - select(AssetCacheState) - .where(AssetCacheState.asset_hash == asset_hash) - .order_by(AssetCacheState.id.asc()) - .limit(1) - ) - ).scalars().first() - - -async def list_cache_states_by_asset_hash( - session: AsyncSession, *, asset_hash: str -) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]: - """Return all cache rows for this asset ordered by oldest first.""" - return ( - await session.execute( - select(AssetCacheState) - .where(AssetCacheState.asset_hash == asset_hash) - .order_by(AssetCacheState.id.asc()) - ) - ).scalars().all() - - -async def list_asset_locations( - session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None -) -> Union[list[AssetLocation], Sequence[AssetLocation]]: - stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash) - if provider: - stmt = stmt.where(AssetLocation.provider == provider) - return (await session.execute(stmt)).scalars().all() - - -async def upsert_asset_location( - session: AsyncSession, - *, - asset_hash: str, - provider: str, - locator: str, - expected_size_bytes: Optional[int] = None, - etag: Optional[str] = None, - last_modified: Optional[str] = None, -) -> AssetLocation: - loc = ( - await session.execute( - select(AssetLocation).where( - AssetLocation.asset_hash == asset_hash, - AssetLocation.provider == provider, - AssetLocation.locator == locator, - ).limit(1) - ) - ).scalars().first() - if loc: - changed = False - if expected_size_bytes is not None and loc.expected_size_bytes != expected_size_bytes: - loc.expected_size_bytes = expected_size_bytes - changed = True - if etag is not None and loc.etag != etag: - loc.etag = etag - changed = True - if last_modified is not None and loc.last_modified != last_modified: - loc.last_modified = last_modified - changed = True - if changed: - await session.flush() - return loc - - loc = AssetLocation( - asset_hash=asset_hash, - provider=provider, - locator=locator, - expected_size_bytes=expected_size_bytes, - etag=etag, - last_modified=last_modified, - ) - session.add(loc) - await session.flush() - return loc - - -async def create_asset_info_for_existing_asset( - session: AsyncSession, - *, - asset_hash: str, - name: str, - user_metadata: Optional[dict] = None, - tags: Optional[Sequence[str]] = None, - tag_origin: str = "manual", - owner_id: str = "", -) -> AssetInfo: - """Create a new AssetInfo referencing an existing Asset. If row already exists, return it unchanged.""" - now = utcnow() - info = AssetInfo( - owner_id=owner_id, - name=name, - asset_hash=asset_hash, - preview_hash=None, - created_at=now, - updated_at=now, - last_access_time=now, - ) - try: - async with session.begin_nested(): - session.add(info) - await session.flush() # get info.id - except IntegrityError: - existing = ( - await session.execute( - select(AssetInfo) - .where( - AssetInfo.asset_hash == asset_hash, - AssetInfo.name == name, - AssetInfo.owner_id == owner_id, - ) - .limit(1) - ) - ).scalars().first() - if not existing: - raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") - return existing - - # Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore - # if user_metadata is not None: - # await replace_asset_info_metadata_projection( - # session, asset_info_id=info.id, user_metadata=user_metadata - # ) - - # start of adding metadata["filename"] - new_meta = dict(user_metadata or {}) - - computed_filename = None - try: - state = await get_cache_state_by_asset_hash(session, asset_hash=asset_hash) - if state and state.file_path: - computed_filename = compute_model_relative_filename(state.file_path) - except Exception: - computed_filename = None - - if computed_filename: - new_meta["filename"] = computed_filename - - if new_meta: - await replace_asset_info_metadata_projection( - session, - asset_info_id=info.id, - user_metadata=new_meta, - ) - # end of adding metadata["filename"] - - if tags is not None: - await set_asset_info_tags( - session, - asset_info_id=info.id, - tags=tags, - origin=tag_origin, - ) - return info - - -async def set_asset_info_tags( - session: AsyncSession, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", -) -> dict: - """ - Replace the tag set on an AssetInfo with `tags`. Idempotent. - Creates missing tag names as 'user'. - """ - desired = normalize_tags(tags) - - # current links - current = set( - tag_name for (tag_name,) in ( - await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) - ).all() - ) - - to_add = [t for t in desired if t not in current] - to_remove = [t for t in current if t not in desired] - - if to_add: - await _helpers.ensure_tags_exist(session, to_add, tag_type="user") - session.add_all([ - AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) - for t in to_add - ]) - await session.flush() - - if to_remove: - await session.execute( - delete(AssetInfoTag) - .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) - ) - await session.flush() - - return {"added": to_add, "removed": to_remove, "total": desired} - - -async def update_asset_info_full( - session: AsyncSession, - *, - asset_info_id: str, - name: Optional[str] = None, - tags: Optional[Sequence[str]] = None, - user_metadata: Optional[dict] = None, - tag_origin: str = "manual", - asset_info_row: Any = None, -) -> AssetInfo: - """ - Update AssetInfo fields: - - name (if provided) - - user_metadata blob + rebuild projection (if provided) - - replace tags with provided set (if provided) - Returns the updated AssetInfo. - """ - if not asset_info_row: - info = await session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - else: - info = asset_info_row - - touched = False - if name is not None and name != info.name: - info.name = name - touched = True - - # Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore - # if user_metadata is not None: - # await replace_asset_info_metadata_projection( - # session, asset_info_id=asset_info_id, user_metadata=user_metadata - # ) - # touched = True - - # start of adding metadata["filename"] - computed_filename = None - try: - state = await get_cache_state_by_asset_hash(session, asset_hash=info.asset_hash) - if state and state.file_path: - computed_filename = compute_model_relative_filename(state.file_path) - except Exception: - computed_filename = None - - if user_metadata is not None: - new_meta = dict(user_metadata) - if computed_filename: - new_meta["filename"] = computed_filename - await replace_asset_info_metadata_projection( - session, asset_info_id=asset_info_id, user_metadata=new_meta - ) - touched = True - else: - if computed_filename: - current_meta = info.user_metadata or {} - if current_meta.get("filename") != computed_filename: - new_meta = dict(current_meta) - new_meta["filename"] = computed_filename - await replace_asset_info_metadata_projection( - session, asset_info_id=asset_info_id, user_metadata=new_meta - ) - touched = True - # end of adding metadata["filename"] - - if tags is not None: - await set_asset_info_tags( - session, - asset_info_id=asset_info_id, - tags=tags, - origin=tag_origin, - ) - touched = True - - if touched and user_metadata is None: - info.updated_at = utcnow() - await session.flush() - - return info - - -async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool: - """Delete the user-visible AssetInfo row. Cascades clear tags and metadata.""" - res = await session.execute(delete(AssetInfo).where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - )) - return bool(res.rowcount) - - -async def replace_asset_info_metadata_projection( - session: AsyncSession, - *, - asset_info_id: str, - user_metadata: Optional[dict], -) -> None: - """Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`.""" - info = await session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - info.user_metadata = user_metadata or {} - info.updated_at = utcnow() - await session.flush() - - await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) - await session.flush() - - if not user_metadata: - return - - rows: list[AssetInfoMeta] = [] - for k, v in user_metadata.items(): - for r in _helpers.project_kv(k, v): - rows.append( - AssetInfoMeta( - asset_info_id=asset_info_id, - key=r["key"], - ordinal=int(r["ordinal"]), - val_str=r.get("val_str"), - val_num=r.get("val_num"), - val_bool=r.get("val_bool"), - val_json=r.get("val_json"), - ) - ) - if rows: - session.add_all(rows) - await session.flush() - - -async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]: - return [ - tag_name - for (tag_name,) in ( - await session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - ] - - -async def list_tags_with_usage( - session: AsyncSession, - *, - prefix: Optional[str] = None, - limit: int = 100, - offset: int = 0, - include_zero: bool = True, - order: str = "count_desc", # "count_desc" | "name_asc" - owner_id: str = "", -) -> tuple[list[tuple[str, str, int]], int]: - # Subquery with counts by tag_name and owner_id - counts_sq = ( - select( - AssetInfoTag.tag_name.label("tag_name"), - func.count(AssetInfoTag.asset_info_id).label("cnt"), - ) - .select_from(AssetInfoTag) - .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) - .where(visible_owner_clause(owner_id)) - .group_by(AssetInfoTag.tag_name) - .subquery() - ) - - # Base select with LEFT JOIN so we can include zero-usage tags - q = ( - select( - Tag.name, - Tag.tag_type, - func.coalesce(counts_sq.c.cnt, 0).label("count"), - ) - .select_from(Tag) - .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) - ) - - # Prefix filter (tags are lowercase by check constraint) - if prefix: - q = q.where(Tag.name.like(prefix.strip().lower() + "%")) - - # Include_zero toggles: if False, drop zero-usage tags - if not include_zero: - q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) - - if order == "name_asc": - q = q.order_by(Tag.name.asc()) - else: # default "count_desc" - q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) - - # Total (without limit/offset, same filters) - total_q = select(func.count()).select_from(Tag) - if prefix: - total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%")) - if not include_zero: - # count only names that appear in counts subquery - total_q = total_q.where( - Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) - ) - - rows = (await session.execute(q.limit(limit).offset(offset))).all() - total = (await session.execute(total_q)).scalar_one() - - # Normalize counts to int for SQLite/Postgres consistency - rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] - return rows_norm, int(total or 0) - - -async def add_tags_to_asset_info( - session: AsyncSession, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", - create_if_missing: bool = True, - asset_info_row: Any = None, -) -> dict: - """Adds tags to an AssetInfo. - If create_if_missing=True, missing tag rows are created as 'user'. - Returns: {"added": [...], "already_present": [...], "total_tags": [...]} - """ - if not asset_info_row: - info = await session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = await get_asset_tags(session, asset_info_id=asset_info_id) - return {"added": [], "already_present": [], "total_tags": total} - - # Ensure tag rows exist if requested. - if create_if_missing: - await _helpers.ensure_tags_exist(session, norm, tag_type="user") - - # Snapshot current links - current = { - tag_name - for (tag_name,) in ( - await session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - want = set(norm) - to_add = sorted(want - current) - - if to_add: - async with session.begin_nested() as nested: - try: - session.add_all( - [ - AssetInfoTag( - asset_info_id=asset_info_id, - tag_name=t, - origin=origin, - added_at=utcnow(), - ) - for t in to_add - ] - ) - await session.flush() - except IntegrityError: - await nested.rollback() - - after = set(await get_asset_tags(session, asset_info_id=asset_info_id)) - return { - "added": sorted(((after - current) & want)), - "already_present": sorted(want & current), - "total_tags": sorted(after), - } - - -async def remove_tags_from_asset_info( - session: AsyncSession, - *, - asset_info_id: str, - tags: Sequence[str], -) -> dict: - """Removes tags from an AssetInfo. - Returns: {"removed": [...], "not_present": [...], "total_tags": [...]} - """ - info = await session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = await get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": [], "not_present": [], "total_tags": total} - - existing = { - tag_name - for (tag_name,) in ( - await session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - to_remove = sorted(set(t for t in norm if t in existing)) - not_present = sorted(set(t for t in norm if t not in existing)) - - if to_remove: - await session.execute( - delete(AssetInfoTag) - .where( - AssetInfoTag.asset_info_id == asset_info_id, - AssetInfoTag.tag_name.in_(to_remove), - ) - ) - await session.flush() - - total = await get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": to_remove, "not_present": not_present, "total_tags": total} - - -async def add_missing_tag_for_asset_hash( - session: AsyncSession, - *, - asset_hash: str, - origin: str = "automatic", -) -> int: - """Ensure every AssetInfo referencing asset_hash has the 'missing' tag. - Returns number of AssetInfos newly tagged. - """ - ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_hash == asset_hash))).scalars().all() - if not ids: - return 0 - - existing = { - asset_info_id - for (asset_info_id,) in ( - await session.execute( - select(AssetInfoTag.asset_info_id).where( - AssetInfoTag.asset_info_id.in_(ids), - AssetInfoTag.tag_name == "missing", - ) - ) - ).all() - } - to_add = [i for i in ids if i not in existing] - if not to_add: - return 0 - - now = utcnow() - session.add_all( - [ - AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now) - for i in to_add - ] - ) - await session.flush() - return len(to_add) - - -async def remove_missing_tag_for_asset_hash( - session: AsyncSession, - *, - asset_hash: str, -) -> int: - """Remove the 'missing' tag from every AssetInfo referencing asset_hash. - Returns number of link rows removed. - """ - ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_hash == asset_hash))).scalars().all() - if not ids: - return 0 - - res = await session.execute( - delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(ids), - AssetInfoTag.tag_name == "missing", - ) - ) - await session.flush() - return int(res.rowcount or 0) - - -async def list_cache_states_with_asset_under_prefixes( - session: AsyncSession, - *, - prefixes: Sequence[str], -) -> list[tuple[AssetCacheState, int]]: - """Return (AssetCacheState, size_bytes) tuples for rows whose file_path starts with any of the absolute prefixes.""" - if not prefixes: - return [] - - conds = [] - for p in prefixes: - if not p: - continue - base = os.path.abspath(p) - if not base.endswith(os.sep): - base = base + os.sep - conds.append(AssetCacheState.file_path.like(base + "%")) - - if not conds: - return [] - - rows = ( - await session.execute( - select(AssetCacheState, Asset.size_bytes) - .join(Asset, Asset.hash == AssetCacheState.asset_hash) - .where(sa.or_(*conds)) - .order_by(AssetCacheState.id.asc()) - ) - ).all() - return [(r[0], int(r[1] or 0)) for r in rows] diff --git a/app/database/services/__init__.py b/app/database/services/__init__.py new file mode 100644 index 000000000000..aed8815a67a3 --- /dev/null +++ b/app/database/services/__init__.py @@ -0,0 +1,56 @@ +from .content import ( + check_fs_asset_exists_quick, + compute_hash_and_dedup_for_cache_state, + ensure_seed_for_path, + ingest_fs_asset, + list_cache_states_with_asset_under_prefixes, + list_unhashed_candidates_under_prefixes, + list_verify_candidates_under_prefixes, + redirect_all_references_then_delete_asset, + touch_asset_infos_by_fs_path, +) +from .info import ( + add_tags_to_asset_info, + create_asset_info_for_existing_asset, + delete_asset_info_by_id, + fetch_asset_info_and_asset, + fetch_asset_info_asset_and_tags, + get_asset_tags, + list_asset_infos_page, + list_tags_with_usage, + remove_tags_from_asset_info, + replace_asset_info_metadata_projection, + set_asset_info_preview, + set_asset_info_tags, + touch_asset_info_by_id, + update_asset_info_full, +) +from .queries import ( + asset_exists_by_hash, + asset_info_exists_for_asset_id, + get_asset_by_hash, + get_asset_info_by_id, + get_cache_state_by_asset_id, + list_cache_states_by_asset_id, +) + +__all__ = [ + # queries + "asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id", + "get_cache_state_by_asset_id", + "list_cache_states_by_asset_id", + # info + "list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags", + "update_asset_info_full", "replace_asset_info_metadata_projection", + "touch_asset_info_by_id", "delete_asset_info_by_id", + "add_tags_to_asset_info", "remove_tags_from_asset_info", + "get_asset_tags", "list_tags_with_usage", "set_asset_info_preview", + "fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags", + # content + "check_fs_asset_exists_quick", "ensure_seed_for_path", + "redirect_all_references_then_delete_asset", + "compute_hash_and_dedup_for_cache_state", + "list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes", + "ingest_fs_asset", "touch_asset_infos_by_fs_path", + "list_cache_states_with_asset_under_prefixes", +] diff --git a/app/database/services/content.py b/app/database/services/content.py new file mode 100644 index 000000000000..6cf440342218 --- /dev/null +++ b/app/database/services/content.py @@ -0,0 +1,746 @@ +import logging +import os +from datetime import datetime +from typing import Any, Optional, Sequence, Union + +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.dialects import postgresql as d_pg +from sqlalchemy.dialects import sqlite as d_sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import noload + +from ..._assets_helpers import compute_model_relative_filename, normalize_tags +from ...storage import hashing as hashing_mod +from ..helpers import ( + ensure_tags_exist, + remove_missing_tag_for_asset_id, +) +from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag +from ..timeutil import utcnow +from .info import replace_asset_info_metadata_projection + + +async def check_fs_asset_exists_quick( + session: AsyncSession, + *, + file_path: str, + size_bytes: Optional[int] = None, + mtime_ns: Optional[int] = None, +) -> bool: + """Return True if a cache row exists for this absolute path and (optionally) mtime/size match.""" + locator = os.path.abspath(file_path) + + stmt = ( + sa.select(sa.literal(True)) + .select_from(AssetCacheState) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(AssetCacheState.file_path == locator) + .limit(1) + ) + + conds = [] + if mtime_ns is not None: + conds.append(AssetCacheState.mtime_ns == int(mtime_ns)) + if size_bytes is not None: + conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes))) + + if conds: + stmt = stmt.where(*conds) + + row = (await session.execute(stmt)).first() + return row is not None + + +async def ensure_seed_for_path( + session: AsyncSession, + *, + abs_path: str, + size_bytes: int, + mtime_ns: int, + info_name: str, + tags: Sequence[str], + owner_id: str = "", +) -> str: + """Ensure: Asset(hash=NULL), AssetCacheState(file_path), and AssetInfo exist for the path. Returns asset_id.""" + locator = os.path.abspath(abs_path) + now = utcnow() + + state = ( + await session.execute( + sa.select(AssetCacheState, Asset) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(AssetCacheState.file_path == locator) + .limit(1) + ) + ).first() + if state: + state_row: AssetCacheState = state[0] + asset_row: Asset = state[1] + changed = state_row.mtime_ns is None or int(state_row.mtime_ns) != int(mtime_ns) + if changed: + state_row.mtime_ns = int(mtime_ns) + state_row.needs_verify = True + if asset_row.size_bytes == 0 and size_bytes > 0: + asset_row.size_bytes = int(size_bytes) + return asset_row.id + + # Create new asset (hash=NULL) + asset = Asset(hash=None, size_bytes=int(size_bytes), mime_type=None, created_at=now) + session.add(asset) + await session.flush() # to get id + + cs = AssetCacheState(asset_id=asset.id, file_path=locator, mtime_ns=int(mtime_ns), needs_verify=False) + session.add(cs) + + info = AssetInfo( + owner_id=owner_id, + name=info_name, + asset_id=asset.id, + preview_id=None, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(info) + await session.flush() + + # Attach tags + want = normalize_tags(tags) + if want: + await ensure_tags_exist(session, want, tag_type="user") + session.add_all([ + AssetInfoTag(asset_info_id=info.id, tag_name=t, origin="automatic", added_at=now) + for t in want + ]) + + await session.flush() + return asset.id + + +async def redirect_all_references_then_delete_asset( + session: AsyncSession, + *, + duplicate_asset_id: str, + canonical_asset_id: str, +) -> None: + """ + Safely migrate all references from duplicate_asset_id to canonical_asset_id. + + - If an AssetInfo for (owner_id, name) already exists on the canonical asset, + merge tags, metadata, times, and preview, then delete the duplicate AssetInfo. + - Otherwise, simply repoint the AssetInfo.asset_id. + - Always retarget AssetCacheState rows. + - Finally delete the duplicate Asset row. + """ + if duplicate_asset_id == canonical_asset_id: + return + + # 1) Migrate AssetInfo rows one-by-one to avoid UNIQUE conflicts. + dup_infos = ( + await session.execute( + select(AssetInfo).options(noload(AssetInfo.tags)).where(AssetInfo.asset_id == duplicate_asset_id) + ) + ).unique().scalars().all() + + for info in dup_infos: + # Try to find an existing collision on canonical + existing = ( + await session.execute( + select(AssetInfo) + .options(noload(AssetInfo.tags)) + .where( + AssetInfo.asset_id == canonical_asset_id, + AssetInfo.owner_id == info.owner_id, + AssetInfo.name == info.name, + ) + .limit(1) + ) + ).unique().scalars().first() + + if existing: + # Merge metadata (prefer existing keys, fill gaps from duplicate) + merged_meta = dict(existing.user_metadata or {}) + other_meta = info.user_metadata or {} + for k, v in other_meta.items(): + if k not in merged_meta: + merged_meta[k] = v + if merged_meta != (existing.user_metadata or {}): + await replace_asset_info_metadata_projection( + session, + asset_info_id=existing.id, + user_metadata=merged_meta, + ) + + # Merge tags (union) + existing_tags = { + t for (t,) in ( + await session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == existing.id) + ) + ).all() + } + from_tags = { + t for (t,) in ( + await session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == info.id) + ) + ).all() + } + to_add = sorted(from_tags - existing_tags) + if to_add: + await ensure_tags_exist(session, to_add, tag_type="user") + now = utcnow() + session.add_all([ + AssetInfoTag(asset_info_id=existing.id, tag_name=t, origin="automatic", added_at=now) + for t in to_add + ]) + await session.flush() + + # Merge preview and times + if existing.preview_id is None and info.preview_id is not None: + existing.preview_id = info.preview_id + if info.last_access_time and ( + existing.last_access_time is None or info.last_access_time > existing.last_access_time + ): + existing.last_access_time = info.last_access_time + existing.updated_at = utcnow() + await session.flush() + + # Delete the duplicate AssetInfo (cascades will clean its tags/meta) + await session.delete(info) + await session.flush() + else: + # Simple retarget + info.asset_id = canonical_asset_id + info.updated_at = utcnow() + await session.flush() + + # 2) Repoint cache states and previews + await session.execute( + sa.update(AssetCacheState) + .where(AssetCacheState.asset_id == duplicate_asset_id) + .values(asset_id=canonical_asset_id) + ) + await session.execute( + sa.update(AssetInfo) + .where(AssetInfo.preview_id == duplicate_asset_id) + .values(preview_id=canonical_asset_id) + ) + + # 3) Remove duplicate Asset + dup = await session.get(Asset, duplicate_asset_id) + if dup: + await session.delete(dup) + await session.flush() + + +async def compute_hash_and_dedup_for_cache_state( + session: AsyncSession, + *, + state_id: int, +) -> Optional[str]: + """ + Compute hash for the given cache state, deduplicate, and settle verify cases. + + Returns the asset_id that this state ends up pointing to, or None if file disappeared. + """ + state = await session.get(AssetCacheState, state_id) + if not state: + return None + + path = state.file_path + try: + if not os.path.isfile(path): + # File vanished: drop the state. If the Asset was a seed (hash NULL) + # and has no other states, drop the Asset too. + asset = await session.get(Asset, state.asset_id) + await session.delete(state) + await session.flush() + + if asset and asset.hash is None: + remaining = ( + await session.execute( + sa.select(sa.func.count()) + .select_from(AssetCacheState) + .where(AssetCacheState.asset_id == asset.id) + ) + ).scalar_one() + if int(remaining or 0) == 0: + await session.delete(asset) + await session.flush() + return None + + digest = await hashing_mod.blake3_hash(path) + new_hash = f"blake3:{digest}" + + st = os.stat(path, follow_symlinks=True) + new_size = int(st.st_size) + mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + + # Current asset of this state + this_asset = await session.get(Asset, state.asset_id) + + # If the state got orphaned somehow (race), just reattach appropriately. + if not this_asset: + canonical = ( + await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1)) + ).scalars().first() + if canonical: + state.asset_id = canonical.id + else: + now = utcnow() + new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now) + session.add(new_asset) + await session.flush() + state.asset_id = new_asset.id + state.mtime_ns = mtime_ns + state.needs_verify = False + try: + await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id) + except Exception: + pass + await session.flush() + return state.asset_id + + # 1) Seed asset case (hash is NULL): claim or merge into canonical + if this_asset.hash is None: + canonical = ( + await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1)) + ).scalars().first() + + if canonical and canonical.id != this_asset.id: + # Merge seed asset into canonical (safe, collision-aware) + await redirect_all_references_then_delete_asset( + session, + duplicate_asset_id=this_asset.id, + canonical_asset_id=canonical.id, + ) + # Refresh state after the merge + state = await session.get(AssetCacheState, state_id) + if state: + state.mtime_ns = mtime_ns + state.needs_verify = False + try: + await remove_missing_tag_for_asset_id(session, asset_id=canonical.id) + except Exception: + pass + await session.flush() + return canonical.id + + # No canonical: try to claim the hash; handle races with a SAVEPOINT + try: + async with session.begin_nested(): + this_asset.hash = new_hash + if int(this_asset.size_bytes or 0) == 0 and new_size > 0: + this_asset.size_bytes = new_size + await session.flush() + except IntegrityError: + # Someone else claimed it concurrently; fetch canonical and merge + canonical = ( + await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1)) + ).scalars().first() + if canonical and canonical.id != this_asset.id: + await redirect_all_references_then_delete_asset( + session, + duplicate_asset_id=this_asset.id, + canonical_asset_id=canonical.id, + ) + state = await session.get(AssetCacheState, state_id) + if state: + state.mtime_ns = mtime_ns + state.needs_verify = False + try: + await remove_missing_tag_for_asset_id(session, asset_id=canonical.id) + except Exception: + pass + await session.flush() + return canonical.id + # If we got here, the integrity error was not about hash uniqueness + raise + + # Claimed successfully + state.mtime_ns = mtime_ns + state.needs_verify = False + try: + await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id) + except Exception: + pass + await session.flush() + return this_asset.id + + # 2) Verify case for hashed assets + if this_asset.hash == new_hash: + # Content unchanged; tidy up sizes/mtime + if int(this_asset.size_bytes or 0) == 0 and new_size > 0: + this_asset.size_bytes = new_size + state.mtime_ns = mtime_ns + state.needs_verify = False + try: + await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id) + except Exception: + pass + await session.flush() + return this_asset.id + + # Content changed on this path only: retarget THIS state, do not move AssetInfo rows + canonical = ( + await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1)) + ).scalars().first() + if canonical: + target_id = canonical.id + else: + now = utcnow() + new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now) + session.add(new_asset) + await session.flush() + target_id = new_asset.id + + state.asset_id = target_id + state.mtime_ns = mtime_ns + state.needs_verify = False + try: + await remove_missing_tag_for_asset_id(session, asset_id=target_id) + except Exception: + pass + await session.flush() + return target_id + + except Exception: + # Propagate; caller records the error and continues the worker. + raise + + +async def list_unhashed_candidates_under_prefixes( + session: AsyncSession, *, prefixes: Sequence[str] +) -> list[int]: + if not prefixes: + return [] + + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + conds.append(AssetCacheState.file_path.like(base + "%")) + + rows = ( + await session.execute( + sa.select(AssetCacheState.id) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(Asset.hash.is_(None)) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) + ) + ).scalars().all() + seen = set() + result: list[int] = [] + for sid in rows: + st = await session.get(AssetCacheState, sid) + if st and st.asset_id not in seen: + seen.add(st.asset_id) + result.append(sid) + return result + + +async def list_verify_candidates_under_prefixes( + session: AsyncSession, *, prefixes: Sequence[str] +) -> Union[list[int], Sequence[int]]: + if not prefixes: + return [] + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + conds.append(AssetCacheState.file_path.like(base + "%")) + + return ( + await session.execute( + sa.select(AssetCacheState.id) + .where(AssetCacheState.needs_verify.is_(True)) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.id.asc()) + ) + ).scalars().all() + + +async def ingest_fs_asset( + session: AsyncSession, + *, + asset_hash: str, + abs_path: str, + size_bytes: int, + mtime_ns: int, + mime_type: Optional[str] = None, + info_name: Optional[str] = None, + owner_id: str = "", + preview_id: Optional[str] = None, + user_metadata: Optional[dict] = None, + tags: Sequence[str] = (), + tag_origin: str = "manual", + require_existing_tags: bool = False, +) -> dict: + """ + Idempotently upsert: + - Asset by content hash (create if missing) + - AssetCacheState(file_path) pointing to asset_id + - Optionally AssetInfo + tag links and metadata projection + Returns flags and ids. + """ + locator = os.path.abspath(abs_path) + now = utcnow() + + if preview_id: + if not await session.get(Asset, preview_id): + preview_id = None + + out: dict[str, Any] = { + "asset_created": False, + "asset_updated": False, + "state_created": False, + "state_updated": False, + "asset_info_id": None, + } + + # 1) Asset by hash + asset = ( + await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + if not asset: + async with session.begin_nested(): + asset = Asset(hash=asset_hash, size_bytes=int(size_bytes), mime_type=mime_type, created_at=now) + session.add(asset) + await session.flush() + out["asset_created"] = True + else: + changed = False + if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: + asset.size_bytes = int(size_bytes) + changed = True + if mime_type and asset.mime_type != mime_type: + asset.mime_type = mime_type + changed = True + if changed: + out["asset_updated"] = True + + # 2) AssetCacheState upsert by file_path (unique) + vals = { + "asset_id": asset.id, + "file_path": locator, + "mtime_ns": int(mtime_ns), + } + dialect = session.bind.dialect.name + if dialect == "sqlite": + ins = ( + d_sqlite.insert(AssetCacheState) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(AssetCacheState) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + + res = await session.execute(ins) + if int(res.rowcount or 0) > 0: + out["state_created"] = True + else: + upd = ( + sa.update(AssetCacheState) + .where(AssetCacheState.file_path == locator) + .where( + sa.or_( + AssetCacheState.asset_id != asset.id, + AssetCacheState.mtime_ns.is_(None), + AssetCacheState.mtime_ns != int(mtime_ns), + ) + ) + .values(asset_id=asset.id, mtime_ns=int(mtime_ns)) + ) + res2 = await session.execute(upd) + if int(res2.rowcount or 0) > 0: + out["state_updated"] = True + + # 3) Optional AssetInfo + tags + metadata + if info_name: + # upsert by (asset_id, owner_id, name) + try: + async with session.begin_nested(): + info = AssetInfo( + owner_id=owner_id, + name=info_name, + asset_id=asset.id, + preview_id=preview_id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(info) + await session.flush() + out["asset_info_id"] = info.id + except IntegrityError: + pass + + existing_info = ( + await session.execute( + select(AssetInfo) + .where( + AssetInfo.asset_id == asset.id, + AssetInfo.name == info_name, + (AssetInfo.owner_id == owner_id), + ) + .limit(1) + ) + ).unique().scalar_one_or_none() + if not existing_info: + raise RuntimeError("Failed to update or insert AssetInfo.") + + if preview_id and existing_info.preview_id != preview_id: + existing_info.preview_id = preview_id + + existing_info.updated_at = now + if existing_info.last_access_time < now: + existing_info.last_access_time = now + await session.flush() + out["asset_info_id"] = existing_info.id + + norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] + if norm and out["asset_info_id"] is not None: + if not require_existing_tags: + await ensure_tags_exist(session, norm, tag_type="user") + + existing_tag_names = set( + name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() + ) + missing = [t for t in norm if t not in existing_tag_names] + if missing and require_existing_tags: + raise ValueError(f"Unknown tags: {missing}") + + existing_links = set( + tag_name + for (tag_name,) in ( + await session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) + ) + ).all() + ) + to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] + if to_add: + session.add_all( + [ + AssetInfoTag( + asset_info_id=out["asset_info_id"], + tag_name=t, + origin=tag_origin, + added_at=now, + ) + for t in to_add + ] + ) + await session.flush() + + # metadata["filename"] hack + if out["asset_info_id"] is not None: + primary_path = ( + await session.execute( + select(AssetCacheState.file_path) + .where(AssetCacheState.asset_id == asset.id) + .order_by(AssetCacheState.id.asc()) + .limit(1) + ) + ).scalars().first() + computed_filename = compute_model_relative_filename(primary_path) if primary_path else None + + current_meta = existing_info.user_metadata or {} + new_meta = dict(current_meta) + if user_metadata is not None: + for k, v in user_metadata.items(): + new_meta[k] = v + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta != current_meta: + await replace_asset_info_metadata_projection( + session, + asset_info_id=out["asset_info_id"], + user_metadata=new_meta, + ) + + try: + await remove_missing_tag_for_asset_id(session, asset_id=asset.id) + except Exception: + logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) + return out + + +async def touch_asset_infos_by_fs_path( + session: AsyncSession, + *, + file_path: str, + ts: Optional[datetime] = None, + only_if_newer: bool = True, +) -> int: + locator = os.path.abspath(file_path) + ts = ts or utcnow() + + stmt = sa.update(AssetInfo).where( + sa.exists( + sa.select(sa.literal(1)) + .select_from(AssetCacheState) + .where( + AssetCacheState.asset_id == AssetInfo.asset_id, + AssetCacheState.file_path == locator, + ) + ) + ) + + if only_if_newer: + stmt = stmt.where( + sa.or_( + AssetInfo.last_access_time.is_(None), + AssetInfo.last_access_time < ts, + ) + ) + + stmt = stmt.values(last_access_time=ts) + + res = await session.execute(stmt) + return int(res.rowcount or 0) + + +async def list_cache_states_with_asset_under_prefixes( + session: AsyncSession, + *, + prefixes: Sequence[str], +) -> list[tuple[AssetCacheState, Optional[str], int]]: + """Return (AssetCacheState, asset_hash, size_bytes) for rows under any prefix.""" + if not prefixes: + return [] + + conds = [] + for p in prefixes: + if not p: + continue + base = os.path.abspath(p) + if not base.endswith(os.sep): + base = base + os.sep + conds.append(AssetCacheState.file_path.like(base + "%")) + + if not conds: + return [] + + rows = ( + await session.execute( + select(AssetCacheState, Asset.hash, Asset.size_bytes) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.id.asc()) + ) + ).all() + return [(r[0], r[1], int(r[2] or 0)) for r in rows] diff --git a/app/database/services/info.py b/app/database/services/info.py new file mode 100644 index 000000000000..e3da1bc8ee3b --- /dev/null +++ b/app/database/services/info.py @@ -0,0 +1,579 @@ +from collections import defaultdict +from datetime import datetime +from typing import Any, Optional, Sequence + +import sqlalchemy as sa +from sqlalchemy import delete, func, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import contains_eager, noload + +from ..._assets_helpers import compute_model_relative_filename, normalize_tags +from ..helpers import ( + apply_metadata_filter, + apply_tag_filters, + ensure_tags_exist, + project_kv, + visible_owner_clause, +) +from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag +from ..timeutil import utcnow +from .queries import get_asset_by_hash, get_cache_state_by_asset_id + + +async def list_asset_infos_page( + session: AsyncSession, + *, + owner_id: str = "", + include_tags: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + name_contains: Optional[str] = None, + metadata_filter: Optional[dict] = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", +) -> tuple[list[AssetInfo], dict[str, list[str]], int]: + base = ( + select(AssetInfo) + .join(Asset, Asset.id == AssetInfo.asset_id) + .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags)) + .where(visible_owner_clause(owner_id)) + ) + + if name_contains: + base = base.where(AssetInfo.name.ilike(f"%{name_contains}%")) + + base = apply_tag_filters(base, include_tags, exclude_tags) + base = apply_metadata_filter(base, metadata_filter) + + sort = (sort or "created_at").lower() + order = (order or "desc").lower() + sort_map = { + "name": AssetInfo.name, + "created_at": AssetInfo.created_at, + "updated_at": AssetInfo.updated_at, + "last_access_time": AssetInfo.last_access_time, + "size": Asset.size_bytes, + } + sort_col = sort_map.get(sort, AssetInfo.created_at) + sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + + base = base.order_by(sort_exp).limit(limit).offset(offset) + + count_stmt = ( + select(func.count()) + .select_from(AssetInfo) + .join(Asset, Asset.id == AssetInfo.asset_id) + .where(visible_owner_clause(owner_id)) + ) + if name_contains: + count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%")) + count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = apply_metadata_filter(count_stmt, metadata_filter) + + total = int((await session.execute(count_stmt)).scalar_one() or 0) + + infos = (await session.execute(base)).unique().scalars().all() + + id_list: list[str] = [i.id for i in infos] + tag_map: dict[str, list[str]] = defaultdict(list) + if id_list: + rows = await session.execute( + select(AssetInfoTag.asset_info_id, Tag.name) + .join(Tag, Tag.name == AssetInfoTag.tag_name) + .where(AssetInfoTag.asset_info_id.in_(id_list)) + ) + for aid, tag_name in rows.all(): + tag_map[aid].append(tag_name) + + return infos, tag_map, total + + +async def fetch_asset_info_and_asset( + session: AsyncSession, + *, + asset_info_id: str, + owner_id: str = "", +) -> Optional[tuple[AssetInfo, Asset]]: + stmt = ( + select(AssetInfo, Asset) + .join(Asset, Asset.id == AssetInfo.asset_id) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + .limit(1) + .options(noload(AssetInfo.tags)) + ) + row = await session.execute(stmt) + pair = row.first() + if not pair: + return None + return pair[0], pair[1] + + +async def fetch_asset_info_asset_and_tags( + session: AsyncSession, + *, + asset_info_id: str, + owner_id: str = "", +) -> Optional[tuple[AssetInfo, Asset, list[str]]]: + stmt = ( + select(AssetInfo, Asset, Tag.name) + .join(Asset, Asset.id == AssetInfo.asset_id) + .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) + .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + .options(noload(AssetInfo.tags)) + .order_by(Tag.name.asc()) + ) + + rows = (await session.execute(stmt)).all() + if not rows: + return None + + first_info, first_asset, _ = rows[0] + tags: list[str] = [] + seen: set[str] = set() + for _info, _asset, tag_name in rows: + if tag_name and tag_name not in seen: + seen.add(tag_name) + tags.append(tag_name) + return first_info, first_asset, tags + + +async def create_asset_info_for_existing_asset( + session: AsyncSession, + *, + asset_hash: str, + name: str, + user_metadata: Optional[dict] = None, + tags: Optional[Sequence[str]] = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> AssetInfo: + """Create or return an existing AssetInfo for an Asset identified by asset_hash.""" + now = utcnow() + asset = await get_asset_by_hash(session, asset_hash=asset_hash) + if not asset: + raise ValueError(f"Unknown asset hash {asset_hash}") + + info = AssetInfo( + owner_id=owner_id, + name=name, + asset_id=asset.id, + preview_id=None, + created_at=now, + updated_at=now, + last_access_time=now, + ) + try: + async with session.begin_nested(): + session.add(info) + await session.flush() + except IntegrityError: + existing = ( + await session.execute( + select(AssetInfo) + .options(noload(AssetInfo.tags)) + .where( + AssetInfo.asset_id == asset.id, + AssetInfo.name == name, + AssetInfo.owner_id == owner_id, + ) + .limit(1) + ) + ).unique().scalars().first() + if not existing: + raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") + return existing + + # metadata["filename"] hack + new_meta = dict(user_metadata or {}) + computed_filename = None + try: + state = await get_cache_state_by_asset_id(session, asset_id=asset.id) + if state and state.file_path: + computed_filename = compute_model_relative_filename(state.file_path) + except Exception: + computed_filename = None + if computed_filename: + new_meta["filename"] = computed_filename + if new_meta: + await replace_asset_info_metadata_projection( + session, + asset_info_id=info.id, + user_metadata=new_meta, + ) + + if tags is not None: + await set_asset_info_tags( + session, + asset_info_id=info.id, + tags=tags, + origin=tag_origin, + ) + return info + + +async def set_asset_info_tags( + session: AsyncSession, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", +) -> dict: + desired = normalize_tags(tags) + + current = set( + tag_name for (tag_name,) in ( + await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) + ).all() + ) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + await ensure_tags_exist(session, to_add, tag_type="user") + session.add_all([ + AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) + for t in to_add + ]) + await session.flush() + + if to_remove: + await session.execute( + delete(AssetInfoTag) + .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) + ) + await session.flush() + + return {"added": to_add, "removed": to_remove, "total": desired} + + +async def update_asset_info_full( + session: AsyncSession, + *, + asset_info_id: str, + name: Optional[str] = None, + tags: Optional[Sequence[str]] = None, + user_metadata: Optional[dict] = None, + tag_origin: str = "manual", + asset_info_row: Any = None, +) -> AssetInfo: + if not asset_info_row: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + else: + info = asset_info_row + + touched = False + if name is not None and name != info.name: + info.name = name + touched = True + + computed_filename = None + try: + state = await get_cache_state_by_asset_id(session, asset_id=info.asset_id) + if state and state.file_path: + computed_filename = compute_model_relative_filename(state.file_path) + except Exception: + computed_filename = None + + if user_metadata is not None: + new_meta = dict(user_metadata) + if computed_filename: + new_meta["filename"] = computed_filename + await replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=new_meta + ) + touched = True + else: + if computed_filename: + current_meta = info.user_metadata or {} + if current_meta.get("filename") != computed_filename: + new_meta = dict(current_meta) + new_meta["filename"] = computed_filename + await replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=new_meta + ) + touched = True + + if tags is not None: + await set_asset_info_tags( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=tag_origin, + ) + touched = True + + if touched and user_metadata is None: + info.updated_at = utcnow() + await session.flush() + + return info + + +async def replace_asset_info_metadata_projection( + session: AsyncSession, + *, + asset_info_id: str, + user_metadata: Optional[dict], +) -> None: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + info.user_metadata = user_metadata or {} + info.updated_at = utcnow() + await session.flush() + + await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) + await session.flush() + + if not user_metadata: + return + + rows: list[AssetInfoMeta] = [] + for k, v in user_metadata.items(): + for r in project_kv(k, v): + rows.append( + AssetInfoMeta( + asset_info_id=asset_info_id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + await session.flush() + + +async def touch_asset_info_by_id( + session: AsyncSession, + *, + asset_info_id: str, + ts: Optional[datetime] = None, + only_if_newer: bool = True, +) -> int: + ts = ts or utcnow() + stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) + if only_if_newer: + stmt = stmt.where( + sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) + ) + stmt = stmt.values(last_access_time=ts) + res = await session.execute(stmt) + return int(res.rowcount or 0) + + +async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool: + res = await session.execute(delete(AssetInfo).where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + )) + return bool(res.rowcount) + + +async def add_tags_to_asset_info( + session: AsyncSession, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", + create_if_missing: bool = True, + asset_info_row: Any = None, +) -> dict: + if not asset_info_row: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"added": [], "already_present": [], "total_tags": total} + + if create_if_missing: + await ensure_tags_exist(session, norm, tag_type="user") + + current = { + tag_name + for (tag_name,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + want = set(norm) + to_add = sorted(want - current) + + if to_add: + async with session.begin_nested() as nested: + try: + session.add_all( + [ + AssetInfoTag( + asset_info_id=asset_info_id, + tag_name=t, + origin=origin, + added_at=utcnow(), + ) + for t in to_add + ] + ) + await session.flush() + except IntegrityError: + await nested.rollback() + + after = set(await get_asset_tags(session, asset_info_id=asset_info_id)) + return { + "added": sorted(((after - current) & want)), + "already_present": sorted(want & current), + "total_tags": sorted(after), + } + + +async def remove_tags_from_asset_info( + session: AsyncSession, + *, + asset_info_id: str, + tags: Sequence[str], +) -> dict: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": [], "not_present": [], "total_tags": total} + + existing = { + tag_name + for (tag_name,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + await session.execute( + delete(AssetInfoTag) + .where( + AssetInfoTag.asset_info_id == asset_info_id, + AssetInfoTag.tag_name.in_(to_remove), + ) + ) + await session.flush() + + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": to_remove, "not_present": not_present, "total_tags": total} + + +async def list_tags_with_usage( + session: AsyncSession, + *, + prefix: Optional[str] = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", + owner_id: str = "", +) -> tuple[list[tuple[str, str, int]], int]: + counts_sq = ( + select( + AssetInfoTag.tag_name.label("tag_name"), + func.count(AssetInfoTag.asset_info_id).label("cnt"), + ) + .select_from(AssetInfoTag) + .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) + .where(visible_owner_clause(owner_id)) + .group_by(AssetInfoTag.tag_name) + .subquery() + ) + + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + if prefix: + q = q.where(Tag.name.like(prefix.strip().lower() + "%")) + + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + total_q = select(func.count()).select_from(Tag) + if prefix: + total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%")) + if not include_zero: + total_q = total_q.where( + Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) + ) + + rows = (await session.execute(q.limit(limit).offset(offset))).all() + total = (await session.execute(total_q)).scalar_one() + + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) + + +async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]: + return [ + tag_name + for (tag_name,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + ] + + +async def set_asset_info_preview( + session: AsyncSession, + *, + asset_info_id: str, + preview_asset_id: Optional[str], +) -> None: + """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + if preview_asset_id is None: + info.preview_id = None + else: + # validate preview asset exists + if not await session.get(Asset, preview_asset_id): + raise ValueError(f"Preview Asset {preview_asset_id} not found") + info.preview_id = preview_asset_id + + info.updated_at = utcnow() + await session.flush() diff --git a/app/database/services/queries.py b/app/database/services/queries.py new file mode 100644 index 000000000000..81649b7f4c10 --- /dev/null +++ b/app/database/services/queries.py @@ -0,0 +1,59 @@ +from typing import Optional, Sequence, Union + +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from ..models import Asset, AssetCacheState, AssetInfo + + +async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: + row = ( + await session.execute( + select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).first() + return row is not None + + +async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]: + return ( + await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + + +async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]: + return await session.get(AssetInfo, asset_info_id) + + +async def asset_info_exists_for_asset_id(session: AsyncSession, *, asset_id: str) -> bool: + q = ( + select(sa.literal(True)) + .select_from(AssetInfo) + .where(AssetInfo.asset_id == asset_id) + .limit(1) + ) + return (await session.execute(q)).first() is not None + + +async def get_cache_state_by_asset_id(session: AsyncSession, *, asset_id: str) -> Optional[AssetCacheState]: + return ( + await session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_id == asset_id) + .order_by(AssetCacheState.id.asc()) + .limit(1) + ) + ).scalars().first() + + +async def list_cache_states_by_asset_id( + session: AsyncSession, *, asset_id: str +) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]: + return ( + await session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_id == asset_id) + .order_by(AssetCacheState.id.asc()) + ) + ).scalars().all() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 5e301b505de0..d814e453a5b3 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -212,7 +212,6 @@ def is_valid_directory(path: str) -> str: os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.") -parser.add_argument("--enable-model-processing", action="store_true", help="Enable automatic processing of the model file, such as calculating hashes and populating the database.") parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.") if comfy.options.args_parsing: diff --git a/main.py b/main.py index 3485a7c76c1d..db0ee04f5059 100644 --- a/main.py +++ b/main.py @@ -279,11 +279,11 @@ def cleanup_temp(): shutil.rmtree(temp_dir, ignore_errors=True) async def setup_database(): - from app import init_db_engine, start_background_assets_scan + from app import init_db_engine, sync_seed_assets await init_db_engine() if not args.disable_assets_autoscan: - await start_background_assets_scan() + await sync_seed_assets(["models", "input", "output"]) def start_comfyui(asyncio_loop=None): diff --git a/server.py b/server.py index d3a0f8628c68..ddd188ebcb78 100644 --- a/server.py +++ b/server.py @@ -37,7 +37,7 @@ from app.custom_node_manager import CustomNodeManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes -from app.api.assets_routes import register_assets_system +from app import sync_seed_assets, register_assets_system from protocol import BinaryEventTypes async def send_socket_catch_exception(function, message): @@ -629,6 +629,7 @@ def node_info(node_class): @routes.get("/object_info") async def get_object_info(request): + await sync_seed_assets(["models"]) with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: diff --git a/tests-assets/test_crud.py b/tests-assets/test_crud.py index 99ea329c52d9..1e59281509f2 100644 --- a/tests-assets/test_crud.py +++ b/tests-assets/test_crud.py @@ -118,6 +118,16 @@ async def test_head_asset_by_hash(http: aiohttp.ClientSession, api_base: str, se assert rh2.status == 404 +@pytest.mark.asyncio +async def test_head_asset_bad_hash_returns_400_and_no_body(http: aiohttp.ClientSession, api_base: str): + # Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload. + # aiohttp exposes an empty body for HEAD, so validate status and that there is no payload. + async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh: + assert rh.status == 400 + body = await rh.read() + assert body == b"" + + @pytest.mark.asyncio async def test_delete_nonexistent_returns_404(http: aiohttp.ClientSession, api_base: str): bogus = str(uuid.uuid4()) @@ -166,12 +176,3 @@ async def test_update_requires_at_least_one_field(http: aiohttp.ClientSession, a body = await r.json() assert r.status == 400 assert body["error"]["code"] == "INVALID_BODY" - - -@pytest.mark.asyncio -async def test_head_asset_bad_hash(http: aiohttp.ClientSession, api_base: str): - # Invalid format - async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh3: - jb = await rh3.json() - assert rh3.status == 400 - assert jb is None # HEAD request should not include "body" in response diff --git a/tests-assets/test_tags.py b/tests-assets/test_tags.py index bba91581fb23..aede764daf43 100644 --- a/tests-assets/test_tags.py +++ b/tests-assets/test_tags.py @@ -66,23 +66,32 @@ async def test_add_and_remove_tags(http: aiohttp.ClientSession, api_base: str, s async with http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add) as r1: b1 = await r1.json() assert r1.status == 200, b1 - # normalized and deduplicated - assert "newtag" in b1["added"] or "beta" in b1["added"] or "unit-tests" not in b1["added"] + # normalized, deduplicated; 'unit-tests' was already present from the seed + assert set(b1["added"]) == {"newtag", "beta"} + assert set(b1["already_present"]) == {"unit-tests"} + assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"] async with http.get(f"{api_base}/api/assets/{aid}") as rg: g = await rg.json() assert rg.status == 200 tags_now = set(g["tags"]) - assert "newtag" in tags_now - assert "beta" in tags_now + assert {"newtag", "beta"}.issubset(tags_now) # Remove a tag and a non-existent tag payload_del = {"tags": ["newtag", "does-not-exist"]} async with http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del) as r2: b2 = await r2.json() assert r2.status == 200 - assert "newtag" in b2["removed"] - assert "does-not-exist" in b2["not_present"] + assert set(b2["removed"]) == {"newtag"} + assert set(b2["not_present"]) == {"does-not-exist"} + + # Verify remaining tags after deletion + async with http.get(f"{api_base}/api/assets/{aid}") as rg2: + g2 = await rg2.json() + assert rg2.status == 200 + tags_later = set(g2["tags"]) + assert "newtag" not in tags_later + assert "beta" in tags_later # still present @pytest.mark.asyncio diff --git a/tests-assets/test_uploads.py b/tests-assets/test_uploads.py index 1d8df4e40f3b..3bfb62ca4b32 100644 --- a/tests-assets/test_uploads.py +++ b/tests-assets/test_uploads.py @@ -206,7 +206,7 @@ async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_b body = await r.json() assert r.status == 400 assert body["error"]["code"] == "INVALID_BODY" - assert "unknown models category" in body["error"]["message"] or "unknown model category" in body["error"]["message"] + assert body["error"]["message"].startswith("unknown models category") @pytest.mark.asyncio From 9b8e88ba6e14141f8545ca4d2abe6caad34bcb93 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sat, 13 Sep 2025 20:09:45 +0300 Subject: [PATCH 47/82] added more tests for the Assets logic --- server.py | 2 +- tests-assets/conftest.py | 33 ++++ tests-assets/test_assets_missing_sync.py | 196 +++++++++++++++++++++++ 3 files changed, 230 insertions(+), 1 deletion(-) create mode 100644 tests-assets/test_assets_missing_sync.py diff --git a/server.py b/server.py index ddd188ebcb78..c3a688a75499 100644 --- a/server.py +++ b/server.py @@ -629,7 +629,7 @@ def node_info(node_class): @routes.get("/object_info") async def get_object_info(request): - await sync_seed_assets(["models"]) + await sync_seed_assets(["models", "input", "output"]) with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: diff --git a/tests-assets/conftest.py b/tests-assets/conftest.py index 88fc0a5f30c8..c133195db26f 100644 --- a/tests-assets/conftest.py +++ b/tests-assets/conftest.py @@ -268,3 +268,36 @@ async def autoclean_unit_test_assets(http: aiohttp.ClientSession, api_base: str) with contextlib.suppress(Exception): async with http.delete(f"{api_base}/api/assets/{aid}") as dr: await dr.read() + + +async def trigger_sync_seed_assets(session: aiohttp.ClientSession, base_url: str) -> None: + """Force a fast sync/seed pass by calling the ComfyUI '/object_info' endpoint.""" + async with session.get(base_url + "/object_info") as r: + await r.read() + await asyncio.sleep(0.05) # tiny yield to the event loop to let any final DB commits flush + + +@pytest_asyncio.fixture +async def run_scan_and_wait(http: aiohttp.ClientSession, api_base: str): + """Schedule an asset scan for a given root and wait until it finishes.""" + async def _run(root: str, timeout: float = 120.0): + async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r: + # we ignore body; scheduling returns 202 with a status payload + await r.read() + + start = time.time() + while True: + async with http.get(api_base + "/api/assets/scan", params={"root": root}) as st: + body = await st.json() + scans = (body or {}).get("scans", []) + status = None + if scans: + status = scans[-1].get("status") + if status in {"completed", "failed", "cancelled"}: + if status != "completed": + raise RuntimeError(f"Scan for root={root} finished with status={status}") + return + if time.time() - start > timeout: + raise TimeoutError(f"Timed out waiting for scan of root={root}") + await asyncio.sleep(0.1) + return _run diff --git a/tests-assets/test_assets_missing_sync.py b/tests-assets/test_assets_missing_sync.py new file mode 100644 index 000000000000..d73e500a7272 --- /dev/null +++ b/tests-assets/test_assets_missing_sync.py @@ -0,0 +1,196 @@ +from pathlib import Path +import uuid + +import aiohttp +import pytest + +from conftest import trigger_sync_seed_assets + + +@pytest.mark.asyncio +async def test_seed_asset_removed_when_file_is_deleted( + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, +): + """Asset without hash (seed) whose file disappears: + after triggering sync_seed_assets, Asset + AssetInfo disappear. + """ + # Create a file directly under input/unit-tests/ so tags include "unit-tests" + case_dir = comfy_tmp_base_dir / "input" / "unit-tests" / "syncseed" + case_dir.mkdir(parents=True, exist_ok=True) + name = f"seed_{uuid.uuid4().hex[:8]}.bin" + fp = case_dir / name + fp.write_bytes(b"Z" * 2048) + + # Trigger a seed sync so DB sees this path (seed asset => hash is NULL) + await trigger_sync_seed_assets(http, api_base) + + # Verify it is visible via API and carries no hash (seed) + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,syncseed", "name_contains": name}, + ) as r1: + body1 = await r1.json() + assert r1.status == 200 + # there should be exactly one with that name + matches = [a for a in body1.get("assets", []) if a.get("name") == name] + assert matches + assert matches[0].get("asset_hash") is None + asset_info_id = matches[0]["id"] + + # Remove the underlying file and sync again + if fp.exists(): + fp.unlink() + + await trigger_sync_seed_assets(http, api_base) + + # It should disappear (AssetInfo and seed Asset gone) + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,syncseed", "name_contains": name}, + ) as r2: + body2 = await r2.json() + assert r2.status == 200 + matches2 = [a for a in body2.get("assets", []) if a.get("name") == name] + assert not matches2, f"Seed asset {asset_info_id} should be gone after sync" + + +@pytest.mark.asyncio +async def test_hashed_asset_missing_tag_added_then_removed_after_scan( + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """Hashed asset with a single cache_state: + 1. delete its file -> sync adds 'missing' + 2. restore file -> scan removes 'missing' + """ + name = "missing_tag_test.png" + tags = ["input", "unit-tests", "msync2"] + data = make_asset_bytes(name, 4096) + a = await asset_factory(name, tags, {}, data) + + # Compute its on-disk path and remove it + dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / name + assert dest.exists(), f"Expected asset file at {dest}" + dest.unlink() + + # Fast sync should add 'missing' to the AssetInfo + await trigger_sync_seed_assets(http, api_base) + + async with http.get(f"{api_base}/api/assets/{a['id']}") as g1: + d1 = await g1.json() + assert g1.status == 200, d1 + assert "missing" in set(d1.get("tags", [])), "Expected 'missing' tag after deletion" + + # Restore the file with the exact same content and re-hash/verify via scan + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(data) + + await run_scan_and_wait("input") + + async with http.get(f"{api_base}/api/assets/{a['id']}") as g2: + d2 = await g2.json() + assert g2.status == 200, d2 + assert "missing" not in set(d2.get("tags", [])), "Missing tag should be cleared after verify" + + +@pytest.mark.asyncio +async def test_hashed_asset_two_assetinfos_both_get_missing( + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, +): + """Hashed asset with a single cache_state, but two AssetInfo rows: + deleting the single file then syncing should add 'missing' to both infos. + """ + # Upload one hashed asset + name = "two_infos_one_path.png" + base_tags = ["input", "unit-tests", "multiinfo"] + created = await asset_factory(name, base_tags, {}, b"A" * 2048) + + # Create second AssetInfo for the same Asset via from-hash + payload = { + "hash": created["asset_hash"], + "name": "two_infos_one_path_copy.png", + "tags": base_tags, # keep it in our unit-tests scope for cleanup + "user_metadata": {"k": "v"}, + } + async with http.post(api_base + "/api/assets/from-hash", json=payload) as r2: + b2 = await r2.json() + assert r2.status == 201, b2 + second_id = b2["id"] + + # Remove the single underlying file + p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / name + assert p.exists() + p.unlink() + + # Sync -> both AssetInfos for this asset must receive 'missing' + await trigger_sync_seed_assets(http, api_base) + + async with http.get(f"{api_base}/api/assets/{created['id']}") as ga: + da = await ga.json() + assert ga.status == 200, da + assert "missing" in set(da.get("tags", [])) + + async with http.get(f"{api_base}/api/assets/{second_id}") as gb: + db = await gb.json() + assert gb.status == 200, db + assert "missing" in set(db.get("tags", [])) + + +@pytest.mark.asyncio +async def test_hashed_asset_two_cache_states_partial_delete_then_full_delete( + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """Hashed asset with two cache_state rows: + 1. delete one file -> sync should NOT add 'missing' + 2. delete second file -> sync should add 'missing' + """ + name = "two_cache_states_partial_delete.png" + tags = ["input", "unit-tests", "dual"] + data = make_asset_bytes(name, 3072) + + created = await asset_factory(name, tags, {}, data) + path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / name + assert path1.exists() + + # Create a second on-disk copy under the same root but different subfolder + path2 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual_copy" / name + path2.parent.mkdir(parents=True, exist_ok=True) + path2.write_bytes(data) + + # Fast seed so the second path appears (as a seed initially) + await trigger_sync_seed_assets(http, api_base) + + # Now run a 'models' scan so the seed copy is hashed and deduped + await run_scan_and_wait("input") + + # Remove only one file and sync -> asset should still be healthy (no 'missing') + path1.unlink() + await trigger_sync_seed_assets(http, api_base) + + async with http.get(f"{api_base}/api/assets/{created['id']}") as g1: + d1 = await g1.json() + assert g1.status == 200, d1 + assert "missing" not in set(d1.get("tags", [])), "Should not be missing while one valid path remains" + + # Remove the second (last) file and sync -> now we expect 'missing' + path2.unlink() + await trigger_sync_seed_assets(http, api_base) + + async with http.get(f"{api_base}/api/assets/{created['id']}") as g2: + d2 = await g2.json() + assert g2.status == 200, d2 + assert "missing" in set(d2.get("tags", [])), "Missing must be set once no valid paths remain" From 4a713654cd6be098ae104daf5c2a731520e031ca Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sat, 13 Sep 2025 21:12:33 +0300 Subject: [PATCH 48/82] added more tests for the Assets logic --- app/__init__.py | 2 +- app/database/services/content.py | 10 +- tests-assets/conftest.py | 2 +- tests-assets/test_assets_missing_sync.py | 3 +- tests-assets/test_crud.py | 1 + tests-assets/test_downloads.py | 1 + tests-assets/test_metadata_filters.py | 1 + tests-assets/test_scans.py | 464 +++++++++++++++++++++++ tests-assets/test_tags.py | 107 +++++- 9 files changed, 558 insertions(+), 33 deletions(-) create mode 100644 tests-assets/test_scans.py diff --git a/app/__init__.py b/app/__init__.py index e8538bd29fb6..f73951107da2 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,5 +1,5 @@ +from .api.assets_routes import register_assets_system from .assets_scanner import sync_seed_assets from .database.db import init_db_engine -from .api.assets_routes import register_assets_system __all__ = ["init_db_engine", "sync_seed_assets", "register_assets_system"] diff --git a/app/database/services/content.py b/app/database/services/content.py index 6cf440342218..f8c43abfd4ea 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -86,7 +86,6 @@ async def ensure_seed_for_path( asset_row.size_bytes = int(size_bytes) return asset_row.id - # Create new asset (hash=NULL) asset = Asset(hash=None, size_bytes=int(size_bytes), mime_type=None, created_at=now) session.add(asset) await session.flush() # to get id @@ -106,7 +105,6 @@ async def ensure_seed_for_path( session.add(info) await session.flush() - # Attach tags want = normalize_tags(tags) if want: await ensure_tags_exist(session, want, tag_type="user") @@ -160,7 +158,6 @@ async def redirect_all_references_then_delete_asset( ).unique().scalars().first() if existing: - # Merge metadata (prefer existing keys, fill gaps from duplicate) merged_meta = dict(existing.user_metadata or {}) other_meta = info.user_metadata or {} for k, v in other_meta.items(): @@ -173,7 +170,6 @@ async def redirect_all_references_then_delete_asset( user_metadata=merged_meta, ) - # Merge tags (union) existing_tags = { t for (t,) in ( await session.execute( @@ -198,7 +194,6 @@ async def redirect_all_references_then_delete_asset( ]) await session.flush() - # Merge preview and times if existing.preview_id is None and info.preview_id is not None: existing.preview_id = info.preview_id if info.last_access_time and ( @@ -253,8 +248,7 @@ async def compute_hash_and_dedup_for_cache_state( path = state.file_path try: if not os.path.isfile(path): - # File vanished: drop the state. If the Asset was a seed (hash NULL) - # and has no other states, drop the Asset too. + # File vanished: drop the state. If the Asset has hash=NULL and has no other states, drop the Asset too. asset = await session.get(Asset, state.asset_id) await session.delete(state) await session.flush() @@ -372,7 +366,6 @@ async def compute_hash_and_dedup_for_cache_state( # 2) Verify case for hashed assets if this_asset.hash == new_hash: - # Content unchanged; tidy up sizes/mtime if int(this_asset.size_bytes or 0) == 0 and new_size > 0: this_asset.size_bytes = new_size state.mtime_ns = mtime_ns @@ -569,7 +562,6 @@ async def ingest_fs_asset( # 3) Optional AssetInfo + tags + metadata if info_name: - # upsert by (asset_id, owner_id, name) try: async with session.begin_nested(): info = AssetInfo( diff --git a/tests-assets/conftest.py b/tests-assets/conftest.py index c133195db26f..84f8747874f3 100644 --- a/tests-assets/conftest.py +++ b/tests-assets/conftest.py @@ -3,6 +3,7 @@ import json import os import socket +import subprocess import sys import tempfile import time @@ -12,7 +13,6 @@ import aiohttp import pytest import pytest_asyncio -import subprocess def pytest_addoption(parser: pytest.Parser) -> None: diff --git a/tests-assets/test_assets_missing_sync.py b/tests-assets/test_assets_missing_sync.py index d73e500a7272..aec6606d60c0 100644 --- a/tests-assets/test_assets_missing_sync.py +++ b/tests-assets/test_assets_missing_sync.py @@ -1,9 +1,8 @@ -from pathlib import Path import uuid +from pathlib import Path import aiohttp import pytest - from conftest import trigger_sync_seed_assets diff --git a/tests-assets/test_crud.py b/tests-assets/test_crud.py index 1e59281509f2..8836cc686f05 100644 --- a/tests-assets/test_crud.py +++ b/tests-assets/test_crud.py @@ -1,4 +1,5 @@ import uuid + import aiohttp import pytest diff --git a/tests-assets/test_downloads.py b/tests-assets/test_downloads.py index 7a449dfe8cfc..9cb1b9486036 100644 --- a/tests-assets/test_downloads.py +++ b/tests-assets/test_downloads.py @@ -1,4 +1,5 @@ from pathlib import Path + import aiohttp import pytest diff --git a/tests-assets/test_metadata_filters.py b/tests-assets/test_metadata_filters.py index 39d00fa2df1f..360a24527d55 100644 --- a/tests-assets/test_metadata_filters.py +++ b/tests-assets/test_metadata_filters.py @@ -1,4 +1,5 @@ import json + import aiohttp import pytest diff --git a/tests-assets/test_scans.py b/tests-assets/test_scans.py new file mode 100644 index 000000000000..7774569cadd7 --- /dev/null +++ b/tests-assets/test_scans.py @@ -0,0 +1,464 @@ +import asyncio +import os +import uuid +from pathlib import Path + +import aiohttp +import pytest +from conftest import trigger_sync_seed_assets + + +def _base_for(root: str, comfy_tmp_base_dir: Path) -> Path: + assert root in ("input", "output") + return comfy_tmp_base_dir / root + + +def _mkbytes(label: str, size: int) -> bytes: + seed = sum(label.encode("utf-8")) % 251 + return bytes((i * 31 + seed) % 256 for i in range(size)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_scan_schedule_idempotent_while_running( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + run_scan_and_wait, +): + """Idempotent schedule while running.""" + scope = f"idem-{uuid.uuid4().hex[:6]}" + base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope + base.mkdir(parents=True, exist_ok=True) + + # Create several seed files (non-zero) to ensure the scan runs long enough + for i in range(8): + (base / f"f{i}.bin").write_bytes(_mkbytes(f"{scope}-{i}", 2 * 1024 * 1024)) # ~2 MiB each + + # Seed -> states with hash=NULL + await trigger_sync_seed_assets(http, api_base) + + # Schedule once + async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r1: + b1 = await r1.json() + assert r1.status == 202, b1 + scans1 = {s["root"]: s for s in b1.get("scans", [])} + s1 = scans1.get(root) + assert s1 and s1["status"] in {"scheduled", "running"} + sid1 = s1["scan_id"] + + # Schedule again immediately — must return the same scan entry (no new worker) + async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r2: + b2 = await r2.json() + assert r2.status == 202, b2 + scans2 = {s["root"]: s for s in b2.get("scans", [])} + s2 = scans2.get(root) + assert s2 and s2["scan_id"] == sid1 + + # Filtered GET must show exactly one scan for this root + async with http.get(api_base + "/api/assets/scan", params={"root": root}) as gs: + bs = await gs.json() + assert gs.status == 200, bs + scans = bs.get("scans", []) + assert len(scans) == 1 and scans[0]["scan_id"] == sid1 + + # Let it finish to avoid cross-test interference + await run_scan_and_wait(root) + + +@pytest.mark.asyncio +async def test_scan_status_filter_by_root_and_file_errors( + http, + api_base: str, + comfy_tmp_base_dir: Path, + run_scan_and_wait, + asset_factory, +): + """Filtering get scan status by root (schedule for both input and output) + file_errors presence.""" + # Create one hashed asset in input under a dir we will chmod to 000 to force PermissionError in reconcile stage + in_scope = f"filter-in-{uuid.uuid4().hex[:6]}" + protected_dir = _base_for("input", comfy_tmp_base_dir) / "unit-tests" / in_scope / "deny" + protected_dir.mkdir(parents=True, exist_ok=True) + name_in = "protected.bin" + + data = b"A" * 4096 + await asset_factory(name_in, ["input", "unit-tests", in_scope, "deny"], {}, data) + try: + os.chmod(protected_dir, 0x000) + + # Also schedule a scan for output root (no errors there) + out_scope = f"filter-out-{uuid.uuid4().hex[:6]}" + out_dir = _base_for("output", comfy_tmp_base_dir) / "unit-tests" / out_scope + out_dir.mkdir(parents=True, exist_ok=True) + (out_dir / "ok.bin").write_bytes(b"B" * 1024) + await trigger_sync_seed_assets(http, api_base) # seed output file + + # Schedule both roots + async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": ["input"]}) as r_in: + assert r_in.status == 202 + async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": ["output"]}) as r_out: + assert r_out.status == 202 + + # Wait both to complete, input last (we want its errors) + await run_scan_and_wait("output") + await run_scan_and_wait("input") + + # Filter by root=input: only input scan listed and must have file_errors + async with http.get(api_base + "/api/assets/scan", params={"root": "input"}) as gs: + body = await gs.json() + assert gs.status == 200, body + scans = body.get("scans", []) + assert len(scans) == 1 + errs = scans[0].get("file_errors", []) + # Must contain at least one error with a message + assert errs and any(e.get("message") for e in errs) + finally: + # Restore perms so cleanup can remove files/dirs + try: + os.chmod(protected_dir, 0o755) + except Exception: + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +@pytest.mark.skipif(os.name == "nt", reason="Permission-based file_errors are unreliable on Windows") +async def test_scan_records_file_errors_permission_denied( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + run_scan_and_wait, +): + """file_errors recording (permission denied) for input/output""" + scope = f"errs-{uuid.uuid4().hex[:6]}" + deny_dir = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / "deny" + deny_dir.mkdir(parents=True, exist_ok=True) + name = "deny.bin" + + await asset_factory(name, [root, "unit-tests", scope, "deny"], {}, b"X" * 2048) + try: + os.chmod(deny_dir, 0x000) + async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r: + assert r.status == 202 + await run_scan_and_wait(root) + + async with http.get(api_base + "/api/assets/scan", params={"root": root}) as gs: + body = await gs.json() + assert gs.status == 200, body + scans = body.get("scans", []) + assert len(scans) == 1 + errs = scans[0].get("file_errors", []) + # Should contain at least one PermissionError-like record + assert errs and any(e.get("path", "").endswith(name) and e.get("message") for e in errs) + finally: + try: + os.chmod(deny_dir, 0o755) + except Exception: + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_missing_tag_created_and_visible_in_tags( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, +): + """Missing tag appears in tags list and increments count (input/output)""" + # Baseline count of 'missing' tag (may be absent) + async with http.get(api_base + "/api/tags", params={"limit": "1000"}) as r0: + t0 = await r0.json() + assert r0.status == 200, t0 + byname = {t["name"]: t for t in t0.get("tags", [])} + old_count = int(byname.get("missing", {}).get("count", 0)) + + scope = f"miss-{uuid.uuid4().hex[:6]}" + name = "missing_me.bin" + created = await asset_factory(name, [root, "unit-tests", scope], {}, b"Y" * 4096) + + # Remove the only file and trigger fast pass + p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / name + assert p.exists() + p.unlink() + await trigger_sync_seed_assets(http, api_base) + + # Asset has 'missing' tag + async with http.get(f"{api_base}/api/assets/{created['id']}") as g1: + d1 = await g1.json() + assert g1.status == 200, d1 + assert "missing" in set(d1.get("tags", [])) + + # Tag list now contains 'missing' with increased count + async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1: + t1 = await r1.json() + assert r1.status == 200, t1 + byname1 = {t["name"]: t for t in t1.get("tags", [])} + assert "missing" in byname1 + assert int(byname1["missing"]["count"]) >= old_count + 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_missing_reapplies_after_manual_removal( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, +): + """Manual removal of 'missing' does not block automatic re-apply (input/output)""" + scope = f"reapply-{uuid.uuid4().hex[:6]}" + name = "reapply.bin" + created = await asset_factory(name, [root, "unit-tests", scope], {}, b"Z" * 1024) + + # Make it missing + p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / name + p.unlink() + await trigger_sync_seed_assets(http, api_base) + + # Remove the 'missing' tag manually + async with http.delete(f"{api_base}/api/assets/{created['id']}/tags", json={"tags": ["missing"]}) as rdel: + b = await rdel.json() + assert rdel.status == 200, b + assert "missing" in set(b.get("removed", [])) + + # Next sync must re-add it + await trigger_sync_seed_assets(http, api_base) + async with http.get(f"{api_base}/api/assets/{created['id']}") as g2: + d2 = await g2.json() + assert g2.status == 200, d2 + assert "missing" in set(d2.get("tags", [])) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_delete_one_assetinfo_of_missing_asset_keeps_identity( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, +): + """Delete one AssetInfo of a missing asset while another exists (input/output)""" + scope = f"twoinfos-{uuid.uuid4().hex[:6]}" + name = "twoinfos.bin" + a1 = await asset_factory(name, [root, "unit-tests", scope], {}, b"W" * 2048) + + # Second AssetInfo for the same content under same root (different name to avoid collision) + a2 = await asset_factory("copy_" + name, [root, "unit-tests", scope], {}, b"W" * 2048) + + # Remove file of the first (both point to the same Asset, but we know on-disk path name for a1) + p1 = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / name + p1.unlink() + await trigger_sync_seed_assets(http, api_base) + + # Both infos should be marked missing + async with http.get(f"{api_base}/api/assets/{a1['id']}") as g1: + d1 = await g1.json(); assert "missing" in set(d1.get("tags", [])) + async with http.get(f"{api_base}/api/assets/{a2['id']}") as g2: + d2 = await g2.json(); assert "missing" in set(d2.get("tags", [])) + + # Delete one info + async with http.delete(f"{api_base}/api/assets/{a1['id']}") as rd: + assert rd.status == 204 + + # Asset identity still exists (by hash) + h = a1["asset_hash"] + async with http.head(f"{api_base}/api/assets/hash/{h}") as rh: + assert rh.status == 200 + + # Remaining info still reflects 'missing' + async with http.get(f"{api_base}/api/assets/{a2['id']}") as g3: + d3 = await g3.json() + assert g3.status == 200 and "missing" in set(d3.get("tags", [])) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("keep_root", ["input", "output"]) +async def test_delete_last_assetinfo_false_keeps_asset_and_states_multiroot( + keep_root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + make_asset_bytes, + asset_factory, +): + """Delete last AssetInfo with delete_content_if_orphan=false keeps asset and the underlying on-disk content.""" + other_root = "output" if keep_root == "input" else "input" + scope = f"delfalse-{uuid.uuid4().hex[:6]}" + name1, name2 = "keep1.bin", "keep2.bin" + data = make_asset_bytes(scope, 3072) + + # First upload creates the physical file + a1 = await asset_factory(name1, [keep_root, "unit-tests", scope], {}, data) + # Second upload (other root) is deduped to the same content; no new file on disk + a2 = await asset_factory(name2, [other_root, "unit-tests", scope], {}, data) + + h = a1["asset_hash"] + p1 = _base_for(keep_root, comfy_tmp_base_dir) / "unit-tests" / scope / name1 + p2 = _base_for(other_root, comfy_tmp_base_dir) / "unit-tests" / scope / name2 + + # De-dup semantics: only the first physical file exists + assert p1.exists(), "Expected the first physical file to exist" + assert not p2.exists(), "Second duplicate must not create another physical file" + + # Delete both AssetInfos; keep content on the very last delete + async with http.delete(f"{api_base}/api/assets/{a2['id']}") as rfirst: + assert rfirst.status == 204 + async with http.delete(f"{api_base}/api/assets/{a1['id']}?delete_content=false") as rlast: + assert rlast.status == 204 + + # Asset identity remains and physical content is still present + async with http.head(f"{api_base}/api/assets/hash/{h}") as rh: + assert rh.status == 200 + assert p1.exists(), "Content file should remain after keep-content delete" + assert not p2.exists(), "There was never a second physical file" + + # Cleanup: re-create a reference by hash and then delete to purge content + payload = { + "hash": h, + "name": "cleanup.bin", + "tags": [keep_root, "unit-tests", scope, "cleanup"], + "user_metadata": {}, + } + async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as rfh: + ref = await rfh.json() + assert rfh.status == 201, ref + cid = ref["id"] + async with http.delete(f"{api_base}/api/assets/{cid}") as rdel: + assert rdel.status == 204 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_sync_seed_ignores_zero_byte_files( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, +): + scope = f"zero-{uuid.uuid4().hex[:6]}" + base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope + base.mkdir(parents=True, exist_ok=True) + z = base / "empty.dat" + z.write_bytes(b"") # zero bytes + + await trigger_sync_seed_assets(http, api_base) + + # No AssetInfo created for this zero-byte file + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests," + scope, "name_contains": "empty.dat"}, + ) as r: + body = await r.json() + assert r.status == 200, body + assert not [a for a in body.get("assets", []) if a.get("name") == "empty.dat"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_sync_seed_idempotency( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, +): + scope = f"idemseed-{uuid.uuid4().hex[:6]}" + base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope + base.mkdir(parents=True, exist_ok=True) + files = [f"f{i}.dat" for i in range(3)] + for i, n in enumerate(files): + (base / n).write_bytes(_mkbytes(n, 1500 + i * 10)) + + await trigger_sync_seed_assets(http, api_base) + async with http.get(api_base + "/api/assets", params={"include_tags": "unit-tests," + scope}) as r1: + b1 = await r1.json() + assert r1.status == 200, b1 + c1 = len(b1.get("assets", [])) + + # Seed again -> count must stay the same + await trigger_sync_seed_assets(http, api_base) + async with http.get(api_base + "/api/assets", params={"include_tags": "unit-tests," + scope}) as r2: + b2 = await r2.json() + assert r2.status == 200, b2 + c2 = len(b2.get("assets", [])) + assert c1 == c2 == len(files) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_sync_seed_nested_dirs_produce_parent_tags( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, +): + scope = f"nest-{uuid.uuid4().hex[:6]}" + # nested: unit-tests / scope / a / b / c / deep.txt + deep_dir = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / "a" / "b" / "c" + deep_dir.mkdir(parents=True, exist_ok=True) + (deep_dir / "deep.txt").write_bytes(b"content") + + await trigger_sync_seed_assets(http, api_base) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": "deep.txt"}, + ) as r: + body = await r.json() + assert r.status == 200, body + assets = body.get("assets", []) + assert assets, "seeded asset not found" + tags = set(assets[0].get("tags", [])) + # Must include all parent parts as tags + the root + for must in {root, "unit-tests", scope, "a", "b", "c"}: + assert must in tags + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_concurrent_seed_hashing_same_file_no_dupes( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + run_scan_and_wait, +): + """ + Create a single seed file, then schedule two scans back-to-back. + Expect: no duplicate AssetInfos, a single hashed asset, and no scan failure. + """ + scope = f"conc-seed-{uuid.uuid4().hex[:6]}" + name = "seed_concurrent.bin" + + base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope + base.mkdir(parents=True, exist_ok=True) + (base / name).write_bytes(b"Z" * 2048) + + await trigger_sync_seed_assets(http, api_base) + + s1, s2 = await asyncio.gather( + http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}), + http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}), + ) + await s1.read() + await s2.read() + assert s1.status in (200, 202) + assert s2.status in (200, 202) + + await run_scan_and_wait(root) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": name}, + ) as r: + b = await r.json() + assert r.status == 200, b + matches = [a for a in b.get("assets", []) if a.get("name") == name] + assert len(matches) == 1 + assert matches[0].get("asset_hash"), "Seed should have been hashed into an Asset" diff --git a/tests-assets/test_tags.py b/tests-assets/test_tags.py index aede764daf43..9ad3c3f860a6 100644 --- a/tests-assets/test_tags.py +++ b/tests-assets/test_tags.py @@ -1,4 +1,6 @@ import json +import uuid + import aiohttp import pytest @@ -40,21 +42,49 @@ async def test_tags_present(http: aiohttp.ClientSession, api_base: str, seeded_a @pytest.mark.asyncio -async def test_tags_empty_usage(http: aiohttp.ClientSession, api_base: str): - # Include zero-usage tags by default - async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1: +async def test_tags_empty_usage(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes): + # Baseline: system tags exist when include_zero (default) is true + async with http.get(api_base + "/api/tags", params={"limit": "500"}) as r1: body1 = await r1.json() assert r1.status == 200 names = [t["name"] for t in body1["tags"]] - # A few system tags from migration should exist: - assert "models" in names - assert "checkpoints" in names - - # With include_zero=False there should be no tags returned for the database without Assets. - async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2: + assert "models" in names and "checkpoints" in names + + # Create a short-lived asset under input with a unique custom tag + scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}" + custom_tag = f"temp-{uuid.uuid4().hex[:8]}" + name = "tag_seed.bin" + _asset = await asset_factory( + name, + ["input", "unit-tests", scope, custom_tag], + {}, + make_asset_bytes(name, 512), + ) + + # While the asset exists, the custom tag must appear when include_zero=false + async with http.get( + api_base + "/api/tags", + params={"include_zero": "false", "prefix": custom_tag, "limit": "50"}, + ) as r2: body2 = await r2.json() assert r2.status == 200 - assert not [t["name"] for t in body2["tags"]] + used_names = [t["name"] for t in body2["tags"]] + assert custom_tag in used_names + + # Delete the asset so the tag usage drops to zero + async with http.delete(f"{api_base}/api/assets/{_asset['id']}") as rd: + assert rd.status == 204 + + # Now the custom tag must not be returned when include_zero=false + async with http.get( + api_base + "/api/tags", + params={"include_zero": "false", "prefix": custom_tag, "limit": "50"}, + ) as r3: + body3 = await r3.json() + assert r3.status == 200 + names_after = [t["name"] for t in body3["tags"]] + assert custom_tag not in names_after + assert not names_after # filtered view should be empty now @pytest.mark.asyncio @@ -96,18 +126,55 @@ async def test_add_and_remove_tags(http: aiohttp.ClientSession, api_base: str, s @pytest.mark.asyncio async def test_tags_list_order_and_prefix(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): - # name ascending - async with http.get(api_base + "/api/tags", params={"order": "name_asc", "limit": "100"}) as r1: + aid = seeded_asset["id"] + h = seeded_asset["asset_hash"] + + # Add both tags to the seeded asset (usage: orderaaa=1, orderbbb=1) + async with http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": ["orderaaa", "orderbbb"]}) as r_add: + add_body = await r_add.json() + assert r_add.status == 200, add_body + + # Create another AssetInfo from the same content but tagged ONLY with 'orderbbb'. + payload = { + "hash": h, + "name": "order_only_bbb.safetensors", + "tags": ["input", "unit-tests", "orderbbb"], + "user_metadata": {}, + } + async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r_copy: + copy_body = await r_copy.json() + assert r_copy.status == 201, copy_body + + # 1) Default order (count_desc): 'orderbbb' should come before 'orderaaa' + # because it has higher usage (2 vs 1). + async with http.get(api_base + "/api/tags", params={"prefix": "order", "include_zero": "false"}) as r1: b1 = await r1.json() - assert r1.status == 200 - names = [t["name"] for t in b1["tags"]] - assert names == sorted(names) - - # invalid limit rejected - async with http.get(api_base + "/api/tags", params={"limit": "1001"}) as r2: + assert r1.status == 200, b1 + names1 = [t["name"] for t in b1["tags"]] + counts1 = {t["name"]: t["count"] for t in b1["tags"]} + # Both must be present within the prefix subset + assert "orderaaa" in names1 and "orderbbb" in names1 + # Usage of 'orderbbb' must be >= 'orderaaa'; in our setup it's 2 vs 1 + assert counts1["orderbbb"] >= counts1["orderaaa"] + # And with count_desc, 'orderbbb' appears earlier than 'orderaaa' + assert names1.index("orderbbb") < names1.index("orderaaa") + + # 2) name_asc: lexical order should flip the relative order + async with http.get( + api_base + "/api/tags", + params={"prefix": "order", "include_zero": "false", "order": "name_asc"}, + ) as r2: b2 = await r2.json() - assert r2.status == 400 - assert b2["error"]["code"] == "INVALID_QUERY" + assert r2.status == 200, b2 + names2 = [t["name"] for t in b2["tags"]] + assert "orderaaa" in names2 and "orderbbb" in names2 + assert names2.index("orderaaa") < names2.index("orderbbb") + + # 3) invalid limit rejected (existing negative case retained) + async with http.get(api_base + "/api/tags", params={"limit": "1001"}) as r3: + b3 = await r3.json() + assert r3.status == 400 + assert b3["error"]["code"] == "INVALID_QUERY" @pytest.mark.asyncio From 975650060f743cdff5a9f2bc08d9dd5f14b60135 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 14 Sep 2025 09:35:47 +0300 Subject: [PATCH 49/82] concurrency upload test + fixed 2 related bugs --- app/database/helpers/tags.py | 30 +++++++++++++++---- app/database/services/content.py | 34 +++++++++++++++++---- tests-assets/test_uploads.py | 51 ++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 11 deletions(-) diff --git a/app/database/helpers/tags.py b/app/database/helpers/tags.py index 47934309654a..c8f9e50749b4 100644 --- a/app/database/helpers/tags.py +++ b/app/database/helpers/tags.py @@ -1,6 +1,8 @@ from typing import Iterable from sqlalchemy import delete, select +from sqlalchemy.dialects import postgresql as d_pg +from sqlalchemy.dialects import sqlite as d_sqlite from sqlalchemy.ext.asyncio import AsyncSession from ..._assets_helpers import normalize_tags @@ -13,13 +15,29 @@ async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_typ if not wanted: return [] existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() + existing_names = {t.name for t in existing} + missing = [n for n in wanted if n not in existing_names] + if missing: + dialect = session.bind.dialect.name + rows = [{"name": n, "tag_type": tag_type} for n in missing] + if dialect == "sqlite": + ins = ( + d_sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + await session.execute(ins) + existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() by_name = {t.name: t for t in existing} - to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name] - if to_create: - session.add_all(to_create) - await session.flush() - by_name.update({t.name: t for t in to_create}) - return [by_name[n] for n in wanted] + return [by_name[n] for n in wanted if n in by_name] async def add_missing_tag_for_asset_id( diff --git a/app/database/services/content.py b/app/database/services/content.py index f8c43abfd4ea..8388e524d6c8 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -484,6 +484,7 @@ async def ingest_fs_asset( """ locator = os.path.abspath(abs_path) now = utcnow() + dialect = session.bind.dialect.name if preview_id: if not await session.get(Asset, preview_id): @@ -502,10 +503,34 @@ async def ingest_fs_asset( await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) ).scalars().first() if not asset: - async with session.begin_nested(): - asset = Asset(hash=asset_hash, size_bytes=int(size_bytes), mime_type=mime_type, created_at=now) - session.add(asset) - await session.flush() + vals = { + "hash": asset_hash, + "size_bytes": int(size_bytes), + "mime_type": mime_type, + "created_at": now, + } + if dialect == "sqlite": + ins = ( + d_sqlite.insert(Asset) + .values(**vals) + .on_conflict_do_nothing(index_elements=[Asset.hash]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(Asset) + .values(**vals) + .on_conflict_do_nothing(index_elements=[Asset.hash]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + res = await session.execute(ins) + rowcount = int(res.rowcount or 0) + asset = ( + await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + if not asset: + raise RuntimeError("Asset row not found after upsert.") + if rowcount > 0: out["asset_created"] = True else: changed = False @@ -524,7 +549,6 @@ async def ingest_fs_asset( "file_path": locator, "mtime_ns": int(mtime_ns), } - dialect = session.bind.dialect.name if dialect == "sqlite": ins = ( d_sqlite.insert(AssetCacheState) diff --git a/tests-assets/test_uploads.py b/tests-assets/test_uploads.py index 3bfb62ca4b32..c1a962d263d4 100644 --- a/tests-assets/test_uploads.py +++ b/tests-assets/test_uploads.py @@ -1,4 +1,7 @@ +import asyncio import json +import uuid + import aiohttp import pytest @@ -125,6 +128,54 @@ async def test_upload_multiple_tags_fields_are_merged(http: aiohttp.ClientSessio assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags) +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_concurrent_upload_identical_bytes_different_names( + root: str, + http: aiohttp.ClientSession, + api_base: str, + make_asset_bytes, +): + """ + Two concurrent uploads of identical bytes but different names. + Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True. + """ + scope = f"concupload-{uuid.uuid4().hex[:6]}" + name1, name2 = "cu_a.bin", "cu_b.bin" + data = make_asset_bytes("concurrent", 4096) + tags = [root, "unit-tests", scope] + + def _form(name: str) -> aiohttp.FormData: + f = aiohttp.FormData() + f.add_field("file", data, filename=name, content_type="application/octet-stream") + f.add_field("tags", json.dumps(tags)) + f.add_field("name", name) + f.add_field("user_metadata", json.dumps({})) + return f + + r1, r2 = await asyncio.gather( + http.post(api_base + "/api/assets", data=_form(name1)), + http.post(api_base + "/api/assets", data=_form(name2)), + ) + b1, b2 = await r1.json(), await r2.json() + assert r1.status in (200, 201), b1 + assert r2.status in (200, 201), b2 + assert b1["asset_hash"] == b2["asset_hash"] + assert b1["id"] != b2["id"] + + created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))]) + assert created_flags == [False, True] + + async with http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "sort": "name"}, + ) as rl: + bl = await rl.json() + assert rl.status == 200, bl + names = [a["name"] for a in bl.get("assets", [])] + assert set([name1, name2]).issubset(names) + + @pytest.mark.asyncio async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str): payload = { From 37b81e66580f2903ef761c4ef768792deb63fe55 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 14 Sep 2025 10:12:33 +0300 Subject: [PATCH 50/82] fixed new PgSQL bug --- app/api/assets_routes.py | 62 ++++++++++++++++++++++++-------- app/database/services/content.py | 18 ++++++++-- 2 files changed, 63 insertions(+), 17 deletions(-) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 384c9f6c0a53..3c3ea5d257d3 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -1,4 +1,5 @@ import contextlib +import logging import os import urllib.parse import uuid @@ -13,7 +14,8 @@ from . import schemas_in, schemas_out ROUTES = web.RouteTableDef() -UserManager: Optional[user_manager.UserManager] = None +USER_MANAGER: Optional[user_manager.UserManager] = None +LOGGER = logging.getLogger(__name__) # UUID regex (canonical hyphenated form, case-insensitive) UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" @@ -58,7 +60,7 @@ async def list_assets(request: web.Request) -> web.Response: offset=q.offset, sort=q.sort, order=q.order, - owner_id=UserManager.get_request_user_id(request), + owner_id=USER_MANAGER.get_request_user_id(request), ) return web.json_response(payload.model_dump(mode="json")) @@ -72,7 +74,7 @@ async def download_asset_content(request: web.Request) -> web.Response: try: abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download( asset_info_id=str(uuid.UUID(request.match_info["id"])), - owner_id=UserManager.get_request_user_id(request), + owner_id=USER_MANAGER.get_request_user_id(request), ) except ValueError as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve)) @@ -105,7 +107,7 @@ async def create_asset_from_hash(request: web.Request) -> web.Response: name=body.name, tags=body.tags, user_metadata=body.user_metadata, - owner_id=UserManager.get_request_user_id(request), + owner_id=USER_MANAGER.get_request_user_id(request), ) if result is None: return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") @@ -234,7 +236,7 @@ async def upload_asset(request: web.Request) -> web.Response: 400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'" ) - owner_id = UserManager.get_request_user_id(request) + owner_id = USER_MANAGER.get_request_user_id(request) # Fast path: if a valid provided hash exists, create AssetInfo without writing anything if spec.hash and provided_hash_exists is True: @@ -247,6 +249,7 @@ async def upload_asset(request: web.Request) -> web.Response: owner_id=owner_id, ) except Exception: + LOGGER.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id) return _error_response(500, "INTERNAL", "Unexpected server error.") if result is None: @@ -289,6 +292,7 @@ async def upload_asset(request: web.Request) -> web.Response: except Exception: if tmp_path and os.path.exists(tmp_path): os.remove(tmp_path) + LOGGER.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id) return _error_response(500, "INTERNAL", "Unexpected server error.") @@ -298,11 +302,16 @@ async def get_asset(request: web.Request) -> web.Response: try: result = await assets_manager.get_asset( asset_info_id=asset_info_id, - owner_id=UserManager.get_request_user_id(request), + owner_id=USER_MANAGER.get_request_user_id(request), ) except ValueError as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) except Exception: + LOGGER.exception( + "get_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) return _error_response(500, "INTERNAL", "Unexpected server error.") return web.json_response(result.model_dump(mode="json"), status=200) @@ -323,11 +332,16 @@ async def update_asset(request: web.Request) -> web.Response: name=body.name, tags=body.tags, user_metadata=body.user_metadata, - owner_id=UserManager.get_request_user_id(request), + owner_id=USER_MANAGER.get_request_user_id(request), ) except (ValueError, PermissionError) as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) except Exception: + LOGGER.exception( + "update_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) return _error_response(500, "INTERNAL", "Unexpected server error.") return web.json_response(result.model_dump(mode="json"), status=200) @@ -346,11 +360,16 @@ async def set_asset_preview(request: web.Request) -> web.Response: result = await assets_manager.set_asset_preview( asset_info_id=asset_info_id, preview_asset_id=body.preview_id, - owner_id=UserManager.get_request_user_id(request), + owner_id=USER_MANAGER.get_request_user_id(request), ) except (PermissionError, ValueError) as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) except Exception: + LOGGER.exception( + "set_asset_preview failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) return _error_response(500, "INTERNAL", "Unexpected server error.") return web.json_response(result.model_dump(mode="json"), status=200) @@ -364,10 +383,15 @@ async def delete_asset(request: web.Request) -> web.Response: try: deleted = await assets_manager.delete_asset_reference( asset_info_id=asset_info_id, - owner_id=UserManager.get_request_user_id(request), + owner_id=USER_MANAGER.get_request_user_id(request), delete_content_if_orphan=delete_content, ) except Exception: + LOGGER.exception( + "delete_asset_reference failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) return _error_response(500, "INTERNAL", "Unexpected server error.") if not deleted: @@ -393,7 +417,7 @@ async def get_tags(request: web.Request) -> web.Response: offset=query.offset, order=query.order, include_zero=query.include_zero, - owner_id=UserManager.get_request_user_id(request), + owner_id=USER_MANAGER.get_request_user_id(request), ) return web.json_response(result.model_dump(mode="json")) @@ -414,11 +438,16 @@ async def add_asset_tags(request: web.Request) -> web.Response: asset_info_id=asset_info_id, tags=data.tags, origin="manual", - owner_id=UserManager.get_request_user_id(request), + owner_id=USER_MANAGER.get_request_user_id(request), ) except (ValueError, PermissionError) as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) except Exception: + LOGGER.exception( + "add_tags_to_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) return _error_response(500, "INTERNAL", "Unexpected server error.") return web.json_response(result.model_dump(mode="json"), status=200) @@ -439,11 +468,16 @@ async def delete_asset_tags(request: web.Request) -> web.Response: result = await assets_manager.remove_tags_from_asset( asset_info_id=asset_info_id, tags=data.tags, - owner_id=UserManager.get_request_user_id(request), + owner_id=USER_MANAGER.get_request_user_id(request), ) except ValueError as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) except Exception: + LOGGER.exception( + "remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) return _error_response(500, "INTERNAL", "Unexpected server error.") return web.json_response(result.model_dump(mode="json"), status=200) @@ -476,8 +510,8 @@ async def get_asset_scan_status(request: web.Request) -> web.Response: def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: - global UserManager - UserManager = user_manager_instance + global USER_MANAGER + USER_MANAGER = user_manager_instance app.add_routes(ROUTES) diff --git a/app/database/services/content.py b/app/database/services/content.py index 8388e524d6c8..8547a1bffafc 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -514,24 +514,36 @@ async def ingest_fs_asset( d_sqlite.insert(Asset) .values(**vals) .on_conflict_do_nothing(index_elements=[Asset.hash]) + .returning(Asset.id) ) elif dialect == "postgresql": ins = ( d_pg.insert(Asset) .values(**vals) - .on_conflict_do_nothing(index_elements=[Asset.hash]) + .on_conflict_do_nothing( + index_elements=[Asset.hash], + index_where=Asset.__table__.c.hash.isnot(None), + ) + .returning(Asset.id) ) else: raise NotImplementedError(f"Unsupported database dialect: {dialect}") res = await session.execute(ins) - rowcount = int(res.rowcount or 0) + inserted_id = res.scalar_one_or_none() asset = ( await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) ).scalars().first() if not asset: raise RuntimeError("Asset row not found after upsert.") - if rowcount > 0: + if inserted_id: out["asset_created"] = True + asset = await session.get(Asset, inserted_id) + else: + asset = ( + await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + if not asset: + raise RuntimeError("Asset row not found after upsert.") else: changed = False if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: From cdd8d160757aa3a00cf0d01a72e3eed7156026df Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 14 Sep 2025 14:52:42 +0300 Subject: [PATCH 51/82] +2 tests for checking Asset downloading logic --- tests-assets/test_downloads.py | 117 +++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/tests-assets/test_downloads.py b/tests-assets/test_downloads.py index 9cb1b9486036..cb8b36220177 100644 --- a/tests-assets/test_downloads.py +++ b/tests-assets/test_downloads.py @@ -1,7 +1,12 @@ +import asyncio +import uuid +from datetime import datetime from pathlib import Path +from typing import Optional import aiohttp import pytest +from conftest import trigger_sync_seed_assets @pytest.mark.asyncio @@ -24,6 +29,73 @@ async def test_download_attachment_and_inline(http: aiohttp.ClientSession, api_b assert "inline" in cd2 +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_download_chooses_existing_state_and_updates_access_time( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + Hashed asset with two state paths: if the first one disappears, + GET /content still serves from the remaining path and bumps last_access_time. + """ + scope = f"dl-first-{uuid.uuid4().hex[:6]}" + name = "first_existing_state.bin" + data = make_asset_bytes(name, 3072) + + # Upload -> path1 + a = await asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + + base = comfy_tmp_base_dir / root / "unit-tests" / scope + path1 = base / name + assert path1.exists() + + # Seed path2 by copying, then scan to dedupe into a second state + path2 = base / "alt" / name + path2.parent.mkdir(parents=True, exist_ok=True) + path2.write_bytes(data) + await trigger_sync_seed_assets(http, api_base) + await run_scan_and_wait(root) + + # Remove path1 so server must fall back to path2 + path1.unlink() + + # last_access_time before + async with http.get(f"{api_base}/api/assets/{aid}") as rg0: + d0 = await rg0.json() + assert rg0.status == 200, d0 + ts0 = d0.get("last_access_time") + + await asyncio.sleep(0.05) + async with http.get(f"{api_base}/api/assets/{aid}/content") as r: + blob = await r.read() + assert r.status == 200 + assert blob == data # must serve from the surviving state (same bytes) + + async with http.get(f"{api_base}/api/assets/{aid}") as rg1: + d1 = await rg1.json() + assert rg1.status == 200, d1 + ts1 = d1.get("last_access_time") + + def _parse_iso8601(s: Optional[str]) -> Optional[float]: + if not s: + return None + s = s[:-1] if s.endswith("Z") else s + return datetime.fromisoformat(s).timestamp() + + t0 = _parse_iso8601(ts0) + t1 = _parse_iso8601(ts1) + assert t1 is not None + if t0 is not None: + assert t1 > t0 + + @pytest.mark.asyncio @pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True) async def test_download_missing_file_returns_404( @@ -49,3 +121,48 @@ async def test_download_missing_file_returns_404( # We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually. async with http.delete(f"{api_base}/api/assets/{aid}") as dr: await dr.read() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_download_404_if_all_states_missing( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """Multi-state asset: after the last remaining on-disk file is removed, download must return 404.""" + scope = f"dl-404-{uuid.uuid4().hex[:6]}" + name = "missing_all_states.bin" + data = make_asset_bytes(name, 2048) + + # Upload -> path1 + a = await asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + + base = comfy_tmp_base_dir / root / "unit-tests" / scope + p1 = base / name + assert p1.exists() + + # Seed a second state and dedupe + p2 = base / "copy" / name + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + await trigger_sync_seed_assets(http, api_base) + await run_scan_and_wait(root) + + # Remove first file -> download should still work via the second state + p1.unlink() + async with http.get(f"{api_base}/api/assets/{aid}/content") as ok1: + b1 = await ok1.read() + assert ok1.status == 200 and b1 == data + + # Remove the last file -> download must 404 + p2.unlink() + async with http.get(f"{api_base}/api/assets/{aid}/content") as r2: + body = await r2.json() + assert r2.status == 404 + assert body["error"]["code"] == "FILE_NOT_FOUND" From 47f7c7ee8cf7de5f8dd05673101e8bf501bfc625 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 14 Sep 2025 15:00:32 +0300 Subject: [PATCH 52/82] rework + add test for concurrent AssetInfo delete --- app/database/services/info.py | 11 +++++----- tests-assets/test_crud.py | 41 +++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/app/database/services/info.py b/app/database/services/info.py index e3da1bc8ee3b..687431d595ff 100644 --- a/app/database/services/info.py +++ b/app/database/services/info.py @@ -379,11 +379,12 @@ async def touch_asset_info_by_id( async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool: - res = await session.execute(delete(AssetInfo).where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - )) - return bool(res.rowcount) + return ( + await session.execute(delete(AssetInfo).where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ).returning(AssetInfo.id)) + ).scalar_one_or_none() is not None async def add_tags_to_asset_info( diff --git a/tests-assets/test_crud.py b/tests-assets/test_crud.py index 8836cc686f05..ba7f23f6763a 100644 --- a/tests-assets/test_crud.py +++ b/tests-assets/test_crud.py @@ -1,3 +1,4 @@ +import asyncio import uuid import aiohttp @@ -177,3 +178,43 @@ async def test_update_requires_at_least_one_field(http: aiohttp.ClientSession, a body = await r.json() assert r.status == 400 assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_concurrent_delete_same_asset_info_single_204( + root: str, + http: aiohttp.ClientSession, + api_base: str, + asset_factory, + make_asset_bytes, +): + """ + Many concurrent DELETE for the same AssetInfo should result in: + - exactly one 204 No Content (the one that actually deleted) + - all others 404 Not Found (row already gone) + """ + scope = f"conc-del-{uuid.uuid4().hex[:6]}" + name = "to_delete.bin" + data = make_asset_bytes(name, 1536) + + created = await asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = created["id"] + + # Hit the same endpoint N times in parallel. + n_tests = 4 + url = f"{api_base}/api/assets/{aid}?delete_content=false" + tasks = [asyncio.create_task(http.delete(url)) for _ in range(n_tests)] + responses = await asyncio.gather(*tasks) + + statuses = [r.status for r in responses] + # Drain bodies to close connections (optional but avoids warnings). + await asyncio.gather(*[r.read() for r in responses]) + + # Exactly one actual delete, the rest must be 404 + assert statuses.count(204) == 1, f"Expected exactly one 204; got: {statuses}" + assert statuses.count(404) == n_tests - 1, f"Expected {n_tests-1} 404; got: {statuses}" + + # The resource must be gone. + async with http.get(f"{api_base}/api/assets/{aid}") as rg: + assert rg.status == 404 From 0b795dc7a70104bd301730ed4c0bfb23a0ee02a4 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 14 Sep 2025 15:14:24 +0300 Subject: [PATCH 53/82] removed non-needed code --- app/database/helpers/__init__.py | 4 ---- app/database/helpers/tags.py | 25 +------------------------ 2 files changed, 1 insertion(+), 28 deletions(-) diff --git a/app/database/helpers/__init__.py b/app/database/helpers/__init__.py index 19d7507fa44f..310583607c80 100644 --- a/app/database/helpers/__init__.py +++ b/app/database/helpers/__init__.py @@ -2,10 +2,8 @@ from .ownership import visible_owner_clause from .projection import is_scalar, project_kv from .tags import ( - add_missing_tag_for_asset_hash, add_missing_tag_for_asset_id, ensure_tags_exist, - remove_missing_tag_for_asset_hash, remove_missing_tag_for_asset_id, ) @@ -16,8 +14,6 @@ "project_kv", "ensure_tags_exist", "add_missing_tag_for_asset_id", - "add_missing_tag_for_asset_hash", "remove_missing_tag_for_asset_id", - "remove_missing_tag_for_asset_hash", "visible_owner_clause", ] diff --git a/app/database/helpers/tags.py b/app/database/helpers/tags.py index c8f9e50749b4..b3e006c0e80a 100644 --- a/app/database/helpers/tags.py +++ b/app/database/helpers/tags.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from ..._assets_helpers import normalize_tags -from ..models import Asset, AssetInfo, AssetInfoTag, Tag +from ..models import AssetInfo, AssetInfoTag, Tag from ..timeutil import utcnow @@ -77,18 +77,6 @@ async def add_missing_tag_for_asset_id( return len(to_add) -async def add_missing_tag_for_asset_hash( - session: AsyncSession, - *, - asset_hash: str, - origin: str = "automatic", -) -> int: - asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first() - if not asset: - return 0 - return await add_missing_tag_for_asset_id(session, asset_id=asset.id, origin=origin) - - async def remove_missing_tag_for_asset_id( session: AsyncSession, *, @@ -107,14 +95,3 @@ async def remove_missing_tag_for_asset_id( ) await session.flush() return int(res.rowcount or 0) - - -async def remove_missing_tag_for_asset_hash( - session: AsyncSession, - *, - asset_hash: str, -) -> int: - asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first() - if not asset: - return 0 - return await remove_missing_tag_for_asset_id(session, asset_id=asset.id) From a2ec1f7637ab8864a94e28027417c060c2ef58df Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 14 Sep 2025 15:31:42 +0300 Subject: [PATCH 54/82] simplify code --- app/database/services/content.py | 26 +++++++------------------- app/database/services/info.py | 7 +++---- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/app/database/services/content.py b/app/database/services/content.py index 8547a1bffafc..58fc6df0471a 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -1,3 +1,4 @@ +import contextlib import logging import os from datetime import datetime @@ -291,10 +292,8 @@ async def compute_hash_and_dedup_for_cache_state( state.asset_id = new_asset.id state.mtime_ns = mtime_ns state.needs_verify = False - try: + with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id) - except Exception: - pass await session.flush() return state.asset_id @@ -311,15 +310,12 @@ async def compute_hash_and_dedup_for_cache_state( duplicate_asset_id=this_asset.id, canonical_asset_id=canonical.id, ) - # Refresh state after the merge state = await session.get(AssetCacheState, state_id) if state: state.mtime_ns = mtime_ns state.needs_verify = False - try: + with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(session, asset_id=canonical.id) - except Exception: - pass await session.flush() return canonical.id @@ -345,10 +341,8 @@ async def compute_hash_and_dedup_for_cache_state( if state: state.mtime_ns = mtime_ns state.needs_verify = False - try: + with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(session, asset_id=canonical.id) - except Exception: - pass await session.flush() return canonical.id # If we got here, the integrity error was not about hash uniqueness @@ -357,10 +351,8 @@ async def compute_hash_and_dedup_for_cache_state( # Claimed successfully state.mtime_ns = mtime_ns state.needs_verify = False - try: + with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id) - except Exception: - pass await session.flush() return this_asset.id @@ -370,10 +362,8 @@ async def compute_hash_and_dedup_for_cache_state( this_asset.size_bytes = new_size state.mtime_ns = mtime_ns state.needs_verify = False - try: + with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id) - except Exception: - pass await session.flush() return this_asset.id @@ -393,10 +383,8 @@ async def compute_hash_and_dedup_for_cache_state( state.asset_id = target_id state.mtime_ns = mtime_ns state.needs_verify = False - try: + with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(session, asset_id=target_id) - except Exception: - pass await session.flush() return target_id diff --git a/app/database/services/info.py b/app/database/services/info.py index 687431d595ff..a31818f0bc5d 100644 --- a/app/database/services/info.py +++ b/app/database/services/info.py @@ -366,16 +366,15 @@ async def touch_asset_info_by_id( asset_info_id: str, ts: Optional[datetime] = None, only_if_newer: bool = True, -) -> int: +) -> bool: ts = ts or utcnow() stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) if only_if_newer: stmt = stmt.where( sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) ) - stmt = stmt.values(last_access_time=ts) - res = await session.execute(stmt) - return int(res.rowcount or 0) + stmt = stmt.values(last_access_time=ts).returning(AssetInfo.id) + return (await session.execute(stmt)).scalar_one_or_none() is not None async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool: From 6cfa94ec58e5dc550796960ac1caff70678bfb42 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 14 Sep 2025 15:54:16 +0300 Subject: [PATCH 55/82] fixed metadata[filename] feature + new tests for this --- app/_assets_helpers.py | 14 ++--- app/assets_manager.py | 7 +-- app/database/services/__init__.py | 2 + app/database/services/content.py | 55 ++++++++++++++---- app/database/services/info.py | 20 ++++--- app/database/services/queries.py | 17 ++++++ tests-assets/test_crud.py | 96 +++++++++++++++++++++++++++++++ 7 files changed, 179 insertions(+), 32 deletions(-) diff --git a/app/_assets_helpers.py b/app/_assets_helpers.py index e0b982c985c0..98761284547c 100644 --- a/app/_assets_helpers.py +++ b/app/_assets_helpers.py @@ -140,7 +140,7 @@ def ensure_within_base(candidate: str, base: str) -> None: raise ValueError("invalid destination path") -def compute_model_relative_filename(file_path: str) -> Optional[str]: +def compute_relative_filename(file_path: str) -> Optional[str]: """ Return the model's path relative to the last well-known folder (the model category), using forward slashes, eg: @@ -155,16 +155,16 @@ def compute_model_relative_filename(file_path: str) -> Optional[str]: except ValueError: return None - if root_category != "models": - return None - p = Path(rel_path) - # parts[0] is the well-known category (eg "checkpoints" or "text_encoders") parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] if not parts: return None - inside = parts[1:] if len(parts) > 1 else [parts[0]] - return "/".join(inside) # normalize to POSIX style for portability + + if root_category == "models": + # parts[0] is the category ("checkpoints", "vae", etc) – drop it + inside = parts[1:] if len(parts) > 1 else [parts[0]] + return "/".join(inside) + return "/".join(parts) # input/output: keep all parts def list_tree(base_dir: str) -> list[str]: diff --git a/app/assets_manager.py b/app/assets_manager.py index 9d2424ce6065..f3da06633eab 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -30,6 +30,7 @@ list_asset_infos_page, list_cache_states_by_asset_id, list_tags_with_usage, + pick_best_live_path, remove_tags_from_asset_info, set_asset_info_preview, touch_asset_info_by_id, @@ -177,11 +178,7 @@ async def resolve_asset_content_for_download( info, asset = pair states = await list_cache_states_by_asset_id(session, asset_id=asset.id) - abs_path = "" - for s in states: - if s and s.file_path and os.path.isfile(s.file_path): - abs_path = s.file_path - break + abs_path = pick_best_live_path(states) if not abs_path: raise FileNotFoundError diff --git a/app/database/services/__init__.py b/app/database/services/__init__.py index aed8815a67a3..88e97bfb049f 100644 --- a/app/database/services/__init__.py +++ b/app/database/services/__init__.py @@ -32,6 +32,7 @@ get_asset_info_by_id, get_cache_state_by_asset_id, list_cache_states_by_asset_id, + pick_best_live_path, ) __all__ = [ @@ -39,6 +40,7 @@ "asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id", "get_cache_state_by_asset_id", "list_cache_states_by_asset_id", + "pick_best_live_path", # info "list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags", "update_asset_info_full", "replace_asset_info_metadata_projection", diff --git a/app/database/services/content.py b/app/database/services/content.py index 58fc6df0471a..546cc7bd1930 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -12,7 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import noload -from ..._assets_helpers import compute_model_relative_filename, normalize_tags +from ..._assets_helpers import compute_relative_filename, normalize_tags from ...storage import hashing as hashing_mod from ..helpers import ( ensure_tags_exist, @@ -21,6 +21,7 @@ from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag from ..timeutil import utcnow from .info import replace_asset_info_metadata_projection +from .queries import list_cache_states_by_asset_id, pick_best_live_path async def check_fs_asset_exists_quick( @@ -106,6 +107,15 @@ async def ensure_seed_for_path( session.add(info) await session.flush() + with contextlib.suppress(Exception): + computed = compute_relative_filename(locator) + if computed: + await replace_asset_info_metadata_projection( + session, + asset_info_id=info.id, + user_metadata={"filename": computed}, + ) + want = normalize_tags(tags) if want: await ensure_tags_exist(session, want, tag_type="user") @@ -265,6 +275,8 @@ async def compute_hash_and_dedup_for_cache_state( if int(remaining or 0) == 0: await session.delete(asset) await session.flush() + else: + await _recompute_and_apply_filename_for_asset(session, asset_id=asset.id) return None digest = await hashing_mod.blake3_hash(path) @@ -316,6 +328,7 @@ async def compute_hash_and_dedup_for_cache_state( state.needs_verify = False with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(session, asset_id=canonical.id) + await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id) await session.flush() return canonical.id @@ -343,6 +356,7 @@ async def compute_hash_and_dedup_for_cache_state( state.needs_verify = False with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(session, asset_id=canonical.id) + await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id) await session.flush() return canonical.id # If we got here, the integrity error was not about hash uniqueness @@ -353,6 +367,7 @@ async def compute_hash_and_dedup_for_cache_state( state.needs_verify = False with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id) + await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id) await session.flush() return this_asset.id @@ -364,6 +379,7 @@ async def compute_hash_and_dedup_for_cache_state( state.needs_verify = False with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id) + await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id) await session.flush() return this_asset.id @@ -385,11 +401,10 @@ async def compute_hash_and_dedup_for_cache_state( state.needs_verify = False with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(session, asset_id=target_id) + await _recompute_and_apply_filename_for_asset(session, asset_id=target_id) await session.flush() return target_id - except Exception: - # Propagate; caller records the error and continues the worker. raise @@ -663,15 +678,8 @@ async def ingest_fs_asset( # metadata["filename"] hack if out["asset_info_id"] is not None: - primary_path = ( - await session.execute( - select(AssetCacheState.file_path) - .where(AssetCacheState.asset_id == asset.id) - .order_by(AssetCacheState.id.asc()) - .limit(1) - ) - ).scalars().first() - computed_filename = compute_model_relative_filename(primary_path) if primary_path else None + primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id)) + computed_filename = compute_relative_filename(primary_path) if primary_path else None current_meta = existing_info.user_metadata or {} new_meta = dict(current_meta) @@ -760,3 +768,26 @@ async def list_cache_states_with_asset_under_prefixes( ) ).all() return [(r[0], r[1], int(r[2] or 0)) for r in rows] + + +async def _recompute_and_apply_filename_for_asset(session: AsyncSession, *, asset_id: str) -> None: + """Compute filename from the first *existing* cache state path and apply it to all AssetInfo (if changed).""" + try: + primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset_id)) + if not primary_path: + return + new_filename = compute_relative_filename(primary_path) + if not new_filename: + return + infos = ( + await session.execute(select(AssetInfo).where(AssetInfo.asset_id == asset_id)) + ).scalars().all() + for info in infos: + current_meta = info.user_metadata or {} + if current_meta.get("filename") == new_filename: + continue + updated = dict(current_meta) + updated["filename"] = new_filename + await replace_asset_info_metadata_projection(session, asset_info_id=info.id, user_metadata=updated) + except Exception: + logging.exception("Failed to recompute filename metadata for asset %s", asset_id) diff --git a/app/database/services/info.py b/app/database/services/info.py index a31818f0bc5d..5c7e3c92fe9c 100644 --- a/app/database/services/info.py +++ b/app/database/services/info.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager, noload -from ..._assets_helpers import compute_model_relative_filename, normalize_tags +from ..._assets_helpers import compute_relative_filename, normalize_tags from ..helpers import ( apply_metadata_filter, apply_tag_filters, @@ -18,7 +18,11 @@ ) from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag from ..timeutil import utcnow -from .queries import get_asset_by_hash, get_cache_state_by_asset_id +from .queries import ( + get_asset_by_hash, + list_cache_states_by_asset_id, + pick_best_live_path, +) async def list_asset_infos_page( @@ -196,9 +200,9 @@ async def create_asset_info_for_existing_asset( new_meta = dict(user_metadata or {}) computed_filename = None try: - state = await get_cache_state_by_asset_id(session, asset_id=asset.id) - if state and state.file_path: - computed_filename = compute_model_relative_filename(state.file_path) + p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id)) + if p: + computed_filename = compute_relative_filename(p) except Exception: computed_filename = None if computed_filename: @@ -280,9 +284,9 @@ async def update_asset_info_full( computed_filename = None try: - state = await get_cache_state_by_asset_id(session, asset_id=info.asset_id) - if state and state.file_path: - computed_filename = compute_model_relative_filename(state.file_path) + p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=info.asset_id)) + if p: + computed_filename = compute_relative_filename(p) except Exception: computed_filename = None diff --git a/app/database/services/queries.py b/app/database/services/queries.py index 81649b7f4c10..fc05e5cbf2a0 100644 --- a/app/database/services/queries.py +++ b/app/database/services/queries.py @@ -1,3 +1,4 @@ +import os from typing import Optional, Sequence, Union import sqlalchemy as sa @@ -57,3 +58,19 @@ async def list_cache_states_by_asset_id( .order_by(AssetCacheState.id.asc()) ) ).scalars().all() + + +def pick_best_live_path(states: Union[list[AssetCacheState], Sequence[AssetCacheState]]) -> str: + """ + Return the best on-disk path among cache states: + 1) Prefer a path that exists with needs_verify == False (already verified). + 2) Otherwise, pick the first path that exists. + 3) Otherwise return empty string. + """ + alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)] + if not alive: + return "" + for s in alive: + if not getattr(s, "needs_verify", False): + return s.file_path + return alive[0].file_path diff --git a/tests-assets/test_crud.py b/tests-assets/test_crud.py index ba7f23f6763a..ad435d65bf25 100644 --- a/tests-assets/test_crud.py +++ b/tests-assets/test_crud.py @@ -1,8 +1,10 @@ import asyncio import uuid +from pathlib import Path import aiohttp import pytest +from conftest import trigger_sync_seed_assets @pytest.mark.asyncio @@ -218,3 +220,97 @@ async def test_concurrent_delete_same_asset_info_single_204( # The resource must be gone. async with http.get(f"{api_base}/api/assets/{aid}") as rg: assert rg.status == 404 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_metadata_filename_is_set_for_seed_asset_without_hash( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, +): + """Seed ingest (no hash yet) must compute user_metadata['filename'] immediately.""" + scope = f"seedmeta-{uuid.uuid4().hex[:6]}" + name = "seed_filename.bin" + + base = comfy_tmp_base_dir / root / "unit-tests" / scope / "a" / "b" + base.mkdir(parents=True, exist_ok=True) + fp = base / name + fp.write_bytes(b"Z" * 2048) + + await trigger_sync_seed_assets(http, api_base) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": name}, + ) as r1: + body = await r1.json() + assert r1.status == 200, body + matches = [a for a in body.get("assets", []) if a.get("name") == name] + assert matches, "Seed asset should be visible after sync" + assert matches[0].get("asset_hash") is None # still a seed + aid = matches[0]["id"] + + async with http.get(f"{api_base}/api/assets/{aid}") as r2: + detail = await r2.json() + assert r2.status == 200, detail + filename = (detail.get("user_metadata") or {}).get("filename") + expected = str(fp.relative_to(comfy_tmp_base_dir / root)).replace("\\", "/") + assert filename == expected, f"expected filename={expected}, got {filename!r}" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_metadata_filename_computed_and_updated_on_retarget( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + 1) Ingest under {root}/unit-tests//a/b/ -> filename reflects relative path. + 2) Retarget by copying to {root}/unit-tests//x/, remove old file, + run fast pass + scan -> filename updates to new relative path. + """ + scope = f"meta-fn-{uuid.uuid4().hex[:6]}" + name1 = "compute_metadata_filename.png" + name2 = "compute_changed_metadata_filename.png" + data = make_asset_bytes(name1, 2100) + + # Upload into nested path a/b + a = await asset_factory(name1, [root, "unit-tests", scope, "a", "b"], {}, data) + aid = a["id"] + + root_base = comfy_tmp_base_dir / root + p1 = root_base / "unit-tests" / scope / "a" / "b" / name1 + assert p1.exists() + + # filename at ingest should be the path relative to root + rel1 = str(p1.relative_to(root_base)).replace("\\", "/") + async with http.get(f"{api_base}/api/assets/{aid}") as g1: + d1 = await g1.json() + assert g1.status == 200, d1 + fn1 = d1["user_metadata"].get("filename") + assert fn1 == rel1 + + # Retarget: copy to x/, remove old, then sync+scan + p2 = root_base / "unit-tests" / scope / "x" / name2 + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + if p1.exists(): + p1.unlink() + + await trigger_sync_seed_assets(http, api_base) # seed the new path + await run_scan_and_wait(root) # verify/hash and reconcile + + # filename should now point at x/ + rel2 = str(p2.relative_to(root_base)).replace("\\", "/") + async with http.get(f"{api_base}/api/assets/{aid}") as g2: + d2 = await g2.json() + assert g2.status == 200, d2 + fn2 = d2["user_metadata"].get("filename") + assert fn2 == rel2 From a7f2546558ded224951f1a6713f6046ca74f9c3a Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 14 Sep 2025 17:55:02 +0300 Subject: [PATCH 56/82] fix: use ".rowcount" instead of ".returning" on SQLite --- app/database/services/content.py | 36 +++++++++++++++++--------------- app/database/services/info.py | 19 ++++++++++------- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/app/database/services/content.py b/app/database/services/content.py index 546cc7bd1930..ead2e2389276 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -513,14 +513,20 @@ async def ingest_fs_asset( "created_at": now, } if dialect == "sqlite": - ins = ( + res = await session.execute( d_sqlite.insert(Asset) .values(**vals) .on_conflict_do_nothing(index_elements=[Asset.hash]) - .returning(Asset.id) ) + if int(res.rowcount or 0) > 0: + out["asset_created"] = True + asset = ( + await session.execute( + select(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).scalars().first() elif dialect == "postgresql": - ins = ( + res = await session.execute( d_pg.insert(Asset) .values(**vals) .on_conflict_do_nothing( @@ -529,24 +535,20 @@ async def ingest_fs_asset( ) .returning(Asset.id) ) + inserted_id = res.scalar_one_or_none() + if inserted_id: + out["asset_created"] = True + asset = await session.get(Asset, inserted_id) + else: + asset = ( + await session.execute( + select(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).scalars().first() else: raise NotImplementedError(f"Unsupported database dialect: {dialect}") - res = await session.execute(ins) - inserted_id = res.scalar_one_or_none() - asset = ( - await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() if not asset: raise RuntimeError("Asset row not found after upsert.") - if inserted_id: - out["asset_created"] = True - asset = await session.get(Asset, inserted_id) - else: - asset = ( - await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() - if not asset: - raise RuntimeError("Asset row not found after upsert.") else: changed = False if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: diff --git a/app/database/services/info.py b/app/database/services/info.py index 5c7e3c92fe9c..d2fd1f503f33 100644 --- a/app/database/services/info.py +++ b/app/database/services/info.py @@ -377,17 +377,20 @@ async def touch_asset_info_by_id( stmt = stmt.where( sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) ) - stmt = stmt.values(last_access_time=ts).returning(AssetInfo.id) - return (await session.execute(stmt)).scalar_one_or_none() is not None + stmt = stmt.values(last_access_time=ts) + if session.bind.dialect.name == "postgresql": + return (await session.execute(stmt.returning(AssetInfo.id))).scalar_one_or_none() is not None + return int((await session.execute(stmt)).rowcount or 0) > 0 async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool: - return ( - await session.execute(delete(AssetInfo).where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ).returning(AssetInfo.id)) - ).scalar_one_or_none() is not None + stmt = sa.delete(AssetInfo).where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + if session.bind.dialect.name == "postgresql": + return (await session.execute(stmt.returning(AssetInfo.id))).scalar_one_or_none() is not None + return int((await session.execute(stmt)).rowcount or 0) > 0 async def add_tags_to_asset_info( From a2fc2bbae4be3170e4dafc27bf51a1daf70ce5e8 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 14 Sep 2025 18:12:00 +0300 Subject: [PATCH 57/82] corrected formatting --- app/api/schemas_in.py | 6 ++++-- app/assets_scanner.py | 7 +++++-- tests-assets/test_metadata_filters.py | 14 +++++++++++--- tests-assets/test_scans.py | 6 ++++-- 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/app/api/schemas_in.py b/app/api/schemas_in.py index bc521b313d49..109ed3f0752f 100644 --- a/app/api/schemas_in.py +++ b/app/api/schemas_in.py @@ -101,10 +101,12 @@ def _tags_norm(cls, v): return [] if isinstance(v, list): out = [str(t).strip().lower() for t in v if str(t).strip()] - seen = set(); dedup = [] + seen = set() + dedup = [] for t in out: if t not in seen: - seen.add(t); dedup.append(t) + seen.add(t) + dedup.append(t) return dedup if isinstance(v, str): return [t.strip().lower() for t in v.split(",") if t.strip()] diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 6cca5b16571c..738b4e7814b7 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -239,7 +239,8 @@ async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgres for lst in (verify_ids, unhashed_ids): for sid in lst: if sid not in seen: - seen.add(sid); ordered.append(sid) + seen.add(sid) + ordered.append(sid) prog.discovered = len(ordered) @@ -382,7 +383,9 @@ async def _close_when_ready(): asyncio.create_task(_close_when_ready()) -async def _await_state_workers_then_finish(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None: +async def _await_state_workers_then_finish( + root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState +) -> None: if state.workers: await asyncio.gather(*state.workers, return_exceptions=True) await _reconcile_missing_tags_for_root(root, prog) diff --git a/tests-assets/test_metadata_filters.py b/tests-assets/test_metadata_filters.py index 360a24527d55..4c4c8f946314 100644 --- a/tests-assets/test_metadata_filters.py +++ b/tests-assets/test_metadata_filters.py @@ -5,7 +5,9 @@ @pytest.mark.asyncio -async def test_meta_and_across_keys_and_types(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes): +async def test_meta_and_across_keys_and_types( + http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes +): name = "mf_and_mix.safetensors" tags = ["models", "checkpoints", "unit-tests", "mf-and"] meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} @@ -126,7 +128,9 @@ async def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_a @pytest.mark.asyncio -async def test_meta_none_semantics_missing_or_null_and_any_of_with_none(http, api_base, asset_factory, make_asset_bytes): +async def test_meta_none_semantics_missing_or_null_and_any_of_with_none( + http, api_base, asset_factory, make_asset_bytes +): # a1: key missing; a2: explicit null; a3: concrete value t = ["models", "checkpoints", "unit-tests", "mf-none"] a1 = await asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1")) @@ -362,7 +366,11 @@ async def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, await asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072)) # Sort by size ascending with paging - q = {"include_tags": "unit-tests,mf-sort", "metadata_filter": json.dumps({"group": "g"}), "sort": "size", "order": "asc", "limit": "2"} + q = { + "include_tags": "unit-tests,mf-sort", + "metadata_filter": json.dumps({"group": "g"}), + "sort": "size", "order": "asc", "limit": "2", + } async with http.get(api_base + "/api/assets", params=q) as r1: b1 = await r1.json() assert r1.status == 200 diff --git a/tests-assets/test_scans.py b/tests-assets/test_scans.py index 7774569cadd7..2058685ddbcd 100644 --- a/tests-assets/test_scans.py +++ b/tests-assets/test_scans.py @@ -259,9 +259,11 @@ async def test_delete_one_assetinfo_of_missing_asset_keeps_identity( # Both infos should be marked missing async with http.get(f"{api_base}/api/assets/{a1['id']}") as g1: - d1 = await g1.json(); assert "missing" in set(d1.get("tags", [])) + d1 = await g1.json() + assert "missing" in set(d1.get("tags", [])) async with http.get(f"{api_base}/api/assets/{a2['id']}") as g2: - d2 = await g2.json(); assert "missing" in set(d2.get("tags", [])) + d2 = await g2.json() + assert "missing" in set(d2.get("tags", [])) # Delete one info async with http.delete(f"{api_base}/api/assets/{a1['id']}") as rd: From 1d970382f0a030432c7bc8626814193a4c9e8558 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 14 Sep 2025 20:02:28 +0300 Subject: [PATCH 58/82] added final tests --- tests-assets/test_assets_missing_sync.py | 155 ++++++++++++++++++++++- tests-assets/test_scans.py | 46 +++++++ tests-assets/test_uploads.py | 39 ++++++ 3 files changed, 237 insertions(+), 3 deletions(-) diff --git a/tests-assets/test_assets_missing_sync.py b/tests-assets/test_assets_missing_sync.py index aec6606d60c0..87a6b1a32074 100644 --- a/tests-assets/test_assets_missing_sync.py +++ b/tests-assets/test_assets_missing_sync.py @@ -1,3 +1,4 @@ +import os import uuid from pathlib import Path @@ -7,7 +8,9 @@ @pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) async def test_seed_asset_removed_when_file_is_deleted( + root: str, http: aiohttp.ClientSession, api_base: str, comfy_tmp_base_dir: Path, @@ -16,7 +19,7 @@ async def test_seed_asset_removed_when_file_is_deleted( after triggering sync_seed_assets, Asset + AssetInfo disappear. """ # Create a file directly under input/unit-tests/ so tags include "unit-tests" - case_dir = comfy_tmp_base_dir / "input" / "unit-tests" / "syncseed" + case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed" case_dir.mkdir(parents=True, exist_ok=True) name = f"seed_{uuid.uuid4().hex[:8]}.bin" fp = case_dir / name @@ -130,6 +133,12 @@ async def test_hashed_asset_two_assetinfos_both_get_missing( assert p.exists() p.unlink() + async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r0: + tags0 = await r0.json() + assert r0.status == 200, tags0 + byname0 = {t["name"]: t for t in tags0.get("tags", [])} + old_missing = int(byname0.get("missing", {}).get("count", 0)) + # Sync -> both AssetInfos for this asset must receive 'missing' await trigger_sync_seed_assets(http, api_base) @@ -143,6 +152,14 @@ async def test_hashed_asset_two_assetinfos_both_get_missing( assert gb.status == 200, db assert "missing" in set(db.get("tags", [])) + # Tag usage for 'missing' increased by exactly 2 (two AssetInfos) + async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1: + tags1 = await r1.json() + assert r1.status == 200, tags1 + byname1 = {t["name"]: t for t in tags1.get("tags", [])} + new_missing = int(byname1.get("missing", {}).get("count", 0)) + assert new_missing == old_missing + 2 + @pytest.mark.asyncio async def test_hashed_asset_two_cache_states_partial_delete_then_full_delete( @@ -173,7 +190,7 @@ async def test_hashed_asset_two_cache_states_partial_delete_then_full_delete( # Fast seed so the second path appears (as a seed initially) await trigger_sync_seed_assets(http, api_base) - # Now run a 'models' scan so the seed copy is hashed and deduped + # Deduplication of AssetInfo-s will not happen as first AssetInfo has owner='default' and second has empty owner. await run_scan_and_wait("input") # Remove only one file and sync -> asset should still be healthy (no 'missing') @@ -185,7 +202,13 @@ async def test_hashed_asset_two_cache_states_partial_delete_then_full_delete( assert g1.status == 200, d1 assert "missing" not in set(d1.get("tags", [])), "Should not be missing while one valid path remains" - # Remove the second (last) file and sync -> now we expect 'missing' + # Baseline 'missing' usage count just before last file removal + async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r0: + tags0 = await r0.json() + assert r0.status == 200, tags0 + old_missing = int({t["name"]: t for t in tags0.get("tags", [])}.get("missing", {}).get("count", 0)) + + # Remove the second (last) file and sync -> now we expect 'missing' on this AssetInfo path2.unlink() await trigger_sync_seed_assets(http, api_base) @@ -193,3 +216,129 @@ async def test_hashed_asset_two_cache_states_partial_delete_then_full_delete( d2 = await g2.json() assert g2.status == 200, d2 assert "missing" in set(d2.get("tags", [])), "Missing must be set once no valid paths remain" + + # Tag usage for 'missing' increased by exactly 2 (two AssetInfo for one Asset) + async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1: + tags1 = await r1.json() + assert r1.status == 200, tags1 + new_missing = int({t["name"]: t for t in tags1.get("tags", [])}.get("missing", {}).get("count", 0)) + assert new_missing == old_missing + 2 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, +): + """ + Fast pass alone clears 'missing' when size and mtime match exactly: + 1) upload (hashed), record original mtime_ns + 2) delete -> fast pass adds 'missing' + 3) restore same bytes and set mtime back to the original value + 4) run fast pass again -> 'missing' is removed (no slow scan) + """ + scope = f"fastclear-{uuid.uuid4().hex[:6]}" + name = "fastpass_clear.bin" + data = make_asset_bytes(name, 3072) + + a = await asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + base = comfy_tmp_base_dir / root / "unit-tests" / scope + p = base / name + st0 = p.stat() + orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000)) + + # Delete -> fast pass adds 'missing' + p.unlink() + await trigger_sync_seed_assets(http, api_base) + async with http.get(f"{api_base}/api/assets/{aid}") as g1: + d1 = await g1.json() + assert g1.status == 200, d1 + assert "missing" in set(d1.get("tags", [])) + + # Restore same bytes and revert mtime to the original value + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(data) + # set both atime and mtime in ns to ensure exact match + os.utime(p, ns=(orig_mtime_ns, orig_mtime_ns)) + + # Fast pass should clear 'missing' without a scan + await trigger_sync_seed_assets(http, api_base) + async with http.get(f"{api_base}/api/assets/{aid}") as g2: + d2 = await g2.json() + assert g2.status == 200, d2 + assert "missing" not in set(d2.get("tags", [])), "Fast pass should clear 'missing' when size+mtime match" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_fastpass_removes_stale_state_row_no_missing( + root: str, + http: aiohttp.ClientSession, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + Hashed asset with two states: + - delete one file + - run fast pass only + Expect: + - asset stays healthy (no 'missing') + - stale AssetCacheState row for the deleted path is removed. + We verify this behaviorally by recreating the deleted path and running fast pass again: + a new *seed* AssetInfo is created, which proves the old state row was not reused. + """ + scope = f"stale-{uuid.uuid4().hex[:6]}" + name = "two_states.bin" + data = make_asset_bytes(name, 2048) + + # Upload hashed asset at path1 + a = await asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + h = a["asset_hash"] + base = comfy_tmp_base_dir / root / "unit-tests" / scope + p1 = base / name + assert p1.exists() + + # Create second state path2, seed+scan to dedupe into the same Asset + p2 = base / "copy" / name + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + await trigger_sync_seed_assets(http, api_base) + await run_scan_and_wait(root) + + # Delete path1 and run fast pass -> no 'missing' and stale state row should be removed + p1.unlink() + await trigger_sync_seed_assets(http, api_base) + async with http.get(f"{api_base}/api/assets/{aid}") as g1: + d1 = await g1.json() + assert g1.status == 200, d1 + assert "missing" not in set(d1.get("tags", [])) + + # Recreate path1 and run fast pass again. + # If the stale state row was removed, a NEW seed AssetInfo will appear for this path. + p1.write_bytes(data) + await trigger_sync_seed_assets(http, api_base) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": name}, + ) as rl: + bl = await rl.json() + assert rl.status == 200, bl + items = bl.get("assets", []) + # one hashed AssetInfo (asset_hash == h) + one seed AssetInfo (asset_hash == null) + hashes = [it.get("asset_hash") for it in items if it.get("name") == name] + assert h in hashes and any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path" + + # Asset identity still healthy + async with http.head(f"{api_base}/api/assets/hash/{h}") as rh: + assert rh.status == 200 diff --git a/tests-assets/test_scans.py b/tests-assets/test_scans.py index 2058685ddbcd..fa3b415d58bc 100644 --- a/tests-assets/test_scans.py +++ b/tests-assets/test_scans.py @@ -464,3 +464,49 @@ async def test_concurrent_seed_hashing_same_file_no_dupes( matches = [a for a in b.get("assets", []) if a.get("name") == name] assert len(matches) == 1 assert matches[0].get("asset_hash"), "Seed should have been hashed into an Asset" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_cache_state_retarget_on_content_change_asset_info_stays( + root: str, + http, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + Start with hashed H1 (AssetInfo A1). Replace file bytes on disk to become H2. + After scan: AssetCacheState points to H2; A1 still references H1; downloading A1 -> 404. + """ + scope = f"retarget-{uuid.uuid4().hex[:6]}" + name = "content_change.bin" + d1 = make_asset_bytes("v1-" + scope, 2048) + + a1 = await asset_factory(name, [root, "unit-tests", scope], {}, d1) + aid = a1["id"] + h1 = a1["asset_hash"] + + p = comfy_tmp_base_dir / root / "unit-tests" / scope / name + assert p.exists() + + # Change the bytes in place to force a new content hash (H2) + d2 = make_asset_bytes("v2-" + scope, 3072) + p.write_bytes(d2) + + # Scan to verify and retarget the state; reconcilers run after scan + await run_scan_and_wait(root) + + # AssetInfo still on the old content identity (H1) + async with http.get(f"{api_base}/api/assets/{aid}") as rg: + g = await rg.json() + assert rg.status == 200, g + assert g.get("asset_hash") == h1 + + # Download must fail until a state exists for H1 (we removed the only one by retarget) + async with http.get(f"{api_base}/api/assets/{aid}/content") as dl: + body = await dl.json() + assert dl.status == 404, body + assert body["error"]["code"] == "FILE_NOT_FOUND" diff --git a/tests-assets/test_uploads.py b/tests-assets/test_uploads.py index c1a962d263d4..de318d0576cb 100644 --- a/tests-assets/test_uploads.py +++ b/tests-assets/test_uploads.py @@ -284,3 +284,42 @@ async def test_upload_tags_traversal_guard(http: aiohttp.ClientSession, api_base body = await r.json() assert r.status == 400 assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_duplicate_upload_same_path_updates_state( + root: str, + http, + api_base: str, + asset_factory, + make_asset_bytes, +): + """ + Two uploads target the exact same destination path (same tags+name) with different bytes. + Expect: file on disk is from the last upload; its AssetInfo serves content; the first AssetInfo's content 404s. + This validates that AssetCacheState(file_path) remains unique and its asset_id/mtime_ns were updated. + """ + scope = f"dup-path-{uuid.uuid4().hex[:6]}" + name = "same_path.bin" + + d1 = make_asset_bytes(scope + "-v1", 1536) + d2 = make_asset_bytes(scope + "-v2", 2048) + tags = [root, "unit-tests", scope] + + first = await asset_factory(name, tags, {}, d1) + second = await asset_factory(name, tags, {}, d2) + + # Second one must serve the new bytes + async with http.get(f"{api_base}/api/assets/{second['id']}/content") as r2: + b2 = await r2.read() + assert r2.status == 200 + assert b2 == d2 + + # The first AssetInfo now points to an identity with no live state for that path -> 404 + async with http.get(f"{api_base}/api/assets/{first['id']}/content") as r1: + try: + body = await r1.json() + except Exception: + body = {} + assert r1.status == 404, body From dda31de690d26870fc617159437fee3346ca004a Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 14 Sep 2025 21:28:31 +0300 Subject: [PATCH 59/82] rework: AssetInfo.name is only a display name --- app/api/schemas_in.py | 7 ++++-- app/assets_manager.py | 19 +++++++++----- tests-assets/conftest.py | 6 ++++- tests-assets/test_assets_missing_sync.py | 27 +++++++++++--------- tests-assets/test_crud.py | 4 +-- tests-assets/test_downloads.py | 16 ++++++------ tests-assets/test_scans.py | 32 +++++++++++------------- tests-assets/test_uploads.py | 32 ++++++++++++------------ 8 files changed, 79 insertions(+), 64 deletions(-) diff --git a/app/api/schemas_in.py b/app/api/schemas_in.py index 109ed3f0752f..1469d325d8e1 100644 --- a/app/api/schemas_in.py +++ b/app/api/schemas_in.py @@ -170,14 +170,17 @@ class UploadAssetSpec(BaseModel): """Upload Asset operation. - tags: ordered; first is root ('models'|'input'|'output'); if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths - - name: desired filename (optional); fallback will be the file hash + - name: display name - user_metadata: arbitrary JSON object (optional) - hash: optional canonical 'blake3:' provided by the client for validation / fast-path + + Files created via this endpoint are stored on disk using the **content hash** as the filename stem + and the original extension is preserved when available. """ model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) tags: list[str] = Field(..., min_length=1) - name: Optional[str] = Field(default=None, max_length=512) + name: Optional[str] = Field(default=None, max_length=512, description="Display Name") user_metadata: dict[str, Any] = Field(default_factory=dict) hash: Optional[str] = Field(default=None) diff --git a/app/assets_manager.py b/app/assets_manager.py index f3da06633eab..4aae6e8ad08b 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -214,11 +214,11 @@ async def upload_asset_from_temp_path( if temp_path and os.path.exists(temp_path): os.remove(temp_path) - desired_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) + display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) info = await create_asset_info_for_existing_asset( session, asset_hash=asset_hash, - name=desired_name, + name=display_name, user_metadata=spec.user_metadata or {}, tags=spec.tags or [], tag_origin="manual", @@ -245,11 +245,18 @@ async def upload_asset_from_temp_path( dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir os.makedirs(dest_dir, exist_ok=True) - desired_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) - dest_abs = os.path.abspath(os.path.join(dest_dir, desired_name)) + src_for_ext = (client_filename or spec.name or "").strip() + _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" + ext = _ext if 0 < len(_ext) <= 16 else "" + hashed_basename = f"{digest}{ext}" + dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) ensure_within_base(dest_abs, base_dir) - content_type = mimetypes.guess_type(desired_name, strict=False)[0] or "application/octet-stream" + content_type = ( + mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] + or mimetypes.guess_type(hashed_basename, strict=False)[0] + or "application/octet-stream" + ) try: os.replace(temp_path, dest_abs) @@ -269,7 +276,7 @@ async def upload_asset_from_temp_path( size_bytes=size_bytes, mtime_ns=mtime_ns, mime_type=content_type, - info_name=os.path.basename(dest_abs), + info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest), owner_id=owner_id, preview_id=None, user_metadata=spec.user_metadata or {}, diff --git a/tests-assets/conftest.py b/tests-assets/conftest.py index 84f8747874f3..3f31f226efcb 100644 --- a/tests-assets/conftest.py +++ b/tests-assets/conftest.py @@ -228,7 +228,7 @@ async def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict: @pytest_asyncio.fixture async def seeded_asset(request: pytest.FixtureRequest, http: aiohttp.ClientSession, api_base: str) -> dict: """ - Upload one asset into models/checkpoints/unit-tests/. + Upload one asset with ".safetensors" extension into models/checkpoints/unit-tests/. Returns response dict with id, asset_hash, tags, etc. """ name = "unit_1_example.safetensors" @@ -301,3 +301,7 @@ async def _run(root: str, timeout: float = 120.0): raise TimeoutError(f"Timed out waiting for scan of root={root}") await asyncio.sleep(0.1) return _run + + +def get_asset_filename(asset_hash: str, extension: str) -> str: + return asset_hash.removeprefix("blake3:") + extension diff --git a/tests-assets/test_assets_missing_sync.py b/tests-assets/test_assets_missing_sync.py index 87a6b1a32074..b959e33f0a33 100644 --- a/tests-assets/test_assets_missing_sync.py +++ b/tests-assets/test_assets_missing_sync.py @@ -4,7 +4,7 @@ import aiohttp import pytest -from conftest import trigger_sync_seed_assets +from conftest import get_asset_filename, trigger_sync_seed_assets @pytest.mark.asyncio @@ -77,7 +77,7 @@ async def test_hashed_asset_missing_tag_added_then_removed_after_scan( a = await asset_factory(name, tags, {}, data) # Compute its on-disk path and remove it - dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / name + dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / get_asset_filename(a["asset_hash"], ".png") assert dest.exists(), f"Expected asset file at {dest}" dest.unlink() @@ -102,7 +102,7 @@ async def test_hashed_asset_missing_tag_added_then_removed_after_scan( @pytest.mark.asyncio -async def test_hashed_asset_two_assetinfos_both_get_missing( +async def test_hashed_asset_two_asset_infos_both_get_missing( http: aiohttp.ClientSession, api_base: str, comfy_tmp_base_dir: Path, @@ -129,7 +129,7 @@ async def test_hashed_asset_two_assetinfos_both_get_missing( second_id = b2["id"] # Remove the single underlying file - p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / name + p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png") assert p.exists() p.unlink() @@ -179,7 +179,7 @@ async def test_hashed_asset_two_cache_states_partial_delete_then_full_delete( data = make_asset_bytes(name, 3072) created = await asset_factory(name, tags, {}, data) - path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / name + path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / get_asset_filename(created["asset_hash"], ".png") assert path1.exists() # Create a second on-disk copy under the same root but different subfolder @@ -249,7 +249,7 @@ async def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match( a = await asset_factory(name, [root, "unit-tests", scope], {}, data) aid = a["id"] base = comfy_tmp_base_dir / root / "unit-tests" / scope - p = base / name + p = base / get_asset_filename(a["asset_hash"], ".bin") st0 = p.stat() orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000)) @@ -302,12 +302,14 @@ async def test_fastpass_removes_stale_state_row_no_missing( # Upload hashed asset at path1 a = await asset_factory(name, [root, "unit-tests", scope], {}, data) - aid = a["id"] - h = a["asset_hash"] base = comfy_tmp_base_dir / root / "unit-tests" / scope - p1 = base / name + a1_filename = get_asset_filename(a["asset_hash"], ".bin") + p1 = base / a1_filename assert p1.exists() + aid = a["id"] + h = a["asset_hash"] + # Create second state path2, seed+scan to dedupe into the same Asset p2 = base / "copy" / name p2.parent.mkdir(parents=True, exist_ok=True) @@ -330,14 +332,15 @@ async def test_fastpass_removes_stale_state_row_no_missing( async with http.get( api_base + "/api/assets", - params={"include_tags": f"unit-tests,{scope}", "name_contains": name}, + params={"include_tags": f"unit-tests,{scope}"}, ) as rl: bl = await rl.json() assert rl.status == 200, bl items = bl.get("assets", []) # one hashed AssetInfo (asset_hash == h) + one seed AssetInfo (asset_hash == null) - hashes = [it.get("asset_hash") for it in items if it.get("name") == name] - assert h in hashes and any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path" + hashes = [it.get("asset_hash") for it in items if it.get("name") in (name, a1_filename)] + assert h in hashes + assert any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path" # Asset identity still healthy async with http.head(f"{api_base}/api/assets/hash/{h}") as rh: diff --git a/tests-assets/test_crud.py b/tests-assets/test_crud.py index ad435d65bf25..f2e4c2699dd7 100644 --- a/tests-assets/test_crud.py +++ b/tests-assets/test_crud.py @@ -4,7 +4,7 @@ import aiohttp import pytest -from conftest import trigger_sync_seed_assets +from conftest import get_asset_filename, trigger_sync_seed_assets @pytest.mark.asyncio @@ -286,7 +286,7 @@ async def test_metadata_filename_computed_and_updated_on_retarget( aid = a["id"] root_base = comfy_tmp_base_dir / root - p1 = root_base / "unit-tests" / scope / "a" / "b" / name1 + p1 = (root_base / "unit-tests" / scope / "a" / "b" / get_asset_filename(a["asset_hash"], ".png")) assert p1.exists() # filename at ingest should be the path relative to root diff --git a/tests-assets/test_downloads.py b/tests-assets/test_downloads.py index cb8b36220177..181aad6f60fa 100644 --- a/tests-assets/test_downloads.py +++ b/tests-assets/test_downloads.py @@ -6,7 +6,7 @@ import aiohttp import pytest -from conftest import trigger_sync_seed_assets +from conftest import get_asset_filename, trigger_sync_seed_assets @pytest.mark.asyncio @@ -53,7 +53,7 @@ async def test_download_chooses_existing_state_and_updates_access_time( aid = a["id"] base = comfy_tmp_base_dir / root / "unit-tests" / scope - path1 = base / name + path1 = base / get_asset_filename(a["asset_hash"], ".bin") assert path1.exists() # Seed path2 by copying, then scan to dedupe into a second state @@ -108,14 +108,14 @@ async def test_download_missing_file_returns_404( async with http.get(f"{api_base}/api/assets/{aid}") as rg: detail = await rg.json() assert rg.status == 200 - rel_inside_category = detail["name"] - abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / rel_inside_category - if abs_path.exists(): - abs_path.unlink() + asset_filename = get_asset_filename(detail["asset_hash"], ".safetensors") + abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / asset_filename + assert abs_path.exists() + abs_path.unlink() async with http.get(f"{api_base}/api/assets/{aid}/content") as r2: - body = await r2.json() assert r2.status == 404 + body = await r2.json() assert body["error"]["code"] == "FILE_NOT_FOUND" finally: # We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually. @@ -144,7 +144,7 @@ async def test_download_404_if_all_states_missing( aid = a["id"] base = comfy_tmp_base_dir / root / "unit-tests" / scope - p1 = base / name + p1 = base / get_asset_filename(a["asset_hash"], ".bin") assert p1.exists() # Seed a second state and dedupe diff --git a/tests-assets/test_scans.py b/tests-assets/test_scans.py index fa3b415d58bc..fcedce5cf7d4 100644 --- a/tests-assets/test_scans.py +++ b/tests-assets/test_scans.py @@ -5,7 +5,7 @@ import aiohttp import pytest -from conftest import trigger_sync_seed_assets +from conftest import get_asset_filename, trigger_sync_seed_assets def _base_for(root: str, comfy_tmp_base_dir: Path) -> Path: @@ -138,7 +138,8 @@ async def test_scan_records_file_errors_permission_denied( deny_dir.mkdir(parents=True, exist_ok=True) name = "deny.bin" - await asset_factory(name, [root, "unit-tests", scope, "deny"], {}, b"X" * 2048) + a1 = await asset_factory(name, [root, "unit-tests", scope, "deny"], {}, b"X" * 2048) + asset_filename = get_asset_filename(a1["asset_hash"], ".bin") try: os.chmod(deny_dir, 0x000) async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r: @@ -152,10 +153,11 @@ async def test_scan_records_file_errors_permission_denied( assert len(scans) == 1 errs = scans[0].get("file_errors", []) # Should contain at least one PermissionError-like record - assert errs and any(e.get("path", "").endswith(name) and e.get("message") for e in errs) + assert errs + assert any(e.get("path", "").endswith(asset_filename) and e.get("message") for e in errs) finally: try: - os.chmod(deny_dir, 0o755) + os.chmod(deny_dir, 0x755) except Exception: pass @@ -182,7 +184,7 @@ async def test_missing_tag_created_and_visible_in_tags( created = await asset_factory(name, [root, "unit-tests", scope], {}, b"Y" * 4096) # Remove the only file and trigger fast pass - p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / name + p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(created["asset_hash"], ".bin") assert p.exists() p.unlink() await trigger_sync_seed_assets(http, api_base) @@ -217,7 +219,7 @@ async def test_missing_reapplies_after_manual_removal( created = await asset_factory(name, [root, "unit-tests", scope], {}, b"Z" * 1024) # Make it missing - p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / name + p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(created["asset_hash"], ".bin") p.unlink() await trigger_sync_seed_assets(http, api_base) @@ -237,7 +239,7 @@ async def test_missing_reapplies_after_manual_removal( @pytest.mark.asyncio @pytest.mark.parametrize("root", ["input", "output"]) -async def test_delete_one_assetinfo_of_missing_asset_keeps_identity( +async def test_delete_one_asset_info_of_missing_asset_keeps_identity( root: str, http, api_base: str, @@ -253,7 +255,7 @@ async def test_delete_one_assetinfo_of_missing_asset_keeps_identity( a2 = await asset_factory("copy_" + name, [root, "unit-tests", scope], {}, b"W" * 2048) # Remove file of the first (both point to the same Asset, but we know on-disk path name for a1) - p1 = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / name + p1 = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(a1["asset_hash"], ".bin") p1.unlink() await trigger_sync_seed_assets(http, api_base) @@ -282,7 +284,7 @@ async def test_delete_one_assetinfo_of_missing_asset_keeps_identity( @pytest.mark.asyncio @pytest.mark.parametrize("keep_root", ["input", "output"]) -async def test_delete_last_assetinfo_false_keeps_asset_and_states_multiroot( +async def test_delete_last_asset_info_false_keeps_asset_and_states_multiroot( keep_root: str, http, api_base: str, @@ -293,21 +295,18 @@ async def test_delete_last_assetinfo_false_keeps_asset_and_states_multiroot( """Delete last AssetInfo with delete_content_if_orphan=false keeps asset and the underlying on-disk content.""" other_root = "output" if keep_root == "input" else "input" scope = f"delfalse-{uuid.uuid4().hex[:6]}" - name1, name2 = "keep1.bin", "keep2.bin" data = make_asset_bytes(scope, 3072) # First upload creates the physical file - a1 = await asset_factory(name1, [keep_root, "unit-tests", scope], {}, data) + a1 = await asset_factory("keep1.bin", [keep_root, "unit-tests", scope], {}, data) # Second upload (other root) is deduped to the same content; no new file on disk - a2 = await asset_factory(name2, [other_root, "unit-tests", scope], {}, data) + a2 = await asset_factory("keep2.bin", [other_root, "unit-tests", scope], {}, data) h = a1["asset_hash"] - p1 = _base_for(keep_root, comfy_tmp_base_dir) / "unit-tests" / scope / name1 - p2 = _base_for(other_root, comfy_tmp_base_dir) / "unit-tests" / scope / name2 + p1 = _base_for(keep_root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(h, ".bin") # De-dup semantics: only the first physical file exists assert p1.exists(), "Expected the first physical file to exist" - assert not p2.exists(), "Second duplicate must not create another physical file" # Delete both AssetInfos; keep content on the very last delete async with http.delete(f"{api_base}/api/assets/{a2['id']}") as rfirst: @@ -319,7 +318,6 @@ async def test_delete_last_assetinfo_false_keeps_asset_and_states_multiroot( async with http.head(f"{api_base}/api/assets/hash/{h}") as rh: assert rh.status == 200 assert p1.exists(), "Content file should remain after keep-content delete" - assert not p2.exists(), "There was never a second physical file" # Cleanup: re-create a reference by hash and then delete to purge content payload = { @@ -489,7 +487,7 @@ async def test_cache_state_retarget_on_content_change_asset_info_stays( aid = a1["id"] h1 = a1["asset_hash"] - p = comfy_tmp_base_dir / root / "unit-tests" / scope / name + p = comfy_tmp_base_dir / root / "unit-tests" / scope / get_asset_filename(a1["asset_hash"], ".bin") assert p.exists() # Change the bytes in place to force a new content hash (H2) diff --git a/tests-assets/test_uploads.py b/tests-assets/test_uploads.py index de318d0576cb..f1b116c1aca2 100644 --- a/tests-assets/test_uploads.py +++ b/tests-assets/test_uploads.py @@ -288,7 +288,7 @@ async def test_upload_tags_traversal_guard(http: aiohttp.ClientSession, api_base @pytest.mark.asyncio @pytest.mark.parametrize("root", ["input", "output"]) -async def test_duplicate_upload_same_path_updates_state( +async def test_duplicate_upload_same_display_name_does_not_clobber( root: str, http, api_base: str, @@ -296,30 +296,30 @@ async def test_duplicate_upload_same_path_updates_state( make_asset_bytes, ): """ - Two uploads target the exact same destination path (same tags+name) with different bytes. - Expect: file on disk is from the last upload; its AssetInfo serves content; the first AssetInfo's content 404s. - This validates that AssetCacheState(file_path) remains unique and its asset_id/mtime_ns were updated. + Two uploads use the same tags and the same display name but different bytes. + With hash-based filenames, they must NOT overwrite each other. Both assets + remain accessible and serve their original content. """ scope = f"dup-path-{uuid.uuid4().hex[:6]}" - name = "same_path.bin" + display_name = "same_display.bin" d1 = make_asset_bytes(scope + "-v1", 1536) d2 = make_asset_bytes(scope + "-v2", 2048) tags = [root, "unit-tests", scope] - first = await asset_factory(name, tags, {}, d1) - second = await asset_factory(name, tags, {}, d2) + first = await asset_factory(display_name, tags, {}, d1) + second = await asset_factory(display_name, tags, {}, d2) - # Second one must serve the new bytes + assert first["id"] != second["id"] + assert first["asset_hash"] != second["asset_hash"] # different content + assert first["name"] == second["name"] == display_name + + # Both must be independently retrievable + async with http.get(f"{api_base}/api/assets/{first['id']}/content") as r1: + b1 = await r1.read() + assert r1.status == 200 + assert b1 == d1 async with http.get(f"{api_base}/api/assets/{second['id']}/content") as r2: b2 = await r2.read() assert r2.status == 200 assert b2 == d2 - - # The first AssetInfo now points to an identity with no live state for that path -> 404 - async with http.get(f"{api_base}/api/assets/{first['id']}/content") as r1: - try: - body = await r1.json() - except Exception: - body = {} - assert r1.status == 404, body From 7becb84341887933cec0f8d9efde25ed6efdb62e Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 14 Sep 2025 22:19:00 +0300 Subject: [PATCH 60/82] fixed tests on SQLite file --- tests-assets/conftest.py | 2 +- tests-assets/test_scans.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests-assets/conftest.py b/tests-assets/conftest.py index 3f31f226efcb..9ee1fa86b113 100644 --- a/tests-assets/conftest.py +++ b/tests-assets/conftest.py @@ -274,7 +274,7 @@ async def trigger_sync_seed_assets(session: aiohttp.ClientSession, base_url: str """Force a fast sync/seed pass by calling the ComfyUI '/object_info' endpoint.""" async with session.get(base_url + "/object_info") as r: await r.read() - await asyncio.sleep(0.05) # tiny yield to the event loop to let any final DB commits flush + await asyncio.sleep(0.1) # tiny yield to the event loop to let any final DB commits flush @pytest_asyncio.fixture diff --git a/tests-assets/test_scans.py b/tests-assets/test_scans.py index fcedce5cf7d4..e82ae5f6d264 100644 --- a/tests-assets/test_scans.py +++ b/tests-assets/test_scans.py @@ -85,7 +85,7 @@ async def test_scan_status_filter_by_root_and_file_errors( data = b"A" * 4096 await asset_factory(name_in, ["input", "unit-tests", in_scope, "deny"], {}, data) try: - os.chmod(protected_dir, 0x000) + os.chmod(protected_dir, 0o000) # Also schedule a scan for output root (no errors there) out_scope = f"filter-out-{uuid.uuid4().hex[:6]}" @@ -141,7 +141,7 @@ async def test_scan_records_file_errors_permission_denied( a1 = await asset_factory(name, [root, "unit-tests", scope, "deny"], {}, b"X" * 2048) asset_filename = get_asset_filename(a1["asset_hash"], ".bin") try: - os.chmod(deny_dir, 0x000) + os.chmod(deny_dir, 0o000) async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r: assert r.status == 202 await run_scan_and_wait(root) @@ -157,7 +157,7 @@ async def test_scan_records_file_errors_permission_denied( assert any(e.get("path", "").endswith(asset_filename) and e.get("message") for e in errs) finally: try: - os.chmod(deny_dir, 0x755) + os.chmod(deny_dir, 0o755) except Exception: pass @@ -402,7 +402,7 @@ async def test_sync_seed_nested_dirs_produce_parent_tags( # nested: unit-tests / scope / a / b / c / deep.txt deep_dir = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / "a" / "b" / "c" deep_dir.mkdir(parents=True, exist_ok=True) - (deep_dir / "deep.txt").write_bytes(b"content") + (deep_dir / "deep.txt").write_bytes(scope.encode()) await trigger_sync_seed_assets(http, api_base) From 025fc49b4e667f53ece6978b61d9d8bf093983ee Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 15 Sep 2025 10:26:13 +0300 Subject: [PATCH 61/82] optimization: DB Queries (Tags) --- app/database/helpers/tags.py | 125 +++++++++++++++++------------------ 1 file changed, 59 insertions(+), 66 deletions(-) diff --git a/app/database/helpers/tags.py b/app/database/helpers/tags.py index b3e006c0e80a..058869eca01c 100644 --- a/app/database/helpers/tags.py +++ b/app/database/helpers/tags.py @@ -1,6 +1,6 @@ from typing import Iterable -from sqlalchemy import delete, select +import sqlalchemy as sa from sqlalchemy.dialects import postgresql as d_pg from sqlalchemy.dialects import sqlite as d_sqlite from sqlalchemy.ext.asyncio import AsyncSession @@ -10,34 +10,27 @@ from ..timeutil import utcnow -async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: +async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> None: wanted = normalize_tags(list(names)) if not wanted: - return [] - existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() - existing_names = {t.name for t in existing} - missing = [n for n in wanted if n not in existing_names] - if missing: - dialect = session.bind.dialect.name - rows = [{"name": n, "tag_type": tag_type} for n in missing] - if dialect == "sqlite": - ins = ( - d_sqlite.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - elif dialect == "postgresql": - ins = ( - d_pg.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - else: - raise NotImplementedError(f"Unsupported database dialect: {dialect}") - await session.execute(ins) - existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() - by_name = {t.name: t for t in existing} - return [by_name[n] for n in wanted if n in by_name] + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + dialect = session.bind.dialect.name + if dialect == "sqlite": + ins = ( + d_sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + await session.execute(ins) async def add_missing_tag_for_asset_id( @@ -45,53 +38,53 @@ async def add_missing_tag_for_asset_id( *, asset_id: str, origin: str = "automatic", -) -> int: - """Ensure every AssetInfo for asset_id has 'missing' tag.""" - ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all() - if not ids: - return 0 - - existing = { - asset_info_id - for (asset_info_id,) in ( - await session.execute( - select(AssetInfoTag.asset_info_id).where( - AssetInfoTag.asset_info_id.in_(ids), - AssetInfoTag.tag_name == "missing", - ) +) -> None: + select_rows = ( + sa.select( + AssetInfo.id.label("asset_info_id"), + sa.literal("missing").label("tag_name"), + sa.literal(origin).label("origin"), + sa.literal(utcnow()).label("added_at"), + ) + .where(AssetInfo.asset_id == asset_id) + .where( + sa.not_( + sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing")) ) - ).all() - } - to_add = [i for i in ids if i not in existing] - if not to_add: - return 0 - - now = utcnow() - session.add_all( - [ - AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now) - for i in to_add - ] + ) ) - await session.flush() - return len(to_add) + dialect = session.bind.dialect.name + if dialect == "sqlite": + ins = ( + d_sqlite.insert(AssetInfoTag) + .from_select( + ["asset_info_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(AssetInfoTag) + .from_select( + ["asset_info_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + await session.execute(ins) async def remove_missing_tag_for_asset_id( session: AsyncSession, *, asset_id: str, -) -> int: - """Remove the 'missing' tag from all AssetInfos for asset_id.""" - ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all() - if not ids: - return 0 - - res = await session.execute( - delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(ids), +) -> None: + await session.execute( + sa.delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), AssetInfoTag.tag_name == "missing", ) ) - await session.flush() - return int(res.rowcount or 0) From 5f187fe6fb81f8131ce1a09e7d100cbc3274d3c7 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 15 Sep 2025 12:46:35 +0300 Subject: [PATCH 62/82] optimization: make list_unhashed_candidates_under_prefixes single-query instead of N+1 --- app/database/services/content.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/app/database/services/content.py b/app/database/services/content.py index ead2e2389276..33b74a2dc986 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -408,9 +408,7 @@ async def compute_hash_and_dedup_for_cache_state( raise -async def list_unhashed_candidates_under_prefixes( - session: AsyncSession, *, prefixes: Sequence[str] -) -> list[int]: +async def list_unhashed_candidates_under_prefixes(session: AsyncSession, *, prefixes: list[str]) -> list[int]: if not prefixes: return [] @@ -421,23 +419,25 @@ async def list_unhashed_candidates_under_prefixes( base += os.sep conds.append(AssetCacheState.file_path.like(base + "%")) - rows = ( - await session.execute( + path_filter = sa.or_(*conds) if len(conds) > 1 else conds[0] + if session.bind.dialect.name == "postgresql": + stmt = ( sa.select(AssetCacheState.id) .join(Asset, Asset.id == AssetCacheState.asset_id) - .where(Asset.hash.is_(None)) - .where(sa.or_(*conds)) + .where(Asset.hash.is_(None), path_filter) .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) + .distinct(AssetCacheState.asset_id) ) - ).scalars().all() - seen = set() - result: list[int] = [] - for sid in rows: - st = await session.get(AssetCacheState, sid) - if st and st.asset_id not in seen: - seen.add(st.asset_id) - result.append(sid) - return result + else: + first_id = sa.func.min(AssetCacheState.id).label("first_id") + stmt = ( + sa.select(first_id) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(Asset.hash.is_(None), path_filter) + .group_by(AssetCacheState.asset_id) + .order_by(first_id.asc()) + ) + return [int(x) for x in (await session.execute(stmt)).scalars().all()] async def list_verify_candidates_under_prefixes( From f3cf99d10c6f7e37b1759cdf2f49dc06df92037d Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 15 Sep 2025 17:29:27 +0300 Subject: [PATCH 63/82] fix+test: escape "_" symbol in tags filtering --- app/database/helpers/__init__.py | 2 ++ app/database/helpers/escape_like.py | 7 +++++++ app/database/services/info.py | 7 +++++-- tests-assets/test_tags.py | 25 +++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 app/database/helpers/escape_like.py diff --git a/app/database/helpers/__init__.py b/app/database/helpers/__init__.py index 310583607c80..5a8e87b2c1da 100644 --- a/app/database/helpers/__init__.py +++ b/app/database/helpers/__init__.py @@ -1,3 +1,4 @@ +from .escape_like import escape_like_prefix from .filters import apply_metadata_filter, apply_tag_filters from .ownership import visible_owner_clause from .projection import is_scalar, project_kv @@ -10,6 +11,7 @@ __all__ = [ "apply_tag_filters", "apply_metadata_filter", + "escape_like_prefix", "is_scalar", "project_kv", "ensure_tags_exist", diff --git a/app/database/helpers/escape_like.py b/app/database/helpers/escape_like.py new file mode 100644 index 000000000000..f905bd40b52c --- /dev/null +++ b/app/database/helpers/escape_like.py @@ -0,0 +1,7 @@ +def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]: + """Escapes %, _ and the escape char itself in a LIKE prefix. + Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like(). + """ + s = s.replace(escape, escape + escape) # escape the escape char first + s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards + return s, escape diff --git a/app/database/services/info.py b/app/database/services/info.py index d2fd1f503f33..13af76ea0819 100644 --- a/app/database/services/info.py +++ b/app/database/services/info.py @@ -13,6 +13,7 @@ apply_metadata_filter, apply_tag_filters, ensure_tags_exist, + escape_like_prefix, project_kv, visible_owner_clause, ) @@ -527,7 +528,8 @@ async def list_tags_with_usage( ) if prefix: - q = q.where(Tag.name.like(prefix.strip().lower() + "%")) + escaped, esc = escape_like_prefix(prefix.strip().lower()) + q = q.where(Tag.name.like(escaped + "%", escape=esc)) if not include_zero: q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) @@ -539,7 +541,8 @@ async def list_tags_with_usage( total_q = select(func.count()).select_from(Tag) if prefix: - total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%")) + escaped, esc = escape_like_prefix(prefix.strip().lower()) + total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) if not include_zero: total_q = total_q.where( Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) diff --git a/tests-assets/test_tags.py b/tests-assets/test_tags.py index 9ad3c3f860a6..9bdf770c4c54 100644 --- a/tests-assets/test_tags.py +++ b/tests-assets/test_tags.py @@ -201,3 +201,28 @@ async def test_tags_endpoints_invalid_bodies(http: aiohttp.ClientSession, api_ba b3 = await r3.json() assert r3.status == 400 assert b3["error"]["code"] == "INVALID_QUERY" + + +@pytest.mark.asyncio +async def test_tags_prefix_treats_underscore_literal( + http, + api_base, + asset_factory, + make_asset_bytes, +): + """'prefix' for /api/tags must treat '_' literally, not as a wildcard.""" + base = f"pref_{uuid.uuid4().hex[:6]}" + tag_ok = f"{base}_ok" # should match prefix=f"{base}_" + tag_bad = f"{base}xok" # must NOT match if '_' is escaped + scope = f"tags-underscore-{uuid.uuid4().hex[:6]}" + + await asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512)) + await asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512)) + + async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}) as r: + body = await r.json() + assert r.status == 200, body + names = [t["name"] for t in body["tags"]] + assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'" + assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard" + assert body["total"] == 1 From f1fb7432a073fc07d64095070d5c78c34dd1556d Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 15 Sep 2025 19:19:47 +0300 Subject: [PATCH 64/82] fix+test: escape "_" symbol in assets filtering --- app/database/services/info.py | 6 ++++-- tests-assets/test_list_filter.py | 34 ++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/app/database/services/info.py b/app/database/services/info.py index 13af76ea0819..7583383683ee 100644 --- a/app/database/services/info.py +++ b/app/database/services/info.py @@ -47,7 +47,8 @@ async def list_asset_infos_page( ) if name_contains: - base = base.where(AssetInfo.name.ilike(f"%{name_contains}%")) + escaped, esc = escape_like_prefix(name_contains) + base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) base = apply_tag_filters(base, include_tags, exclude_tags) base = apply_metadata_filter(base, metadata_filter) @@ -73,7 +74,8 @@ async def list_asset_infos_page( .where(visible_owner_clause(owner_id)) ) if name_contains: - count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%")) + escaped, esc = escape_like_prefix(name_contains) + count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = apply_metadata_filter(count_stmt, metadata_filter) diff --git a/tests-assets/test_list_filter.py b/tests-assets/test_list_filter.py index b0b476af5bf4..835de0367058 100644 --- a/tests-assets/test_list_filter.py +++ b/tests-assets/test_list_filter.py @@ -1,4 +1,5 @@ import asyncio +import uuid import aiohttp import pytest @@ -301,3 +302,36 @@ async def test_list_assets_invalid_query_rejected(http: aiohttp.ClientSession, a b2 = await r2.json() assert r2.status == 400 assert b2["error"]["code"] == "INVALID_QUERY" + + +@pytest.mark.asyncio +async def test_list_assets_name_contains_literal_underscore( + http, + api_base, + asset_factory, + make_asset_bytes, +): + """'name_contains' must treat '_' literally, not as a SQL wildcard. + We create: + - foo_bar.safetensors (should match) + - fooxbar.safetensors (must NOT match if '_' is escaped) + - foobar.safetensors (must NOT match) + """ + scope = f"lf-underscore-{uuid.uuid4().hex[:6]}" + tags = ["models", "checkpoints", "unit-tests", scope] + + a = await asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700)) + b = await asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700)) + c = await asset_factory("foobar.safetensors", tags, {}, make_asset_bytes("c", 700)) + + async with http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": "foo_bar"}, + ) as r: + body = await r.json() + assert r.status == 200, body + names = [x["name"] for x in body["assets"]] + assert a["name"] in names, f"Expected literal underscore match to include {a['name']}" + assert b["name"] not in names, "Underscore must be escaped — should not match 'fooxbar'" + assert c["name"] not in names, "Underscore must be escaped — should not match 'foobar'" + assert body["total"] == 1 From 0be513b213bf8d9d0164987f21d7476eec39c9fd Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 15 Sep 2025 20:26:48 +0300 Subject: [PATCH 65/82] fix: escape "_" symbol in all other places --- app/assets_scanner.py | 12 +++++++----- app/database/services/content.py | 10 +++++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 738b4e7814b7..29c77a7c6b31 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -22,6 +22,7 @@ from .database.db import create_session from .database.helpers import ( add_missing_tag_for_asset_id, + escape_like_prefix, remove_missing_tag_for_asset_id, ) from .database.models import Asset, AssetCacheState, AssetInfo @@ -160,7 +161,7 @@ def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.A ) -async def _refresh_verify_flags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None: +async def _refresh_verify_flags_for_root(root: schemas_in.RootType) -> None: """Fast pass to mark verify candidates by comparing stored mtime_ns with on-disk mtime.""" prefixes = prefixes_for_root(root) if not prefixes: @@ -171,7 +172,8 @@ async def _refresh_verify_flags_for_root(root: schemas_in.RootType, prog: ScanPr base = os.path.abspath(p) if not base.endswith(os.sep): base += os.sep - conds.append(AssetCacheState.file_path.like(base + "%")) + escaped, esc = escape_like_prefix(base) + conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) async with await create_session() as sess: rows = ( @@ -227,7 +229,7 @@ async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgres try: prefixes = prefixes_for_root(root) - await _refresh_verify_flags_for_root(root, prog) + await _refresh_verify_flags_for_root(root) # collect candidates from DB async with await create_session() as sess: @@ -419,7 +421,8 @@ async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: base = os.path.abspath(p) if not base.endswith(os.sep): base += os.sep - conds.append(AssetCacheState.file_path.like(base + "%")) + escaped, esc = escape_like_prefix(base) + conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) async with await create_session() as sess: if not conds: @@ -443,7 +446,6 @@ async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} by_asset[aid] = acc - exists = False fast_ok = False try: s = os.stat(st.file_path, follow_symlinks=True) diff --git a/app/database/services/content.py b/app/database/services/content.py index 33b74a2dc986..0b33e1eee357 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -16,6 +16,7 @@ from ...storage import hashing as hashing_mod from ..helpers import ( ensure_tags_exist, + escape_like_prefix, remove_missing_tag_for_asset_id, ) from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag @@ -417,7 +418,8 @@ async def list_unhashed_candidates_under_prefixes(session: AsyncSession, *, pref base = os.path.abspath(p) if not base.endswith(os.sep): base += os.sep - conds.append(AssetCacheState.file_path.like(base + "%")) + escaped, esc = escape_like_prefix(base) + conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) path_filter = sa.or_(*conds) if len(conds) > 1 else conds[0] if session.bind.dialect.name == "postgresql": @@ -450,7 +452,8 @@ async def list_verify_candidates_under_prefixes( base = os.path.abspath(p) if not base.endswith(os.sep): base += os.sep - conds.append(AssetCacheState.file_path.like(base + "%")) + escaped, esc = escape_like_prefix(base) + conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) return ( await session.execute( @@ -756,7 +759,8 @@ async def list_cache_states_with_asset_under_prefixes( base = os.path.abspath(p) if not base.endswith(os.sep): base = base + os.sep - conds.append(AssetCacheState.file_path.like(base + "%")) + escaped, esc = escape_like_prefix(base) + conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) if not conds: return [] From 24a95f5ca4590893bc0e1d35d483e4cd04c22ada Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 16 Sep 2025 11:28:29 +0300 Subject: [PATCH 66/82] removed default scanning of "input" and "output" folders; added separate endpoint for test suite. --- app/api/assets_routes.py | 20 ++++++++++++++++++++ main.py | 2 +- server.py | 2 +- tests-assets/conftest.py | 2 +- 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 3c3ea5d257d3..6bb0ed77e27b 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -483,6 +483,26 @@ async def delete_asset_tags(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) +@ROUTES.post("/api/assets/scan/seed") +async def seed_assets(request: web.Request) -> web.Response: + try: + payload = await request.json() + except Exception: + payload = {} + + try: + body = schemas_in.ScheduleAssetScanBody.model_validate(payload) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + + try: + await assets_scanner.sync_seed_assets(body.roots) + except Exception: + LOGGER.exception("sync_seed_assets failed for roots=%s", body.roots) + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response({"synced": True, "roots": body.roots}, status=200) + + @ROUTES.post("/api/assets/scan/schedule") async def schedule_asset_scan(request: web.Request) -> web.Response: try: diff --git a/main.py b/main.py index db0ee04f5059..18c97e5e1933 100644 --- a/main.py +++ b/main.py @@ -283,7 +283,7 @@ async def setup_database(): await init_db_engine() if not args.disable_assets_autoscan: - await sync_seed_assets(["models", "input", "output"]) + await sync_seed_assets(["models"]) def start_comfyui(asyncio_loop=None): diff --git a/server.py b/server.py index c3a688a75499..ddd188ebcb78 100644 --- a/server.py +++ b/server.py @@ -629,7 +629,7 @@ def node_info(node_class): @routes.get("/object_info") async def get_object_info(request): - await sync_seed_assets(["models", "input", "output"]) + await sync_seed_assets(["models"]) with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: diff --git a/tests-assets/conftest.py b/tests-assets/conftest.py index 9ee1fa86b113..7d1ea5acb570 100644 --- a/tests-assets/conftest.py +++ b/tests-assets/conftest.py @@ -272,7 +272,7 @@ async def autoclean_unit_test_assets(http: aiohttp.ClientSession, api_base: str) async def trigger_sync_seed_assets(session: aiohttp.ClientSession, base_url: str) -> None: """Force a fast sync/seed pass by calling the ComfyUI '/object_info' endpoint.""" - async with session.get(base_url + "/object_info") as r: + async with session.post(base_url + "/api/assets/scan/seed", json={"roots": ["models", "input", "output"]}) as r: await r.read() await asyncio.sleep(0.1) # tiny yield to the event loop to let any final DB commits flush From 77332d30549845eb5dc9e91de0473f052dd657f6 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 16 Sep 2025 14:21:40 +0300 Subject: [PATCH 67/82] optimization: fast scan: commit to the DB in chunks --- app/assets_scanner.py | 100 ++++++++++++++++--------------- app/database/services/content.py | 1 + 2 files changed, 53 insertions(+), 48 deletions(-) diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 29c77a7c6b31..1863cff924c5 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import logging import os import time @@ -95,45 +96,55 @@ async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetS async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: - for r in roots: - try: - await _fast_db_consistency_pass(r) - except Exception as ex: - LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex) - - paths: list[str] = [] - if "models" in roots: - paths.extend(collect_models_files()) - if "input" in roots: - paths.extend(list_tree(folder_paths.get_input_directory())) - if "output" in roots: - paths.extend(list_tree(folder_paths.get_output_directory())) - - for p in paths: - try: - st = os.stat(p, follow_symlinks=True) - if not int(st.st_size or 0): - continue - size_bytes = int(st.st_size) - mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) - name, tags = get_name_and_tags_from_asset_path(p) - await _seed_one_async(p, size_bytes, mtime_ns, name, tags) - except OSError: - continue - - -async def _seed_one_async(p: str, size_bytes: int, mtime_ns: int, name: str, tags: list[str]) -> None: - async with await create_session() as sess: - await ensure_seed_for_path( - sess, - abs_path=p, - size_bytes=size_bytes, - mtime_ns=mtime_ns, - info_name=name, - tags=tags, - owner_id="", + t_total = time.perf_counter() + try: + for r in roots: + try: + await _fast_db_consistency_pass(r) + except Exception as ex: + LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex) + + paths: list[str] = [] + if "models" in roots: + paths.extend(collect_models_files()) + if "input" in roots: + paths.extend(list_tree(folder_paths.get_input_directory())) + if "output" in roots: + paths.extend(list_tree(folder_paths.get_output_directory())) + + processed = 0 + async with await create_session() as sess: + for p in paths: + try: + st = os.stat(p, follow_symlinks=True) + if not int(st.st_size or 0): + continue + size_bytes = int(st.st_size) + mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + name, tags = get_name_and_tags_from_asset_path(p) + + await ensure_seed_for_path( + sess, + abs_path=p, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + info_name=name, + tags=tags, + owner_id="", + ) + + processed += 1 + if processed % 500 == 0: + await sess.commit() + except OSError: + continue + await sess.commit() + finally: + LOGGER.info( + "Assets scan(roots=%s) completed in %.3f s", + roots, + time.perf_counter() - t_total, ) - await sess.commit() def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse: @@ -482,20 +493,13 @@ async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: if any_fast_ok: # Remove 'missing' and delete just the stale state rows for st in missing_states: - try: + with contextlib.suppress(Exception): await sess.delete(await sess.get(AssetCacheState, st.id)) - except Exception: - pass - try: + with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(sess, asset_id=aid) - except Exception: - pass else: - # No fast-ok path: mark as missing - try: + with contextlib.suppress(Exception): await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") - except Exception: - pass await sess.flush() await sess.commit() diff --git a/app/database/services/content.py b/app/database/services/content.py index 0b33e1eee357..a8ce200d1422 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -87,6 +87,7 @@ async def ensure_seed_for_path( state_row.needs_verify = True if asset_row.size_bytes == 0 and size_bytes > 0: asset_row.size_bytes = int(size_bytes) + await session.flush() return asset_row.id asset = Asset(hash=None, size_bytes=int(size_bytes), mime_type=None, created_at=now) From a336c7c165c925a62fc776d928487cba35864f92 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 16 Sep 2025 19:19:18 +0300 Subject: [PATCH 68/82] refactor(1): use general fast_asset_file_check helper for fast check --- app/assets_scanner.py | 60 +++++++++++++++--------------- app/database/helpers/__init__.py | 2 + app/database/helpers/fast_check.py | 19 ++++++++++ 3 files changed, 52 insertions(+), 29 deletions(-) create mode 100644 app/database/helpers/fast_check.py diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 1863cff924c5..5b5b19a31671 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -24,6 +24,7 @@ from .database.helpers import ( add_missing_tag_for_asset_id, escape_like_prefix, + fast_asset_file_check, remove_missing_tag_for_asset_id, ) from .database.models import Asset, AssetCacheState, AssetInfo @@ -194,6 +195,7 @@ async def _refresh_verify_flags_for_root(root: schemas_in.RootType) -> None: AssetCacheState.mtime_ns, AssetCacheState.needs_verify, Asset.hash, + Asset.size_bytes, AssetCacheState.file_path, ) .join(Asset, Asset.id == AssetCacheState.asset_id) @@ -203,22 +205,18 @@ async def _refresh_verify_flags_for_root(root: schemas_in.RootType) -> None: to_set = [] to_clear = [] - for sid, mtime_db, needs_verify, a_hash, fp in rows: + for sid, mtime_db, needs_verify, a_hash, size_db, fp in rows: try: st = os.stat(fp, follow_symlinks=True) except OSError: - # Missing files are handled by missing-tag reconciliation later. - continue + continue # Missing files are handled by missing-tag reconciliation later. - actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) if a_hash is not None: - if mtime_db is None or int(mtime_db) != int(actual_mtime_ns): - if not needs_verify: - to_set.append(sid) - else: + if fast_asset_file_check(mtime_db=mtime_db, size_db=size_db, stat_result=st): if needs_verify: to_clear.append(sid) - + elif not needs_verify: + to_set.append(sid) if to_set: await sess.execute( sa.update(AssetCacheState) @@ -306,15 +304,10 @@ async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: Scan acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)} by_asset[aid] = acc try: - st = os.stat(state.file_path, follow_symlinks=True) - actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) - fast_ok = False if acc["hashed"]: - if state.mtime_ns is not None and int(state.mtime_ns) == int(actual_mtime_ns): - if int(acc["size_db"]) > 0 and int(st.st_size) == int(acc["size_db"]): - fast_ok = True - if fast_ok: - acc["any_fast_ok_here"] = True + st = os.stat(state.file_path, follow_symlinks=True) + if fast_asset_file_check(mtime_db=state.mtime_ns, size_db=acc["size_db"], stat_result=st): + acc["any_fast_ok_here"] = True except FileNotFoundError: pass except OSError as e: @@ -333,12 +326,11 @@ async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: Scan others = await list_cache_states_by_asset_id(sess, asset_id=aid) for st in others: try: - s = os.stat(st.file_path, follow_symlinks=True) - actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000)) - if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns): - if acc["size_db"] > 0 and int(s.st_size) == acc["size_db"]: - any_fast_ok_global = True - break + any_fast_ok_global = fast_asset_file_check( + mtime_db=st.mtime_ns, + size_db=acc["size_db"], + stat_result=os.stat(st.file_path, follow_symlinks=True), + ) except OSError: continue @@ -459,12 +451,12 @@ async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: fast_ok = False try: - s = os.stat(st.file_path, follow_symlinks=True) + fast_ok = fast_asset_file_check( + mtime_db=st.mtime_ns, + size_db=acc["size_db"], + stat_result=os.stat(st.file_path, follow_symlinks=True), + ) exists = True - actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000)) - if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns): - if acc["size_db"] == 0 or int(s.st_size) == acc["size_db"]: - fast_ok = True except FileNotFoundError: exists = False except OSError as ex: @@ -474,6 +466,7 @@ async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: acc["states"].append({"obj": st, "exists": exists, "fast_ok": fast_ok}) # Apply actions + to_set_verify: list[int] = [] for aid, acc in by_asset.items(): a_hash = acc["hash"] states = acc["states"] @@ -500,6 +493,15 @@ async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: else: with contextlib.suppress(Exception): await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") - + for s in states: + if s["exists"] and not s["fast_ok"]: + to_set_verify.append(s["obj"].id) + await sess.flush() + if to_set_verify: + await sess.execute( + sa.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_set_verify)) + .values(needs_verify=True) + ) await sess.flush() await sess.commit() diff --git a/app/database/helpers/__init__.py b/app/database/helpers/__init__.py index 5a8e87b2c1da..8119f72e9a6f 100644 --- a/app/database/helpers/__init__.py +++ b/app/database/helpers/__init__.py @@ -1,4 +1,5 @@ from .escape_like import escape_like_prefix +from .fast_check import fast_asset_file_check from .filters import apply_metadata_filter, apply_tag_filters from .ownership import visible_owner_clause from .projection import is_scalar, project_kv @@ -12,6 +13,7 @@ "apply_tag_filters", "apply_metadata_filter", "escape_like_prefix", + "fast_asset_file_check", "is_scalar", "project_kv", "ensure_tags_exist", diff --git a/app/database/helpers/fast_check.py b/app/database/helpers/fast_check.py new file mode 100644 index 000000000000..940d6984f535 --- /dev/null +++ b/app/database/helpers/fast_check.py @@ -0,0 +1,19 @@ +import os +from typing import Optional + + +def fast_asset_file_check( + *, + mtime_db: Optional[int], + size_db: Optional[int], + stat_result: os.stat_result, +) -> bool: + if mtime_db is None: + return False + actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)) + if int(mtime_db) != int(actual_mtime_ns): + return False + sz = int(size_db or 0) + if sz > 0: + return int(stat_result.st_size) == sz + return True From 31ec744317bcacf1b2a8198a5f58918bbb6c802f Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 16 Sep 2025 19:50:21 +0300 Subject: [PATCH 69/82] refactor(2)/fix: skip double checking the existing files during fast check --- app/assets_scanner.py | 45 +++++++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 5b5b19a31671..967bc64bce7a 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -98,14 +98,19 @@ async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetS async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: t_total = time.perf_counter() + created = 0 + skipped_existing = 0 + paths: list[str] = [] try: + existing_paths: set[str] = set() for r in roots: try: - await _fast_db_consistency_pass(r) + survivors = await _fast_db_consistency_pass(r, collect_existing_paths=True) + if survivors: + existing_paths.update(survivors) except Exception as ex: LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex) - paths: list[str] = [] if "models" in roots: paths.extend(collect_models_files()) if "input" in roots: @@ -113,10 +118,12 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: if "output" in roots: paths.extend(list_tree(folder_paths.get_output_directory())) - processed = 0 async with await create_session() as sess: for p in paths: try: + if os.path.abspath(p) in existing_paths: + skipped_existing += 1 + continue st = os.stat(p, follow_symlinks=True) if not int(st.st_size or 0): continue @@ -134,17 +141,20 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: owner_id="", ) - processed += 1 - if processed % 500 == 0: + created += 1 + if created % 500 == 0: await sess.commit() except OSError: continue await sess.commit() finally: LOGGER.info( - "Assets scan(roots=%s) completed in %.3f s", + "Assets scan(roots=%s) completed in %.3f s (created=%d, skipped_existing=%d, total_seen=%d)", roots, time.perf_counter() - t_total, + created, + skipped_existing, + len(paths), ) @@ -406,7 +416,9 @@ def _append_error(prog: ScanProgress, *, path: str, message: str) -> None: }) -async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: +async def _fast_db_consistency_pass( + root: schemas_in.RootType, *, collect_existing_paths: bool = False +) -> Optional[set[str]]: """ Quick pass over asset_cache_state for `root`: - If file missing and Asset.hash is NULL and the Asset has no other states, delete the Asset and its infos. @@ -414,10 +426,12 @@ async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: * If at least one state for this Asset is fast-ok, delete the missing state. * If none are fast-ok, add 'missing' tag to all AssetInfos for this Asset. - If at least one state becomes fast-ok for a hashed Asset, remove the 'missing' tag. + When collect_existing_paths is True, returns a set of absolute file paths + that still have a live asset_cache_state row for this root after reconciliation. """ prefixes = prefixes_for_root(root) if not prefixes: - return + return set() if collect_existing_paths else None conds = [] for p in prefixes: @@ -429,7 +443,7 @@ async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: async with await create_session() as sess: if not conds: - return + return set() if collect_existing_paths else None rows = ( await sess.execute( @@ -467,6 +481,7 @@ async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: # Apply actions to_set_verify: list[int] = [] + survivors: set[str] = set() for aid, acc in by_asset.items(): a_hash = acc["hash"] states = acc["states"] @@ -481,7 +496,10 @@ async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: asset = await sess.get(Asset, aid) if asset: await sess.delete(asset) - # else leave it for the slow scan to verify/rehash + else: + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["obj"].file_path)) else: if any_fast_ok: # Remove 'missing' and delete just the stale state rows @@ -490,9 +508,15 @@ async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: await sess.delete(await sess.get(AssetCacheState, st.id)) with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(sess, asset_id=aid) + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["obj"].file_path)) else: with contextlib.suppress(Exception): await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["obj"].file_path)) for s in states: if s["exists"] and not s["fast_ok"]: to_set_verify.append(s["obj"].id) @@ -505,3 +529,4 @@ async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: ) await sess.flush() await sess.commit() + return survivors if collect_existing_paths else None From 677a0e2508c46fedf9d3b759fead96761328647b Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 16 Sep 2025 20:29:50 +0300 Subject: [PATCH 70/82] refactor(3): unite logic for Asset fast check --- app/assets_scanner.py | 178 ++++++++++++++++-------------------------- 1 file changed, 69 insertions(+), 109 deletions(-) diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 967bc64bce7a..27bc52a25009 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -105,7 +105,7 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: existing_paths: set[str] = set() for r in roots: try: - survivors = await _fast_db_consistency_pass(r, collect_existing_paths=True) + survivors = await _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True) if survivors: existing_paths.update(survivors) except Exception as ex: @@ -183,72 +183,13 @@ def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.A ) -async def _refresh_verify_flags_for_root(root: schemas_in.RootType) -> None: - """Fast pass to mark verify candidates by comparing stored mtime_ns with on-disk mtime.""" - prefixes = prefixes_for_root(root) - if not prefixes: - return - - conds = [] - for p in prefixes: - base = os.path.abspath(p) - if not base.endswith(os.sep): - base += os.sep - escaped, esc = escape_like_prefix(base) - conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) - - async with await create_session() as sess: - rows = ( - await sess.execute( - sa.select( - AssetCacheState.id, - AssetCacheState.mtime_ns, - AssetCacheState.needs_verify, - Asset.hash, - Asset.size_bytes, - AssetCacheState.file_path, - ) - .join(Asset, Asset.id == AssetCacheState.asset_id) - .where(sa.or_(*conds)) - ) - ).all() - - to_set = [] - to_clear = [] - for sid, mtime_db, needs_verify, a_hash, size_db, fp in rows: - try: - st = os.stat(fp, follow_symlinks=True) - except OSError: - continue # Missing files are handled by missing-tag reconciliation later. - - if a_hash is not None: - if fast_asset_file_check(mtime_db=mtime_db, size_db=size_db, stat_result=st): - if needs_verify: - to_clear.append(sid) - elif not needs_verify: - to_set.append(sid) - if to_set: - await sess.execute( - sa.update(AssetCacheState) - .where(AssetCacheState.id.in_(to_set)) - .values(needs_verify=True) - ) - if to_clear: - await sess.execute( - sa.update(AssetCacheState) - .where(AssetCacheState.id.in_(to_clear)) - .values(needs_verify=False) - ) - await sess.commit() - - async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None: prog.status = "running" prog.started_at = time.time() try: prefixes = prefixes_for_root(root) - await _refresh_verify_flags_for_root(root) + await _fast_db_consistency_pass(root) # collect candidates from DB async with await create_session() as sess: @@ -417,17 +358,17 @@ def _append_error(prog: ScanProgress, *, path: str, message: str) -> None: async def _fast_db_consistency_pass( - root: schemas_in.RootType, *, collect_existing_paths: bool = False + root: schemas_in.RootType, + *, + collect_existing_paths: bool = False, + update_missing_tags: bool = False, ) -> Optional[set[str]]: - """ - Quick pass over asset_cache_state for `root`: - - If file missing and Asset.hash is NULL and the Asset has no other states, delete the Asset and its infos. - - If file missing and Asset.hash is NOT NULL: - * If at least one state for this Asset is fast-ok, delete the missing state. - * If none are fast-ok, add 'missing' tag to all AssetInfos for this Asset. - - If at least one state becomes fast-ok for a hashed Asset, remove the 'missing' tag. - When collect_existing_paths is True, returns a set of absolute file paths - that still have a live asset_cache_state row for this root after reconciliation. + """Fast DB+FS pass for a root: + - Toggle needs_verify per state using fast check + - For hashed assets with at least one fast-ok state in this root: delete stale missing states + - For seed assets with all states missing: delete Asset and its AssetInfos + - Optionally add/remove 'missing' tags based on fast-ok in this root + - Optionally return surviving absolute paths """ prefixes = prefixes_for_root(root) if not prefixes: @@ -442,22 +383,25 @@ async def _fast_db_consistency_pass( conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) async with await create_session() as sess: - if not conds: - return set() if collect_existing_paths else None - rows = ( await sess.execute( - sa.select(AssetCacheState, Asset.hash, Asset.size_bytes) + sa.select( + AssetCacheState.id, + AssetCacheState.file_path, + AssetCacheState.mtime_ns, + AssetCacheState.needs_verify, + AssetCacheState.asset_id, + Asset.hash, + Asset.size_bytes, + ) .join(Asset, Asset.id == AssetCacheState.asset_id) .where(sa.or_(*conds)) .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) ) ).all() - # Group by asset_id with status per state by_asset: dict[str, dict] = {} - for st, a_hash, a_size in rows: - aid = st.asset_id + for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows: acc = by_asset.get(aid) if acc is None: acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} @@ -465,33 +409,46 @@ async def _fast_db_consistency_pass( fast_ok = False try: + exists = True fast_ok = fast_asset_file_check( - mtime_db=st.mtime_ns, + mtime_db=mtime_db, size_db=acc["size_db"], - stat_result=os.stat(st.file_path, follow_symlinks=True), + stat_result=os.stat(fp, follow_symlinks=True), ) - exists = True except FileNotFoundError: exists = False - except OSError as ex: + except OSError: exists = False - LOGGER.debug("fast pass stat error for %s: %s", st.file_path, ex) - acc["states"].append({"obj": st, "exists": exists, "fast_ok": fast_ok}) + acc["states"].append({ + "sid": sid, + "fp": fp, + "exists": exists, + "fast_ok": fast_ok, + "needs_verify": bool(needs_verify), + }) - # Apply actions to_set_verify: list[int] = [] + to_clear_verify: list[int] = [] + stale_state_ids: list[int] = [] survivors: set[str] = set() + for aid, acc in by_asset.items(): a_hash = acc["hash"] states = acc["states"] any_fast_ok = any(s["fast_ok"] for s in states) all_missing = all(not s["exists"] for s in states) - missing_states = [s["obj"] for s in states if not s["exists"]] + + for s in states: + if not s["exists"]: + continue + if s["fast_ok"] and s["needs_verify"]: + to_clear_verify.append(s["sid"]) + if not s["fast_ok"] and not s["needs_verify"]: + to_set_verify.append(s["sid"]) if a_hash is None: - # Seed asset: if all states gone (and in practice there is only one), remove the whole Asset - if states and all_missing: + if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid)) asset = await sess.get(Asset, aid) if asset: @@ -499,34 +456,37 @@ async def _fast_db_consistency_pass( else: for s in states: if s["exists"]: - survivors.add(os.path.abspath(s["obj"].file_path)) - else: - if any_fast_ok: - # Remove 'missing' and delete just the stale state rows - for st in missing_states: - with contextlib.suppress(Exception): - await sess.delete(await sess.get(AssetCacheState, st.id)) + survivors.add(os.path.abspath(s["fp"])) + continue + + if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records + for s in states: + if not s["exists"]: + stale_state_ids.append(s["sid"]) + if update_missing_tags: with contextlib.suppress(Exception): await remove_missing_tag_for_asset_id(sess, asset_id=aid) - for s in states: - if s["exists"]: - survivors.add(os.path.abspath(s["obj"].file_path)) - else: - with contextlib.suppress(Exception): - await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") - for s in states: - if s["exists"]: - survivors.add(os.path.abspath(s["obj"].file_path)) - for s in states: - if s["exists"] and not s["fast_ok"]: - to_set_verify.append(s["obj"].id) - await sess.flush() + elif update_missing_tags: + with contextlib.suppress(Exception): + await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") + + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["fp"])) + + if stale_state_ids: + await sess.execute(sa.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids))) if to_set_verify: await sess.execute( sa.update(AssetCacheState) .where(AssetCacheState.id.in_(to_set_verify)) .values(needs_verify=True) ) - await sess.flush() + if to_clear_verify: + await sess.execute( + sa.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_clear_verify)) + .values(needs_verify=False) + ) await sess.commit() return survivors if collect_existing_paths else None From d0aa64d57b19efc49e9664c7188fa51af9a74169 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 16 Sep 2025 21:18:18 +0300 Subject: [PATCH 71/82] refactor(4): use one query to init DB with all tags for assets --- app/assets_scanner.py | 71 +++++++++++++++++++------------- app/database/services/content.py | 15 +++---- 2 files changed, 51 insertions(+), 35 deletions(-) diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 27bc52a25009..0f199719dc1f 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -23,6 +23,7 @@ from .database.db import create_session from .database.helpers import ( add_missing_tag_for_asset_id, + ensure_tags_exist, escape_like_prefix, fast_asset_file_check, remove_missing_tag_for_asset_id, @@ -118,38 +119,52 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: if "output" in roots: paths.extend(list_tree(folder_paths.get_output_directory())) + new_specs: list[tuple[str, int, int, str, list[str]]] = [] + tag_pool: set[str] = set() + for p in paths: + ap = os.path.abspath(p) + if ap in existing_paths: + skipped_existing += 1 + continue + try: + st = os.stat(p, follow_symlinks=True) + except OSError: + continue + if not int(st.st_size or 0): + continue + name, tags = get_name_and_tags_from_asset_path(ap) + new_specs.append(( + ap, + int(st.st_size), + getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)), + name, + tags, + )) + for t in tags: + tag_pool.add(t) + async with await create_session() as sess: - for p in paths: - try: - if os.path.abspath(p) in existing_paths: - skipped_existing += 1 - continue - st = os.stat(p, follow_symlinks=True) - if not int(st.st_size or 0): - continue - size_bytes = int(st.st_size) - mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) - name, tags = get_name_and_tags_from_asset_path(p) - - await ensure_seed_for_path( - sess, - abs_path=p, - size_bytes=size_bytes, - mtime_ns=mtime_ns, - info_name=name, - tags=tags, - owner_id="", - ) - - created += 1 - if created % 500 == 0: - await sess.commit() - except OSError: - continue + if tag_pool: + await ensure_tags_exist(sess, tag_pool, tag_type="user") + for ap, sz, mt, name, tags in new_specs: + await ensure_seed_for_path( + sess, + abs_path=ap, + size_bytes=sz, + mtime_ns=mt, + info_name=name, + tags=tags, + owner_id="", + skip_tag_ensure=True, + ) + + created += 1 + if created % 500 == 0: + await sess.commit() await sess.commit() finally: LOGGER.info( - "Assets scan(roots=%s) completed in %.3f s (created=%d, skipped_existing=%d, total_seen=%d)", + "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)", roots, time.perf_counter() - t_total, created, diff --git a/app/database/services/content.py b/app/database/services/content.py index a8ce200d1422..298660129edd 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -65,6 +65,7 @@ async def ensure_seed_for_path( info_name: str, tags: Sequence[str], owner_id: str = "", + skip_tag_ensure: bool = False, ) -> str: """Ensure: Asset(hash=NULL), AssetCacheState(file_path), and AssetInfo exist for the path. Returns asset_id.""" locator = os.path.abspath(abs_path) @@ -81,20 +82,20 @@ async def ensure_seed_for_path( if state: state_row: AssetCacheState = state[0] asset_row: Asset = state[1] - changed = state_row.mtime_ns is None or int(state_row.mtime_ns) != int(mtime_ns) + changed = state_row.mtime_ns is None or int(state_row.mtime_ns) != mtime_ns if changed: - state_row.mtime_ns = int(mtime_ns) + state_row.mtime_ns = mtime_ns state_row.needs_verify = True if asset_row.size_bytes == 0 and size_bytes > 0: - asset_row.size_bytes = int(size_bytes) + asset_row.size_bytes = size_bytes await session.flush() return asset_row.id - asset = Asset(hash=None, size_bytes=int(size_bytes), mime_type=None, created_at=now) + asset = Asset(hash=None, size_bytes=size_bytes, mime_type=None, created_at=now) session.add(asset) await session.flush() # to get id - cs = AssetCacheState(asset_id=asset.id, file_path=locator, mtime_ns=int(mtime_ns), needs_verify=False) + cs = AssetCacheState(asset_id=asset.id, file_path=locator, mtime_ns=mtime_ns, needs_verify=False) session.add(cs) info = AssetInfo( @@ -120,12 +121,12 @@ async def ensure_seed_for_path( want = normalize_tags(tags) if want: - await ensure_tags_exist(session, want, tag_type="user") + if not skip_tag_ensure: + await ensure_tags_exist(session, want, tag_type="user") session.add_all([ AssetInfoTag(asset_info_id=info.id, tag_name=t, origin="automatic", added_at=now) for t in want ]) - await session.flush() return asset.id From 621faaa19558a7bed70d10837920fd7cbc2e1731 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 17 Sep 2025 10:23:48 +0300 Subject: [PATCH 72/82] refactor(5): use less DB queries to create seed asset --- app/assets_scanner.py | 4 +- app/database/services/__init__.py | 4 +- app/database/services/content.py | 181 ++++++++++++++++++++---------- 3 files changed, 126 insertions(+), 63 deletions(-) diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 0f199719dc1f..5ec1ebe8870e 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -31,11 +31,11 @@ from .database.models import Asset, AssetCacheState, AssetInfo from .database.services import ( compute_hash_and_dedup_for_cache_state, - ensure_seed_for_path, list_cache_states_by_asset_id, list_cache_states_with_asset_under_prefixes, list_unhashed_candidates_under_prefixes, list_verify_candidates_under_prefixes, + seed_from_path, ) LOGGER = logging.getLogger(__name__) @@ -147,7 +147,7 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: if tag_pool: await ensure_tags_exist(sess, tag_pool, tag_type="user") for ap, sz, mt, name, tags in new_specs: - await ensure_seed_for_path( + await seed_from_path( sess, abs_path=ap, size_bytes=sz, diff --git a/app/database/services/__init__.py b/app/database/services/__init__.py index 88e97bfb049f..fae9eb6703b7 100644 --- a/app/database/services/__init__.py +++ b/app/database/services/__init__.py @@ -1,12 +1,12 @@ from .content import ( check_fs_asset_exists_quick, compute_hash_and_dedup_for_cache_state, - ensure_seed_for_path, ingest_fs_asset, list_cache_states_with_asset_under_prefixes, list_unhashed_candidates_under_prefixes, list_verify_candidates_under_prefixes, redirect_all_references_then_delete_asset, + seed_from_path, touch_asset_infos_by_fs_path, ) from .info import ( @@ -49,7 +49,7 @@ "get_asset_tags", "list_tags_with_usage", "set_asset_info_preview", "fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags", # content - "check_fs_asset_exists_quick", "ensure_seed_for_path", + "check_fs_asset_exists_quick", "seed_from_path", "redirect_all_references_then_delete_asset", "compute_hash_and_dedup_for_cache_state", "list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes", diff --git a/app/database/services/content.py b/app/database/services/content.py index 298660129edd..84fa01f01a71 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -1,6 +1,7 @@ import contextlib import logging import os +import uuid from datetime import datetime from typing import Any, Optional, Sequence, Union @@ -19,7 +20,7 @@ escape_like_prefix, remove_missing_tag_for_asset_id, ) -from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag +from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag from ..timeutil import utcnow from .info import replace_asset_info_metadata_projection from .queries import list_cache_states_by_asset_id, pick_best_live_path @@ -56,7 +57,7 @@ async def check_fs_asset_exists_quick( return row is not None -async def ensure_seed_for_path( +async def seed_from_path( session: AsyncSession, *, abs_path: str, @@ -66,69 +67,131 @@ async def ensure_seed_for_path( tags: Sequence[str], owner_id: str = "", skip_tag_ensure: bool = False, -) -> str: - """Ensure: Asset(hash=NULL), AssetCacheState(file_path), and AssetInfo exist for the path. Returns asset_id.""" +) -> None: + """Creates Asset(hash=NULL), AssetCacheState(file_path), and AssetInfo exist for the path.""" locator = os.path.abspath(abs_path) now = utcnow() + dialect = session.bind.dialect.name - state = ( - await session.execute( - sa.select(AssetCacheState, Asset) - .join(Asset, Asset.id == AssetCacheState.asset_id) - .where(AssetCacheState.file_path == locator) - .limit(1) + new_asset_id = str(uuid.uuid4()) + new_info_id = str(uuid.uuid4()) + + # 1) Insert Asset (hash=NULL) – no conflict expected + asset_vals = { + "id": new_asset_id, + "hash": None, + "size_bytes": size_bytes, + "mime_type": None, + "created_at": now, + } + if dialect == "sqlite": + await session.execute(d_sqlite.insert(Asset).values(**asset_vals)) + elif dialect == "postgresql": + await session.execute(d_pg.insert(Asset).values(**asset_vals)) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + + # 2) Try to claim file_path in AssetCacheState. Our concurrency gate. + acs_vals = { + "asset_id": new_asset_id, + "file_path": locator, + "mtime_ns": mtime_ns, + } + if dialect == "sqlite": + ins_state = ( + d_sqlite.insert(AssetCacheState) + .values(**acs_vals) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) ) - ).first() - if state: - state_row: AssetCacheState = state[0] - asset_row: Asset = state[1] - changed = state_row.mtime_ns is None or int(state_row.mtime_ns) != mtime_ns - if changed: - state_row.mtime_ns = mtime_ns - state_row.needs_verify = True - if asset_row.size_bytes == 0 and size_bytes > 0: - asset_row.size_bytes = size_bytes - await session.flush() - return asset_row.id - - asset = Asset(hash=None, size_bytes=size_bytes, mime_type=None, created_at=now) - session.add(asset) - await session.flush() # to get id - - cs = AssetCacheState(asset_id=asset.id, file_path=locator, mtime_ns=mtime_ns, needs_verify=False) - session.add(cs) - - info = AssetInfo( - owner_id=owner_id, - name=info_name, - asset_id=asset.id, - preview_id=None, - created_at=now, - updated_at=now, - last_access_time=now, - ) - session.add(info) - await session.flush() + state_inserted = int((await session.execute(ins_state)).rowcount or 0) > 0 + else: + ins_state = ( + d_pg.insert(AssetCacheState) + .values(**acs_vals) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + .returning(AssetCacheState.id) + ) + state_inserted = (await session.execute(ins_state)).scalar_one_or_none() is not None - with contextlib.suppress(Exception): - computed = compute_relative_filename(locator) - if computed: - await replace_asset_info_metadata_projection( - session, - asset_info_id=info.id, - user_metadata={"filename": computed}, - ) + if not state_inserted: + # Lost the race - clean up our orphan seed Asset and exit + with contextlib.suppress(Exception): + await session.execute(sa.delete(Asset).where(Asset.id == new_asset_id)) + return - want = normalize_tags(tags) - if want: - if not skip_tag_ensure: - await ensure_tags_exist(session, want, tag_type="user") - session.add_all([ - AssetInfoTag(asset_info_id=info.id, tag_name=t, origin="automatic", added_at=now) - for t in want - ]) - await session.flush() - return asset.id + # 3) Create AssetInfo (unique(asset_id, owner_id, name)). + fname = compute_relative_filename(locator) + + info_vals = { + "id": new_info_id, + "owner_id": owner_id, + "name": info_name, + "asset_id": new_asset_id, + "preview_id": None, + "user_metadata": {"filename": fname} if fname else None, + "created_at": now, + "updated_at": now, + "last_access_time": now, + } + if dialect == "sqlite": + ins_info = ( + d_sqlite.insert(AssetInfo) + .values(**info_vals) + .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) + ) + info_inserted = int((await session.execute(ins_info)).rowcount or 0) > 0 + else: + ins_info = ( + d_pg.insert(AssetInfo) + .values(**info_vals) + .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) + .returning(AssetInfo.id) + ) + info_inserted = (await session.execute(ins_info)).scalar_one_or_none() is not None + + # 4) If we actually inserted AssetInfo, attach tags and filename. + if info_inserted: + want = normalize_tags(tags) + if want: + if not skip_tag_ensure: + await ensure_tags_exist(session, want, tag_type="user") + tag_rows = [ + { + "asset_info_id": new_info_id, + "tag_name": t, + "origin": "automatic", + "added_at": now, + } + for t in want + ] + if dialect == "sqlite": + ins_links = ( + d_sqlite.insert(AssetInfoTag) + .values(tag_rows) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + else: + ins_links = ( + d_pg.insert(AssetInfoTag) + .values(tag_rows) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + await session.execute(ins_links) + + if fname: # simple filename projection with single row + meta_row = { + "asset_info_id": new_info_id, + "key": "filename", + "ordinal": 0, + "val_str": fname, + "val_num": None, + "val_bool": None, + "val_json": None, + } + if dialect == "sqlite": + await session.execute(d_sqlite.insert(AssetInfoMeta).values(**meta_row)) + else: + await session.execute(d_pg.insert(AssetInfoMeta).values(**meta_row)) async def redirect_all_references_then_delete_asset( From 5b6810a2c665c9539da7f9f2d72e4ca2407c02c4 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 17 Sep 2025 13:25:56 +0300 Subject: [PATCH 73/82] fixed hash calculation during model loading in ComfyUI --- app/database/services/content.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/app/database/services/content.py b/app/database/services/content.py index 84fa01f01a71..13d63cc1f602 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -33,14 +33,18 @@ async def check_fs_asset_exists_quick( size_bytes: Optional[int] = None, mtime_ns: Optional[int] = None, ) -> bool: - """Return True if a cache row exists for this absolute path and (optionally) mtime/size match.""" + """Returns True if we already track this absolute path with a HASHED asset and the cached mtime/size match.""" locator = os.path.abspath(file_path) stmt = ( sa.select(sa.literal(True)) .select_from(AssetCacheState) .join(Asset, Asset.id == AssetCacheState.asset_id) - .where(AssetCacheState.file_path == locator) + .where( + AssetCacheState.file_path == locator, + Asset.hash.isnot(None), + AssetCacheState.needs_verify.is_(False), + ) .limit(1) ) @@ -49,12 +53,9 @@ async def check_fs_asset_exists_quick( conds.append(AssetCacheState.mtime_ns == int(mtime_ns)) if size_bytes is not None: conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes))) - if conds: stmt = stmt.where(*conds) - - row = (await session.execute(stmt)).first() - return row is not None + return (await session.execute(stmt)).first() is not None async def seed_from_path( From 85ef08449dc01e32349db23e4f3f5bdeb1e13c0d Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 17 Sep 2025 13:40:08 +0300 Subject: [PATCH 74/82] optimization: initial scan speed(batching tags) --- app/assets_scanner.py | 10 +++++++++- app/database/helpers/__init__.py | 2 ++ app/database/helpers/tags.py | 16 ++++++++++++++++ app/database/services/content.py | 18 ++---------------- 4 files changed, 29 insertions(+), 17 deletions(-) diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 5ec1ebe8870e..b90be1a12ec9 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -26,6 +26,7 @@ ensure_tags_exist, escape_like_prefix, fast_asset_file_check, + insert_tags_from_batch, remove_missing_tag_for_asset_id, ) from .database.models import Asset, AssetCacheState, AssetInfo @@ -146,6 +147,8 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: async with await create_session() as sess: if tag_pool: await ensure_tags_exist(sess, tag_pool, tag_type="user") + + pending_tag_links: list[dict] = [] for ap, sz, mt, name, tags in new_specs: await seed_from_path( sess, @@ -155,12 +158,17 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: info_name=name, tags=tags, owner_id="", - skip_tag_ensure=True, + collected_tag_rows=pending_tag_links, ) created += 1 if created % 500 == 0: + if pending_tag_links: + await insert_tags_from_batch(sess, tag_rows=pending_tag_links) + pending_tag_links.clear() await sess.commit() + if pending_tag_links: + await insert_tags_from_batch(sess, tag_rows=pending_tag_links) await sess.commit() finally: LOGGER.info( diff --git a/app/database/helpers/__init__.py b/app/database/helpers/__init__.py index 8119f72e9a6f..6d3db744f65d 100644 --- a/app/database/helpers/__init__.py +++ b/app/database/helpers/__init__.py @@ -6,6 +6,7 @@ from .tags import ( add_missing_tag_for_asset_id, ensure_tags_exist, + insert_tags_from_batch, remove_missing_tag_for_asset_id, ) @@ -19,5 +20,6 @@ "ensure_tags_exist", "add_missing_tag_for_asset_id", "remove_missing_tag_for_asset_id", + "insert_tags_from_batch", "visible_owner_clause", ] diff --git a/app/database/helpers/tags.py b/app/database/helpers/tags.py index 058869eca01c..40e22ac074c6 100644 --- a/app/database/helpers/tags.py +++ b/app/database/helpers/tags.py @@ -88,3 +88,19 @@ async def remove_missing_tag_for_asset_id( AssetInfoTag.tag_name == "missing", ) ) + + +async def insert_tags_from_batch(session: AsyncSession, *, tag_rows: list[dict]) -> None: + if session.bind.dialect.name == "sqlite": + ins_links = ( + d_sqlite.insert(AssetInfoTag) + .values(tag_rows) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + else: + ins_links = ( + d_pg.insert(AssetInfoTag) + .values(tag_rows) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + await session.execute(ins_links) diff --git a/app/database/services/content.py b/app/database/services/content.py index 13d63cc1f602..ae50e29ec5d6 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -67,7 +67,7 @@ async def seed_from_path( info_name: str, tags: Sequence[str], owner_id: str = "", - skip_tag_ensure: bool = False, + collected_tag_rows: list[dict], ) -> None: """Creates Asset(hash=NULL), AssetCacheState(file_path), and AssetInfo exist for the path.""" locator = os.path.abspath(abs_path) @@ -154,8 +154,6 @@ async def seed_from_path( if info_inserted: want = normalize_tags(tags) if want: - if not skip_tag_ensure: - await ensure_tags_exist(session, want, tag_type="user") tag_rows = [ { "asset_info_id": new_info_id, @@ -165,19 +163,7 @@ async def seed_from_path( } for t in want ] - if dialect == "sqlite": - ins_links = ( - d_sqlite.insert(AssetInfoTag) - .values(tag_rows) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) - ) - else: - ins_links = ( - d_pg.insert(AssetInfoTag) - .values(tag_rows) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) - ) - await session.execute(ins_links) + collected_tag_rows.extend(tag_rows) if fname: # simple filename projection with single row meta_row = { From f9602457d6aeff401c3af5b8ac69c93555f1b070 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 17 Sep 2025 16:47:27 +0300 Subject: [PATCH 75/82] optimization: initial scan speed(batching metadata[filename]) --- app/assets_scanner.py | 8 ++++++++ app/database/helpers/__init__.py | 2 ++ app/database/helpers/meta.py | 30 ++++++++++++++++++++++++++++++ app/database/helpers/tags.py | 4 +++- app/database/services/content.py | 27 +++++++++++++-------------- 5 files changed, 56 insertions(+), 15 deletions(-) create mode 100644 app/database/helpers/meta.py diff --git a/app/assets_scanner.py b/app/assets_scanner.py index b90be1a12ec9..e622d6e3c48b 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -26,6 +26,7 @@ ensure_tags_exist, escape_like_prefix, fast_asset_file_check, + insert_meta_from_batch, insert_tags_from_batch, remove_missing_tag_for_asset_id, ) @@ -149,6 +150,7 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: await ensure_tags_exist(sess, tag_pool, tag_type="user") pending_tag_links: list[dict] = [] + pending_meta_rows: list[dict] = [] for ap, sz, mt, name, tags in new_specs: await seed_from_path( sess, @@ -159,6 +161,7 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: tags=tags, owner_id="", collected_tag_rows=pending_tag_links, + collected_meta_rows=pending_meta_rows, ) created += 1 @@ -166,9 +169,14 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: if pending_tag_links: await insert_tags_from_batch(sess, tag_rows=pending_tag_links) pending_tag_links.clear() + if pending_meta_rows: + await insert_meta_from_batch(sess, rows=pending_meta_rows) + pending_meta_rows.clear() await sess.commit() if pending_tag_links: await insert_tags_from_batch(sess, tag_rows=pending_tag_links) + if pending_meta_rows: + await insert_meta_from_batch(sess, rows=pending_meta_rows) await sess.commit() finally: LOGGER.info( diff --git a/app/database/helpers/__init__.py b/app/database/helpers/__init__.py index 6d3db744f65d..fda457ca95d4 100644 --- a/app/database/helpers/__init__.py +++ b/app/database/helpers/__init__.py @@ -1,6 +1,7 @@ from .escape_like import escape_like_prefix from .fast_check import fast_asset_file_check from .filters import apply_metadata_filter, apply_tag_filters +from .meta import insert_meta_from_batch from .ownership import visible_owner_clause from .projection import is_scalar, project_kv from .tags import ( @@ -20,6 +21,7 @@ "ensure_tags_exist", "add_missing_tag_for_asset_id", "remove_missing_tag_for_asset_id", + "insert_meta_from_batch", "insert_tags_from_batch", "visible_owner_clause", ] diff --git a/app/database/helpers/meta.py b/app/database/helpers/meta.py new file mode 100644 index 000000000000..a2c801a32797 --- /dev/null +++ b/app/database/helpers/meta.py @@ -0,0 +1,30 @@ +from sqlalchemy.dialects import postgresql as d_pg +from sqlalchemy.dialects import sqlite as d_sqlite +from sqlalchemy.ext.asyncio import AsyncSession + +from ..models import AssetInfoMeta + + +async def insert_meta_from_batch(session: AsyncSession, *, rows: list[dict]) -> None: + """Bulk insert rows into asset_info_meta with ON CONFLICT DO NOTHING. + Each row should contain: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json + """ + if session.bind.dialect.name == "sqlite": + ins = ( + d_sqlite.insert(AssetInfoMeta) + .values(rows) + .on_conflict_do_nothing( + index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] + ) + ) + elif session.bind.dialect.name == "postgresql": + ins = ( + d_pg.insert(AssetInfoMeta) + .values(rows) + .on_conflict_do_nothing( + index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] + ) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {session.bind.dialect.name}") + await session.execute(ins) diff --git a/app/database/helpers/tags.py b/app/database/helpers/tags.py index 40e22ac074c6..5bc393a8bf09 100644 --- a/app/database/helpers/tags.py +++ b/app/database/helpers/tags.py @@ -97,10 +97,12 @@ async def insert_tags_from_batch(session: AsyncSession, *, tag_rows: list[dict]) .values(tag_rows) .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) ) - else: + elif session.bind.dialect.name == "postgresql": ins_links = ( d_pg.insert(AssetInfoTag) .values(tag_rows) .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) ) + else: + raise NotImplementedError(f"Unsupported database dialect: {session.bind.dialect.name}") await session.execute(ins_links) diff --git a/app/database/services/content.py b/app/database/services/content.py index ae50e29ec5d6..903238c9f7b7 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -20,7 +20,7 @@ escape_like_prefix, remove_missing_tag_for_asset_id, ) -from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag +from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag from ..timeutil import utcnow from .info import replace_asset_info_metadata_projection from .queries import list_cache_states_by_asset_id, pick_best_live_path @@ -68,6 +68,7 @@ async def seed_from_path( tags: Sequence[str], owner_id: str = "", collected_tag_rows: list[dict], + collected_meta_rows: list[dict], ) -> None: """Creates Asset(hash=NULL), AssetCacheState(file_path), and AssetInfo exist for the path.""" locator = os.path.abspath(abs_path) @@ -166,19 +167,17 @@ async def seed_from_path( collected_tag_rows.extend(tag_rows) if fname: # simple filename projection with single row - meta_row = { - "asset_info_id": new_info_id, - "key": "filename", - "ordinal": 0, - "val_str": fname, - "val_num": None, - "val_bool": None, - "val_json": None, - } - if dialect == "sqlite": - await session.execute(d_sqlite.insert(AssetInfoMeta).values(**meta_row)) - else: - await session.execute(d_pg.insert(AssetInfoMeta).values(**meta_row)) + collected_meta_rows.append( + { + "asset_info_id": new_info_id, + "key": "filename", + "ordinal": 0, + "val_str": fname, + "val_num": None, + "val_bool": None, + "val_json": None, + } + ) async def redirect_all_references_then_delete_asset( From 1a37d1476dd728abf0301c86bcd5658f3510032e Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 17 Sep 2025 20:15:50 +0300 Subject: [PATCH 76/82] refactor(6): fully batched initial scan --- app/_assets_helpers.py | 2 +- app/assets_scanner.py | 60 +++----- app/database/helpers/__init__.py | 6 +- app/database/helpers/bulk_ops.py | 231 ++++++++++++++++++++++++++++++ app/database/helpers/meta.py | 30 ---- app/database/helpers/tags.py | 18 --- app/database/services/__init__.py | 3 +- app/database/services/content.py | 125 +--------------- 8 files changed, 255 insertions(+), 220 deletions(-) create mode 100644 app/database/helpers/bulk_ops.py delete mode 100644 app/database/helpers/meta.py diff --git a/app/_assets_helpers.py b/app/_assets_helpers.py index 98761284547c..59141e99707e 100644 --- a/app/_assets_helpers.py +++ b/app/_assets_helpers.py @@ -97,7 +97,7 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: root_category, some_path = get_relative_to_root_category_path_of_asset(file_path) p = Path(some_path) parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] - return p.name, normalize_tags([root_category, *parent_parts]) + return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]: diff --git a/app/assets_scanner.py b/app/assets_scanner.py index e622d6e3c48b..7ef64a0526ed 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -12,6 +12,7 @@ from ._assets_helpers import ( collect_models_files, + compute_relative_filename, get_comfy_models_folders, get_name_and_tags_from_asset_path, list_tree, @@ -26,9 +27,8 @@ ensure_tags_exist, escape_like_prefix, fast_asset_file_check, - insert_meta_from_batch, - insert_tags_from_batch, remove_missing_tag_for_asset_id, + seed_from_paths_batch, ) from .database.models import Asset, AssetCacheState, AssetInfo from .database.services import ( @@ -37,7 +37,6 @@ list_cache_states_with_asset_under_prefixes, list_unhashed_candidates_under_prefixes, list_verify_candidates_under_prefixes, - seed_from_path, ) LOGGER = logging.getLogger(__name__) @@ -121,7 +120,7 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: if "output" in roots: paths.extend(list_tree(folder_paths.get_output_directory())) - new_specs: list[tuple[str, int, int, str, list[str]]] = [] + specs: list[dict] = [] tag_pool: set[str] = set() for p in paths: ap = os.path.abspath(p) @@ -129,54 +128,33 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: skipped_existing += 1 continue try: - st = os.stat(p, follow_symlinks=True) + st = os.stat(ap, follow_symlinks=True) except OSError: continue - if not int(st.st_size or 0): + if not st.st_size: continue name, tags = get_name_and_tags_from_asset_path(ap) - new_specs.append(( - ap, - int(st.st_size), - getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)), - name, - tags, - )) + specs.append( + { + "abs_path": ap, + "size_bytes": st.st_size, + "mtime_ns": getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)), + "info_name": name, + "tags": tags, + "fname": compute_relative_filename(ap), + } + ) for t in tags: tag_pool.add(t) + if not specs: + return async with await create_session() as sess: if tag_pool: await ensure_tags_exist(sess, tag_pool, tag_type="user") - pending_tag_links: list[dict] = [] - pending_meta_rows: list[dict] = [] - for ap, sz, mt, name, tags in new_specs: - await seed_from_path( - sess, - abs_path=ap, - size_bytes=sz, - mtime_ns=mt, - info_name=name, - tags=tags, - owner_id="", - collected_tag_rows=pending_tag_links, - collected_meta_rows=pending_meta_rows, - ) - - created += 1 - if created % 500 == 0: - if pending_tag_links: - await insert_tags_from_batch(sess, tag_rows=pending_tag_links) - pending_tag_links.clear() - if pending_meta_rows: - await insert_meta_from_batch(sess, rows=pending_meta_rows) - pending_meta_rows.clear() - await sess.commit() - if pending_tag_links: - await insert_tags_from_batch(sess, tag_rows=pending_tag_links) - if pending_meta_rows: - await insert_meta_from_batch(sess, rows=pending_meta_rows) + result = await seed_from_paths_batch(sess, specs=specs, owner_id="") + created += result["inserted_infos"] await sess.commit() finally: LOGGER.info( diff --git a/app/database/helpers/__init__.py b/app/database/helpers/__init__.py index fda457ca95d4..9ae13cd02e61 100644 --- a/app/database/helpers/__init__.py +++ b/app/database/helpers/__init__.py @@ -1,13 +1,12 @@ +from .bulk_ops import seed_from_paths_batch from .escape_like import escape_like_prefix from .fast_check import fast_asset_file_check from .filters import apply_metadata_filter, apply_tag_filters -from .meta import insert_meta_from_batch from .ownership import visible_owner_clause from .projection import is_scalar, project_kv from .tags import ( add_missing_tag_for_asset_id, ensure_tags_exist, - insert_tags_from_batch, remove_missing_tag_for_asset_id, ) @@ -21,7 +20,6 @@ "ensure_tags_exist", "add_missing_tag_for_asset_id", "remove_missing_tag_for_asset_id", - "insert_meta_from_batch", - "insert_tags_from_batch", + "seed_from_paths_batch", "visible_owner_clause", ] diff --git a/app/database/helpers/bulk_ops.py b/app/database/helpers/bulk_ops.py new file mode 100644 index 000000000000..4578511e5bb4 --- /dev/null +++ b/app/database/helpers/bulk_ops.py @@ -0,0 +1,231 @@ +import os +import uuid +from typing import Iterable, Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as d_pg +from sqlalchemy.dialects import sqlite as d_sqlite +from sqlalchemy.ext.asyncio import AsyncSession + +from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag +from ..timeutil import utcnow + + +MAX_BIND_PARAMS = 800 + + +async def seed_from_paths_batch( + session: AsyncSession, + *, + specs: Sequence[dict], + owner_id: str = "", +) -> dict: + """Each spec is a dict with keys: + - abs_path: str + - size_bytes: int + - mtime_ns: int + - info_name: str + - tags: list[str] + - fname: Optional[str] + """ + if not specs: + return {"inserted_infos": 0, "won_states": 0, "lost_states": 0} + + now = utcnow() + dialect = session.bind.dialect.name + if dialect not in ("sqlite", "postgresql"): + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + + asset_rows: list[dict] = [] + state_rows: list[dict] = [] + path_to_asset: dict[str, str] = {} + asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row + path_list: list[str] = [] + + for sp in specs: + ap = os.path.abspath(sp["abs_path"]) + aid = str(uuid.uuid4()) + iid = str(uuid.uuid4()) + path_list.append(ap) + path_to_asset[ap] = aid + + asset_rows.append( + { + "id": aid, + "hash": None, + "size_bytes": sp["size_bytes"], + "mime_type": None, + "created_at": now, + } + ) + state_rows.append( + { + "asset_id": aid, + "file_path": ap, + "mtime_ns": sp["mtime_ns"], + } + ) + asset_to_info[aid] = { + "id": iid, + "owner_id": owner_id, + "name": sp["info_name"], + "asset_id": aid, + "preview_id": None, + "user_metadata": {"filename": sp["fname"]} if sp["fname"] else None, + "created_at": now, + "updated_at": now, + "last_access_time": now, + "_tags": sp["tags"], + "_filename": sp["fname"], + } + + # insert all seed Assets (hash=NULL) + ins_asset = d_sqlite.insert(Asset) if dialect == "sqlite" else d_pg.insert(Asset) + for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)): + await session.execute(ins_asset, chunk) + + # try to claim AssetCacheState (file_path) + winners_by_path: set[str] = set() + if dialect == "sqlite": + ins_state = ( + d_sqlite.insert(AssetCacheState) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + .returning(AssetCacheState.file_path) + ) + else: + ins_state = ( + d_pg.insert(AssetCacheState) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + .returning(AssetCacheState.file_path) + ) + for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)): + winners_by_path.update((await session.execute(ins_state, chunk)).scalars().all()) + + all_paths_set = set(path_list) + losers_by_path = all_paths_set - winners_by_path + lost_assets = [path_to_asset[p] for p in losers_by_path] + if lost_assets: # losers get their Asset removed + for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS): + await session.execute(sa.delete(Asset).where(Asset.id.in_(id_chunk))) + + if not winners_by_path: + return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)} + + # insert AssetInfo only for winners + winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path] + if dialect == "sqlite": + ins_info = ( + d_sqlite.insert(AssetInfo) + .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) + .returning(AssetInfo.id) + ) + else: + ins_info = ( + d_pg.insert(AssetInfo) + .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) + .returning(AssetInfo.id) + ) + + inserted_info_ids: set[str] = set() + for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)): + inserted_info_ids.update((await session.execute(ins_info, chunk)).scalars().all()) + + # build and insert tag + meta rows for the AssetInfo + tag_rows: list[dict] = [] + meta_rows: list[dict] = [] + if inserted_info_ids: + for row in winner_info_rows: + iid = row["id"] + if iid not in inserted_info_ids: + continue + for t in row["_tags"]: + tag_rows.append({ + "asset_info_id": iid, + "tag_name": t, + "origin": "automatic", + "added_at": now, + }) + if row["_filename"]: + meta_rows.append( + { + "asset_info_id": iid, + "key": "filename", + "ordinal": 0, + "val_str": row["_filename"], + "val_num": None, + "val_bool": None, + "val_json": None, + } + ) + + await bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS) + return { + "inserted_infos": len(inserted_info_ids), + "won_states": len(winners_by_path), + "lost_states": len(losers_by_path), + } + + +async def bulk_insert_tags_and_meta( + session: AsyncSession, + *, + tag_rows: list[dict], + meta_rows: list[dict], + max_bind_params: int, +) -> None: + """Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING. + - tag_rows keys: asset_info_id, tag_name, origin, added_at + - meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json + """ + dialect = session.bind.dialect.name + if tag_rows: + if dialect == "sqlite": + ins_links = ( + d_sqlite.insert(AssetInfoTag) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + elif dialect == "postgresql": + ins_links = ( + d_pg.insert(AssetInfoTag) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params): + await session.execute(ins_links, chunk) + if meta_rows: + if dialect == "sqlite": + ins_meta = ( + d_sqlite.insert(AssetInfoMeta) + .on_conflict_do_nothing( + index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] + ) + ) + elif dialect == "postgresql": + ins_meta = ( + d_pg.insert(AssetInfoMeta) + .on_conflict_do_nothing( + index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] + ) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params): + await session.execute(ins_meta, chunk) + + +def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]: + if not rows: + return [] + rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row)) + for i in range(0, len(rows), rows_per_stmt): + yield rows[i:i + rows_per_stmt] + + +def _iter_chunks(seq, n: int): + for i in range(0, len(seq), n): + yield seq[i:i + n] + + +def _rows_per_stmt(cols: int) -> int: + return max(1, MAX_BIND_PARAMS // max(1, cols)) diff --git a/app/database/helpers/meta.py b/app/database/helpers/meta.py deleted file mode 100644 index a2c801a32797..000000000000 --- a/app/database/helpers/meta.py +++ /dev/null @@ -1,30 +0,0 @@ -from sqlalchemy.dialects import postgresql as d_pg -from sqlalchemy.dialects import sqlite as d_sqlite -from sqlalchemy.ext.asyncio import AsyncSession - -from ..models import AssetInfoMeta - - -async def insert_meta_from_batch(session: AsyncSession, *, rows: list[dict]) -> None: - """Bulk insert rows into asset_info_meta with ON CONFLICT DO NOTHING. - Each row should contain: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json - """ - if session.bind.dialect.name == "sqlite": - ins = ( - d_sqlite.insert(AssetInfoMeta) - .values(rows) - .on_conflict_do_nothing( - index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] - ) - ) - elif session.bind.dialect.name == "postgresql": - ins = ( - d_pg.insert(AssetInfoMeta) - .values(rows) - .on_conflict_do_nothing( - index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] - ) - ) - else: - raise NotImplementedError(f"Unsupported database dialect: {session.bind.dialect.name}") - await session.execute(ins) diff --git a/app/database/helpers/tags.py b/app/database/helpers/tags.py index 5bc393a8bf09..058869eca01c 100644 --- a/app/database/helpers/tags.py +++ b/app/database/helpers/tags.py @@ -88,21 +88,3 @@ async def remove_missing_tag_for_asset_id( AssetInfoTag.tag_name == "missing", ) ) - - -async def insert_tags_from_batch(session: AsyncSession, *, tag_rows: list[dict]) -> None: - if session.bind.dialect.name == "sqlite": - ins_links = ( - d_sqlite.insert(AssetInfoTag) - .values(tag_rows) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) - ) - elif session.bind.dialect.name == "postgresql": - ins_links = ( - d_pg.insert(AssetInfoTag) - .values(tag_rows) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) - ) - else: - raise NotImplementedError(f"Unsupported database dialect: {session.bind.dialect.name}") - await session.execute(ins_links) diff --git a/app/database/services/__init__.py b/app/database/services/__init__.py index fae9eb6703b7..6c6f26e514d7 100644 --- a/app/database/services/__init__.py +++ b/app/database/services/__init__.py @@ -6,7 +6,6 @@ list_unhashed_candidates_under_prefixes, list_verify_candidates_under_prefixes, redirect_all_references_then_delete_asset, - seed_from_path, touch_asset_infos_by_fs_path, ) from .info import ( @@ -49,7 +48,7 @@ "get_asset_tags", "list_tags_with_usage", "set_asset_info_preview", "fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags", # content - "check_fs_asset_exists_quick", "seed_from_path", + "check_fs_asset_exists_quick", "redirect_all_references_then_delete_asset", "compute_hash_and_dedup_for_cache_state", "list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes", diff --git a/app/database/services/content.py b/app/database/services/content.py index 903238c9f7b7..11eff76f9ed1 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -1,7 +1,6 @@ import contextlib import logging import os -import uuid from datetime import datetime from typing import Any, Optional, Sequence, Union @@ -13,7 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import noload -from ..._assets_helpers import compute_relative_filename, normalize_tags +from ..._assets_helpers import compute_relative_filename from ...storage import hashing as hashing_mod from ..helpers import ( ensure_tags_exist, @@ -58,128 +57,6 @@ async def check_fs_asset_exists_quick( return (await session.execute(stmt)).first() is not None -async def seed_from_path( - session: AsyncSession, - *, - abs_path: str, - size_bytes: int, - mtime_ns: int, - info_name: str, - tags: Sequence[str], - owner_id: str = "", - collected_tag_rows: list[dict], - collected_meta_rows: list[dict], -) -> None: - """Creates Asset(hash=NULL), AssetCacheState(file_path), and AssetInfo exist for the path.""" - locator = os.path.abspath(abs_path) - now = utcnow() - dialect = session.bind.dialect.name - - new_asset_id = str(uuid.uuid4()) - new_info_id = str(uuid.uuid4()) - - # 1) Insert Asset (hash=NULL) – no conflict expected - asset_vals = { - "id": new_asset_id, - "hash": None, - "size_bytes": size_bytes, - "mime_type": None, - "created_at": now, - } - if dialect == "sqlite": - await session.execute(d_sqlite.insert(Asset).values(**asset_vals)) - elif dialect == "postgresql": - await session.execute(d_pg.insert(Asset).values(**asset_vals)) - else: - raise NotImplementedError(f"Unsupported database dialect: {dialect}") - - # 2) Try to claim file_path in AssetCacheState. Our concurrency gate. - acs_vals = { - "asset_id": new_asset_id, - "file_path": locator, - "mtime_ns": mtime_ns, - } - if dialect == "sqlite": - ins_state = ( - d_sqlite.insert(AssetCacheState) - .values(**acs_vals) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - ) - state_inserted = int((await session.execute(ins_state)).rowcount or 0) > 0 - else: - ins_state = ( - d_pg.insert(AssetCacheState) - .values(**acs_vals) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - .returning(AssetCacheState.id) - ) - state_inserted = (await session.execute(ins_state)).scalar_one_or_none() is not None - - if not state_inserted: - # Lost the race - clean up our orphan seed Asset and exit - with contextlib.suppress(Exception): - await session.execute(sa.delete(Asset).where(Asset.id == new_asset_id)) - return - - # 3) Create AssetInfo (unique(asset_id, owner_id, name)). - fname = compute_relative_filename(locator) - - info_vals = { - "id": new_info_id, - "owner_id": owner_id, - "name": info_name, - "asset_id": new_asset_id, - "preview_id": None, - "user_metadata": {"filename": fname} if fname else None, - "created_at": now, - "updated_at": now, - "last_access_time": now, - } - if dialect == "sqlite": - ins_info = ( - d_sqlite.insert(AssetInfo) - .values(**info_vals) - .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) - ) - info_inserted = int((await session.execute(ins_info)).rowcount or 0) > 0 - else: - ins_info = ( - d_pg.insert(AssetInfo) - .values(**info_vals) - .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) - .returning(AssetInfo.id) - ) - info_inserted = (await session.execute(ins_info)).scalar_one_or_none() is not None - - # 4) If we actually inserted AssetInfo, attach tags and filename. - if info_inserted: - want = normalize_tags(tags) - if want: - tag_rows = [ - { - "asset_info_id": new_info_id, - "tag_name": t, - "origin": "automatic", - "added_at": now, - } - for t in want - ] - collected_tag_rows.extend(tag_rows) - - if fname: # simple filename projection with single row - collected_meta_rows.append( - { - "asset_info_id": new_info_id, - "key": "filename", - "ordinal": 0, - "val_str": fname, - "val_num": None, - "val_bool": None, - "val_json": None, - } - ) - - async def redirect_all_references_then_delete_asset( session: AsyncSession, *, From 283cd27bdce1837a4e812e84378bba42381c38d2 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Thu, 18 Sep 2025 09:47:50 +0300 Subject: [PATCH 77/82] final adjustments --- app/database/helpers/bulk_ops.py | 1 - app/database/services/content.py | 10 ++-------- app/database/services/info.py | 9 ++------- 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/app/database/helpers/bulk_ops.py b/app/database/helpers/bulk_ops.py index 4578511e5bb4..feefbb5bd63b 100644 --- a/app/database/helpers/bulk_ops.py +++ b/app/database/helpers/bulk_ops.py @@ -10,7 +10,6 @@ from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag from ..timeutil import utcnow - MAX_BIND_PARAMS = 800 diff --git a/app/database/services/content.py b/app/database/services/content.py index 11eff76f9ed1..05450f8231bb 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -642,10 +642,9 @@ async def touch_asset_infos_by_fs_path( file_path: str, ts: Optional[datetime] = None, only_if_newer: bool = True, -) -> int: +) -> None: locator = os.path.abspath(file_path) ts = ts or utcnow() - stmt = sa.update(AssetInfo).where( sa.exists( sa.select(sa.literal(1)) @@ -656,7 +655,6 @@ async def touch_asset_infos_by_fs_path( ) ) ) - if only_if_newer: stmt = stmt.where( sa.or_( @@ -664,11 +662,7 @@ async def touch_asset_infos_by_fs_path( AssetInfo.last_access_time < ts, ) ) - - stmt = stmt.values(last_access_time=ts) - - res = await session.execute(stmt) - return int(res.rowcount or 0) + await session.execute(stmt.values(last_access_time=ts)) async def list_cache_states_with_asset_under_prefixes( diff --git a/app/database/services/info.py b/app/database/services/info.py index 7583383683ee..8a28bcf4c38c 100644 --- a/app/database/services/info.py +++ b/app/database/services/info.py @@ -373,17 +373,14 @@ async def touch_asset_info_by_id( asset_info_id: str, ts: Optional[datetime] = None, only_if_newer: bool = True, -) -> bool: +) -> None: ts = ts or utcnow() stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) if only_if_newer: stmt = stmt.where( sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) ) - stmt = stmt.values(last_access_time=ts) - if session.bind.dialect.name == "postgresql": - return (await session.execute(stmt.returning(AssetInfo.id))).scalar_one_or_none() is not None - return int((await session.execute(stmt)).rowcount or 0) > 0 + await session.execute(stmt.values(last_access_time=ts)) async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool: @@ -391,8 +388,6 @@ async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, AssetInfo.id == asset_info_id, visible_owner_clause(owner_id), ) - if session.bind.dialect.name == "postgresql": - return (await session.execute(stmt.returning(AssetInfo.id))).scalar_one_or_none() is not None return int((await session.execute(stmt)).rowcount or 0) > 0 From adccfb2dfdfd1793ce9aec372f0091a2e428f874 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 26 Sep 2025 20:33:46 -0700 Subject: [PATCH 78/82] Remove populate_db_with_asset from load_torch_file for now, as nothing yet uses the hashes --- comfy/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index f13a780e8c30..2d96e1d0f308 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -103,7 +103,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): else: sd = pl_sd - populate_db_with_asset(ckpt) + # populate_db_with_asset(ckpt) # surprise tool that can help us later - performs hashing on model file return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): From fbba2e59e53aefb3e3778d46f85aea7540498851 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 26 Sep 2025 20:39:23 -0700 Subject: [PATCH 79/82] Satisfy ruff --- comfy/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index 2d96e1d0f308..73461b8761fb 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -29,7 +29,6 @@ from torch.nn.functional import interpolate from einops import rearrange from comfy.cli_args import args -from app.assets_manager import populate_db_with_asset MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap From 94941c50b3dc7f7facd5bf39ea0e3c2c4781e623 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 01:01:16 +0300 Subject: [PATCH 80/82] move alembic_db inside app folder (#10163) --- .github/workflows/test-assets.yml | 1 - alembic.ini | 2 +- {alembic_db => app/alembic_db}/README.md | 0 {alembic_db => app/alembic_db}/env.py | 0 {alembic_db => app/alembic_db}/script.py.mako | 0 {alembic_db => app/alembic_db}/versions/0001_assets.py | 0 app/database/db.py | 6 +++--- 7 files changed, 4 insertions(+), 5 deletions(-) rename {alembic_db => app/alembic_db}/README.md (100%) rename {alembic_db => app/alembic_db}/env.py (100%) rename {alembic_db => app/alembic_db}/script.py.mako (100%) rename {alembic_db => app/alembic_db}/versions/0001_assets.py (100%) diff --git a/.github/workflows/test-assets.yml b/.github/workflows/test-assets.yml index 3b3a7c73f3a4..ef80fc48ab4c 100644 --- a/.github/workflows/test-assets.yml +++ b/.github/workflows/test-assets.yml @@ -4,7 +4,6 @@ on: push: paths: - 'app/**' - - 'alembic_db/**' - 'tests-assets/**' - '.github/workflows/test-assets.yml' - 'requirements.txt' diff --git a/alembic.ini b/alembic.ini index 12f18712f430..360efd386901 100644 --- a/alembic.ini +++ b/alembic.ini @@ -3,7 +3,7 @@ [alembic] # path to migration scripts # Use forward slashes (/) also on windows to provide an os agnostic path -script_location = alembic_db +script_location = app/alembic_db # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s # Uncomment the line below if you want the files to be prepended with date and time diff --git a/alembic_db/README.md b/app/alembic_db/README.md similarity index 100% rename from alembic_db/README.md rename to app/alembic_db/README.md diff --git a/alembic_db/env.py b/app/alembic_db/env.py similarity index 100% rename from alembic_db/env.py rename to app/alembic_db/env.py diff --git a/alembic_db/script.py.mako b/app/alembic_db/script.py.mako similarity index 100% rename from alembic_db/script.py.mako rename to app/alembic_db/script.py.mako diff --git a/alembic_db/versions/0001_assets.py b/app/alembic_db/versions/0001_assets.py similarity index 100% rename from alembic_db/versions/0001_assets.py rename to app/alembic_db/versions/0001_assets.py diff --git a/app/database/db.py b/app/database/db.py index 82c9cc737c39..54f9000cc231 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -26,9 +26,9 @@ def _root_paths(): """Resolve alembic.ini and migrations script folder.""" - root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) - config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) - scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) + root_path = os.path.abspath(os.path.dirname(__file__)) + config_path = os.path.abspath(os.path.join(root_path, "../../alembic.ini")) + scripts_path = os.path.abspath(os.path.join(root_path, "../alembic_db")) return config_path, scripts_path From fd6ac0a7654a3842681559c46021f41df006ea7e Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 21:34:06 +0300 Subject: [PATCH 81/82] drop PgSQL 14, unite migration for SQLite and PgSQL (#10165) --- .github/workflows/test-assets.yml | 2 +- app/alembic_db/versions/0001_assets.py | 16 ++-------------- app/database/models.py | 1 + 3 files changed, 4 insertions(+), 15 deletions(-) diff --git a/.github/workflows/test-assets.yml b/.github/workflows/test-assets.yml index ef80fc48ab4c..4ae26ba5ff2d 100644 --- a/.github/workflows/test-assets.yml +++ b/.github/workflows/test-assets.yml @@ -99,7 +99,7 @@ jobs: fail-fast: false matrix: python: ['3.9', '3.12'] - pgsql: ['14', '16'] + pgsql: ['16', '18'] services: postgres: diff --git a/app/alembic_db/versions/0001_assets.py b/app/alembic_db/versions/0001_assets.py index 1f5fb462280d..589b22ac8a49 100644 --- a/app/alembic_db/versions/0001_assets.py +++ b/app/alembic_db/versions/0001_assets.py @@ -26,16 +26,7 @@ def upgrade() -> None: sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), ) - if op.get_bind().dialect.name == "postgresql": - op.create_index( - "uq_assets_hash_not_null", - "assets", - ["hash"], - unique=True, - postgresql_where=sa.text("hash IS NOT NULL"), - ) - else: - op.create_index("uq_assets_hash", "assets", ["hash"], unique=True) + op.create_index("uq_assets_hash", "assets", ["hash"], unique=True) op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) # ASSETS_INFO: user-visible references @@ -179,9 +170,6 @@ def downgrade() -> None: op.drop_index("ix_assets_info_owner_id", table_name="assets_info") op.drop_table("assets_info") - if op.get_bind().dialect.name == "postgresql": - op.drop_index("uq_assets_hash_not_null", table_name="assets") - else: - op.drop_index("uq_assets_hash", table_name="assets") + op.drop_index("uq_assets_hash", table_name="assets") op.drop_index("ix_assets_mime_type", table_name="assets") op.drop_table("assets") diff --git a/app/database/models.py b/app/database/models.py index 6a6798bcfd5a..c6555fa61732 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -77,6 +77,7 @@ class Asset(Base): ) __table_args__ = ( + Index("uq_assets_hash", "hash", unique=True), Index("ix_assets_mime_type", "mime_type"), CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), ) From 917177e821983632854e714d90b491dd2311e316 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 21:53:26 +0300 Subject: [PATCH 82/82] move assets related stuff to "app/assets" folder (#10184) --- app/__init__.py | 5 --- app/alembic_db/env.py | 3 +- app/assets/__init__.py | 4 ++ .../_helpers.py} | 0 app/{ => assets}/api/__init__.py | 0 .../assets_routes.py => assets/api/routes.py} | 37 ++++++++++--------- app/{ => assets}/api/schemas_in.py | 0 app/{ => assets}/api/schemas_out.py | 0 app/{ => assets}/database/__init__.py | 0 app/{ => assets}/database/helpers/__init__.py | 0 app/{ => assets}/database/helpers/bulk_ops.py | 0 .../database/helpers/escape_like.py | 0 .../database/helpers/fast_check.py | 0 app/{ => assets}/database/helpers/filters.py | 2 +- .../database/helpers/ownership.py | 0 .../database/helpers/projection.py | 0 app/{ => assets}/database/helpers/tags.py | 2 +- app/{ => assets}/database/models.py | 0 .../database/services/__init__.py | 0 app/{ => assets}/database/services/content.py | 2 +- app/{ => assets}/database/services/info.py | 2 +- app/{ => assets}/database/services/queries.py | 0 app/{ => assets}/database/timeutil.py | 0 app/{assets_manager.py => assets/manager.py} | 4 +- app/{assets_scanner.py => assets/scanner.py} | 4 +- app/{ => assets}/storage/__init__.py | 0 app/{ => assets}/storage/hashing.py | 0 app/{database => }/db.py | 4 +- main.py | 3 +- server.py | 2 +- 30 files changed, 37 insertions(+), 37 deletions(-) delete mode 100644 app/__init__.py create mode 100644 app/assets/__init__.py rename app/{_assets_helpers.py => assets/_helpers.py} (100%) rename app/{ => assets}/api/__init__.py (100%) rename app/{api/assets_routes.py => assets/api/routes.py} (94%) rename app/{ => assets}/api/schemas_in.py (100%) rename app/{ => assets}/api/schemas_out.py (100%) rename app/{ => assets}/database/__init__.py (100%) rename app/{ => assets}/database/helpers/__init__.py (100%) rename app/{ => assets}/database/helpers/bulk_ops.py (100%) rename app/{ => assets}/database/helpers/escape_like.py (100%) rename app/{ => assets}/database/helpers/fast_check.py (100%) rename app/{ => assets}/database/helpers/filters.py (98%) rename app/{ => assets}/database/helpers/ownership.py (100%) rename app/{ => assets}/database/helpers/projection.py (100%) rename app/{ => assets}/database/helpers/tags.py (98%) rename app/{ => assets}/database/models.py (100%) rename app/{ => assets}/database/services/__init__.py (100%) rename app/{ => assets}/database/services/content.py (99%) rename app/{ => assets}/database/services/info.py (99%) rename app/{ => assets}/database/services/queries.py (100%) rename app/{ => assets}/database/timeutil.py (100%) rename app/{assets_manager.py => assets/manager.py} (99%) rename app/{assets_scanner.py => assets/scanner.py} (99%) rename app/{ => assets}/storage/__init__.py (100%) rename app/{ => assets}/storage/hashing.py (100%) rename app/{database => }/db.py (98%) diff --git a/app/__init__.py b/app/__init__.py deleted file mode 100644 index f73951107da2..000000000000 --- a/app/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .api.assets_routes import register_assets_system -from .assets_scanner import sync_seed_assets -from .database.db import init_db_engine - -__all__ = ["init_db_engine", "sync_seed_assets", "register_assets_system"] diff --git a/app/alembic_db/env.py b/app/alembic_db/env.py index 4d7770679875..44f4e1a0c9e3 100644 --- a/app/alembic_db/env.py +++ b/app/alembic_db/env.py @@ -2,13 +2,12 @@ from sqlalchemy import pool from alembic import context +from app.assets.database.models import Base # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config - -from app.database.models import Base target_metadata = Base.metadata # other values from the config, defined by the needs of env.py, diff --git a/app/assets/__init__.py b/app/assets/__init__.py new file mode 100644 index 000000000000..28020a2935fb --- /dev/null +++ b/app/assets/__init__.py @@ -0,0 +1,4 @@ +from .api.routes import register_assets_system +from .scanner import sync_seed_assets + +__all__ = ["sync_seed_assets", "register_assets_system"] diff --git a/app/_assets_helpers.py b/app/assets/_helpers.py similarity index 100% rename from app/_assets_helpers.py rename to app/assets/_helpers.py diff --git a/app/api/__init__.py b/app/assets/api/__init__.py similarity index 100% rename from app/api/__init__.py rename to app/assets/api/__init__.py diff --git a/app/api/assets_routes.py b/app/assets/api/routes.py similarity index 94% rename from app/api/assets_routes.py rename to app/assets/api/routes.py index 6bb0ed77e27b..4ca7467750fd 100644 --- a/app/api/assets_routes.py +++ b/app/assets/api/routes.py @@ -10,7 +10,8 @@ import folder_paths -from .. import assets_manager, assets_scanner, user_manager +from ... import user_manager +from .. import manager, scanner from . import schemas_in, schemas_out ROUTES = web.RouteTableDef() @@ -29,7 +30,7 @@ async def head_asset_by_hash(request: web.Request) -> web.Response: algo, digest = hash_str.split(":", 1) if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - exists = await assets_manager.asset_exists(asset_hash=hash_str) + exists = await manager.asset_exists(asset_hash=hash_str) return web.Response(status=200 if exists else 404) @@ -51,7 +52,7 @@ async def list_assets(request: web.Request) -> web.Response: except ValidationError as ve: return _validation_error_response("INVALID_QUERY", ve) - payload = await assets_manager.list_assets( + payload = await manager.list_assets( include_tags=q.include_tags, exclude_tags=q.exclude_tags, name_contains=q.name_contains, @@ -72,7 +73,7 @@ async def download_asset_content(request: web.Request) -> web.Response: disposition = "attachment" try: - abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download( + abs_path, content_type, filename = await manager.resolve_asset_content_for_download( asset_info_id=str(uuid.UUID(request.match_info["id"])), owner_id=USER_MANAGER.get_request_user_id(request), ) @@ -102,7 +103,7 @@ async def create_asset_from_hash(request: web.Request) -> web.Response: except Exception: return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") - result = await assets_manager.create_asset_from_hash( + result = await manager.create_asset_from_hash( hash_str=body.hash, name=body.name, tags=body.tags, @@ -154,7 +155,7 @@ async def upload_asset(request: web.Request) -> web.Response: return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") provided_hash = f"{algo}:{digest}" try: - provided_hash_exists = await assets_manager.asset_exists(asset_hash=provided_hash) + provided_hash_exists = await manager.asset_exists(asset_hash=provided_hash) except Exception: provided_hash_exists = None # do not fail the whole request here @@ -241,7 +242,7 @@ async def upload_asset(request: web.Request) -> web.Response: # Fast path: if a valid provided hash exists, create AssetInfo without writing anything if spec.hash and provided_hash_exists is True: try: - result = await assets_manager.create_asset_from_hash( + result = await manager.create_asset_from_hash( hash_str=spec.hash, name=spec.name or (spec.hash.split(":", 1)[1]), tags=spec.tags, @@ -269,7 +270,7 @@ async def upload_asset(request: web.Request) -> web.Response: return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.") try: - created = await assets_manager.upload_asset_from_temp_path( + created = await manager.upload_asset_from_temp_path( spec, temp_path=tmp_path, client_filename=file_client_name, @@ -300,7 +301,7 @@ async def upload_asset(request: web.Request) -> web.Response: async def get_asset(request: web.Request) -> web.Response: asset_info_id = str(uuid.UUID(request.match_info["id"])) try: - result = await assets_manager.get_asset( + result = await manager.get_asset( asset_info_id=asset_info_id, owner_id=USER_MANAGER.get_request_user_id(request), ) @@ -327,7 +328,7 @@ async def update_asset(request: web.Request) -> web.Response: return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") try: - result = await assets_manager.update_asset( + result = await manager.update_asset( asset_info_id=asset_info_id, name=body.name, tags=body.tags, @@ -357,7 +358,7 @@ async def set_asset_preview(request: web.Request) -> web.Response: return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") try: - result = await assets_manager.set_asset_preview( + result = await manager.set_asset_preview( asset_info_id=asset_info_id, preview_asset_id=body.preview_id, owner_id=USER_MANAGER.get_request_user_id(request), @@ -381,7 +382,7 @@ async def delete_asset(request: web.Request) -> web.Response: delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"} try: - deleted = await assets_manager.delete_asset_reference( + deleted = await manager.delete_asset_reference( asset_info_id=asset_info_id, owner_id=USER_MANAGER.get_request_user_id(request), delete_content_if_orphan=delete_content, @@ -411,7 +412,7 @@ async def get_tags(request: web.Request) -> web.Response: status=400, ) - result = await assets_manager.list_tags( + result = await manager.list_tags( prefix=query.prefix, limit=query.limit, offset=query.offset, @@ -434,7 +435,7 @@ async def add_asset_tags(request: web.Request) -> web.Response: return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") try: - result = await assets_manager.add_tags_to_asset( + result = await manager.add_tags_to_asset( asset_info_id=asset_info_id, tags=data.tags, origin="manual", @@ -465,7 +466,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response: return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") try: - result = await assets_manager.remove_tags_from_asset( + result = await manager.remove_tags_from_asset( asset_info_id=asset_info_id, tags=data.tags, owner_id=USER_MANAGER.get_request_user_id(request), @@ -496,7 +497,7 @@ async def seed_assets(request: web.Request) -> web.Response: return _validation_error_response("INVALID_BODY", ve) try: - await assets_scanner.sync_seed_assets(body.roots) + await scanner.sync_seed_assets(body.roots) except Exception: LOGGER.exception("sync_seed_assets failed for roots=%s", body.roots) return _error_response(500, "INTERNAL", "Unexpected server error.") @@ -515,14 +516,14 @@ async def schedule_asset_scan(request: web.Request) -> web.Response: except ValidationError as ve: return _validation_error_response("INVALID_BODY", ve) - states = await assets_scanner.schedule_scans(body.roots) + states = await scanner.schedule_scans(body.roots) return web.json_response(states.model_dump(mode="json"), status=202) @ROUTES.get("/api/assets/scan") async def get_asset_scan_status(request: web.Request) -> web.Response: root = request.query.get("root", "").strip().lower() - states = assets_scanner.current_statuses() + states = scanner.current_statuses() if root in {"models", "input", "output"}: states = [s for s in states.scans if s.root == root] # type: ignore states = schemas_out.AssetScanStatusResponse(scans=states) diff --git a/app/api/schemas_in.py b/app/assets/api/schemas_in.py similarity index 100% rename from app/api/schemas_in.py rename to app/assets/api/schemas_in.py diff --git a/app/api/schemas_out.py b/app/assets/api/schemas_out.py similarity index 100% rename from app/api/schemas_out.py rename to app/assets/api/schemas_out.py diff --git a/app/database/__init__.py b/app/assets/database/__init__.py similarity index 100% rename from app/database/__init__.py rename to app/assets/database/__init__.py diff --git a/app/database/helpers/__init__.py b/app/assets/database/helpers/__init__.py similarity index 100% rename from app/database/helpers/__init__.py rename to app/assets/database/helpers/__init__.py diff --git a/app/database/helpers/bulk_ops.py b/app/assets/database/helpers/bulk_ops.py similarity index 100% rename from app/database/helpers/bulk_ops.py rename to app/assets/database/helpers/bulk_ops.py diff --git a/app/database/helpers/escape_like.py b/app/assets/database/helpers/escape_like.py similarity index 100% rename from app/database/helpers/escape_like.py rename to app/assets/database/helpers/escape_like.py diff --git a/app/database/helpers/fast_check.py b/app/assets/database/helpers/fast_check.py similarity index 100% rename from app/database/helpers/fast_check.py rename to app/assets/database/helpers/fast_check.py diff --git a/app/database/helpers/filters.py b/app/assets/database/helpers/filters.py similarity index 98% rename from app/database/helpers/filters.py rename to app/assets/database/helpers/filters.py index 0b6d85b8d572..0edc0c66d88c 100644 --- a/app/database/helpers/filters.py +++ b/app/assets/database/helpers/filters.py @@ -3,7 +3,7 @@ import sqlalchemy as sa from sqlalchemy import exists -from ..._assets_helpers import normalize_tags +from ..._helpers import normalize_tags from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag diff --git a/app/database/helpers/ownership.py b/app/assets/database/helpers/ownership.py similarity index 100% rename from app/database/helpers/ownership.py rename to app/assets/database/helpers/ownership.py diff --git a/app/database/helpers/projection.py b/app/assets/database/helpers/projection.py similarity index 100% rename from app/database/helpers/projection.py rename to app/assets/database/helpers/projection.py diff --git a/app/database/helpers/tags.py b/app/assets/database/helpers/tags.py similarity index 98% rename from app/database/helpers/tags.py rename to app/assets/database/helpers/tags.py index 058869eca01c..402dc346d430 100644 --- a/app/database/helpers/tags.py +++ b/app/assets/database/helpers/tags.py @@ -5,7 +5,7 @@ from sqlalchemy.dialects import sqlite as d_sqlite from sqlalchemy.ext.asyncio import AsyncSession -from ..._assets_helpers import normalize_tags +from ..._helpers import normalize_tags from ..models import AssetInfo, AssetInfoTag, Tag from ..timeutil import utcnow diff --git a/app/database/models.py b/app/assets/database/models.py similarity index 100% rename from app/database/models.py rename to app/assets/database/models.py diff --git a/app/database/services/__init__.py b/app/assets/database/services/__init__.py similarity index 100% rename from app/database/services/__init__.py rename to app/assets/database/services/__init__.py diff --git a/app/database/services/content.py b/app/assets/database/services/content.py similarity index 99% rename from app/database/services/content.py rename to app/assets/database/services/content.py index 05450f8231bb..864c190442cb 100644 --- a/app/database/services/content.py +++ b/app/assets/database/services/content.py @@ -12,7 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import noload -from ..._assets_helpers import compute_relative_filename +from ..._helpers import compute_relative_filename from ...storage import hashing as hashing_mod from ..helpers import ( ensure_tags_exist, diff --git a/app/database/services/info.py b/app/assets/database/services/info.py similarity index 99% rename from app/database/services/info.py rename to app/assets/database/services/info.py index 8a28bcf4c38c..b499557418c8 100644 --- a/app/database/services/info.py +++ b/app/assets/database/services/info.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager, noload -from ..._assets_helpers import compute_relative_filename, normalize_tags +from ..._helpers import compute_relative_filename, normalize_tags from ..helpers import ( apply_metadata_filter, apply_tag_filters, diff --git a/app/database/services/queries.py b/app/assets/database/services/queries.py similarity index 100% rename from app/database/services/queries.py rename to app/assets/database/services/queries.py diff --git a/app/database/timeutil.py b/app/assets/database/timeutil.py similarity index 100% rename from app/database/timeutil.py rename to app/assets/database/timeutil.py diff --git a/app/assets_manager.py b/app/assets/manager.py similarity index 99% rename from app/assets_manager.py rename to app/assets/manager.py index 4aae6e8ad08b..50cf146d26a3 100644 --- a/app/assets_manager.py +++ b/app/assets/manager.py @@ -6,13 +6,13 @@ from comfy_api.internal import async_to_sync -from ._assets_helpers import ( +from ..db import create_session +from ._helpers import ( ensure_within_base, get_name_and_tags_from_asset_path, resolve_destination_from_tags, ) from .api import schemas_in, schemas_out -from .database.db import create_session from .database.models import Asset from .database.services import ( add_tags_to_asset_info, diff --git a/app/assets_scanner.py b/app/assets/scanner.py similarity index 99% rename from app/assets_scanner.py rename to app/assets/scanner.py index 7ef64a0526ed..aa8123a49c5a 100644 --- a/app/assets_scanner.py +++ b/app/assets/scanner.py @@ -10,7 +10,8 @@ import folder_paths -from ._assets_helpers import ( +from ..db import create_session +from ._helpers import ( collect_models_files, compute_relative_filename, get_comfy_models_folders, @@ -21,7 +22,6 @@ ts_to_iso, ) from .api import schemas_in, schemas_out -from .database.db import create_session from .database.helpers import ( add_missing_tag_for_asset_id, ensure_tags_exist, diff --git a/app/storage/__init__.py b/app/assets/storage/__init__.py similarity index 100% rename from app/storage/__init__.py rename to app/assets/storage/__init__.py diff --git a/app/storage/hashing.py b/app/assets/storage/hashing.py similarity index 100% rename from app/storage/hashing.py rename to app/assets/storage/hashing.py diff --git a/app/database/db.py b/app/db.py similarity index 98% rename from app/database/db.py rename to app/db.py index 54f9000cc231..f125706f0189 100644 --- a/app/database/db.py +++ b/app/db.py @@ -27,8 +27,8 @@ def _root_paths(): """Resolve alembic.ini and migrations script folder.""" root_path = os.path.abspath(os.path.dirname(__file__)) - config_path = os.path.abspath(os.path.join(root_path, "../../alembic.ini")) - scripts_path = os.path.abspath(os.path.join(root_path, "../alembic_db")) + config_path = os.path.abspath(os.path.join(root_path, "../alembic.ini")) + scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) return config_path, scripts_path diff --git a/main.py b/main.py index 62279268dc94..9bc9ac9ed527 100644 --- a/main.py +++ b/main.py @@ -279,7 +279,8 @@ def cleanup_temp(): shutil.rmtree(temp_dir, ignore_errors=True) async def setup_database(): - from app import init_db_engine, sync_seed_assets + from app.assets import sync_seed_assets + from app.db import init_db_engine await init_db_engine() if not args.disable_assets_autoscan: diff --git a/server.py b/server.py index d8c1c02c3a1c..424ca9b593d7 100644 --- a/server.py +++ b/server.py @@ -37,7 +37,7 @@ from app.custom_node_manager import CustomNodeManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes -from app import sync_seed_assets, register_assets_system +from app.assets import sync_seed_assets, register_assets_system from protocol import BinaryEventTypes # Import cache control middleware