Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions comfy_cli/registry/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,18 @@ def publish_node_version(self, node_config: PyProjectConfig, token) -> PublishNo
"name": node_config.tool_comfy.display_name,
"license": license_json,
"repository": node_config.project.urls.repository,
"supported_os": node_config.project.supported_os,
"supported_accelerators": node_config.project.supported_accelerators,
"supported_comfyui_version": node_config.project.supported_comfyui_version,
"supported_comfyui_frontend_version": node_config.project.supported_comfyui_frontend_version,
},
"node_version": {
"version": node_config.project.version,
"dependencies": node_config.project.dependencies,
"supported_os": node_config.project.supported_os,
"supported_accelerators": node_config.project.supported_accelerators,
"supported_comfyui_version": node_config.project.supported_comfyui_version,
"supported_comfyui_frontend_version": node_config.project.supported_comfyui_frontend_version,
},
}
print(request_body)
Expand Down
99 changes: 98 additions & 1 deletion comfy_cli/registry/config_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import subprocess
from typing import Optional

Expand Down Expand Up @@ -87,6 +88,76 @@ def sanitize_node_name(name: str) -> str:
return name


def validate_and_extract_os_classifiers(classifiers: list) -> list:
os_classifiers = [c for c in classifiers if c.startswith("Operating System :: ")]
if not os_classifiers:
return []

os_values = [c[len("Operating System :: ") :] for c in os_classifiers]
valid_os_prefixes = {"Microsoft", "POSIX", "MacOS", "OS Independent"}

for os_value in os_values:
if not any(os_value.startswith(prefix) for prefix in valid_os_prefixes):
typer.echo(
'Warning: Invalid Operating System classifier found. Operating System classifiers must start with one of: "Microsoft", "POSIX", "MacOS", "OS Independent". '
'Examples: "Operating System :: Microsoft :: Windows", "Operating System :: POSIX :: Linux", "Operating System :: MacOS", "Operating System :: OS Independent". '
"No OS information will be populated."
)
return []

return os_values


def validate_and_extract_accelerator_classifiers(classifiers: list) -> list:
accelerator_classifiers = [c for c in classifiers if c.startswith("Environment ::")]
if not accelerator_classifiers:
return []

accelerator_values = [c[len("Environment :: ") :] for c in accelerator_classifiers]

valid_accelerators = {
"GPU :: NVIDIA CUDA",
"GPU :: AMD ROCm",
"GPU :: Intel Arc",
"NPU :: Huawei Ascend",
"GPU :: Apple Metal",
}

for accelerator_value in accelerator_values:
if accelerator_value not in valid_accelerators:
typer.echo(
"Warning: Invalid Environment classifier found. Environment classifiers must be one of: "
'"Environment :: GPU :: NVIDIA CUDA", "Environment :: GPU :: AMD ROCm", "Environment :: GPU :: Intel Arc", '
'"Environment :: NPU :: Huawei Ascend", "Environment :: GPU :: Apple Metal". '
"No accelerator information will be populated."
)
return []

return accelerator_values


def validate_version(version: str, field_name: str) -> str:
if not version:
return version

version_pattern = r"^(?:(==|>=|<=|!=|~=|>|<|<>|=)\s*)?(\d+\.\d+\.\d+(?:-[a-zA-Z0-9]+)?)?$"

version_parts = [part.strip() for part in version.split(",")]
for part in version_parts:
if not re.match(version_pattern, part):
typer.echo(
f'Warning: Invalid {field_name} format: "{version}". '
f"Each version part must follow the pattern: [operator][version] where operator is optional (==, >=, <=, !=, ~=, >, <, <>, =) "
f"and version is in format major.minor.patch[-suffix]. "
f"Multiple versions can be comma-separated. "
f'Examples: ">=1.0.0", "==2.1.0-beta", "1.5.2", ">=1.0.0,<2.0.0". '
f"No {field_name} will be populated."
)
return ""

return version


def initialize_project_config():
create_comfynode_config()

Expand Down Expand Up @@ -157,6 +228,28 @@ def extract_node_configuration(
urls_data = project_data.get("urls", {})
comfy_data = data.get("tool", {}).get("comfy", {})

dependencies = project_data.get("dependencies", [])
supported_comfyui_frontend_version = ""
for dep in dependencies:
if isinstance(dep, str) and dep.startswith("comfyui-frontend-package"):
supported_comfyui_frontend_version = dep.removeprefix("comfyui-frontend-package")
break

# Remove the ComfyUI-frontend dependency from the dependencies list
dependencies = [
dep for dep in dependencies if not (isinstance(dep, str) and dep.startswith("comfyui-frontend-package"))
]

supported_comfyui_version = data.get("tool", {}).get("comfy", {}).get("requires-comfyui", "")

classifiers = project_data.get("classifiers", [])
supported_os = validate_and_extract_os_classifiers(classifiers)
supported_accelerators = validate_and_extract_accelerator_classifiers(classifiers)
supported_comfyui_version = validate_version(supported_comfyui_version, "requires-comfyui")
supported_comfyui_frontend_version = validate_version(
supported_comfyui_frontend_version, "comfyui-frontend-package"
)

license_data = project_data.get("license", {})
if isinstance(license_data, str):
license = License(text=license_data)
Expand All @@ -182,14 +275,18 @@ def extract_node_configuration(
description=project_data.get("description", ""),
version=project_data.get("version", ""),
requires_python=project_data.get("requires-python", ""),
dependencies=project_data.get("dependencies", []),
dependencies=dependencies,
license=license,
urls=URLs(
homepage=urls_data.get("Homepage", ""),
documentation=urls_data.get("Documentation", ""),
repository=urls_data.get("Repository", ""),
issues=urls_data.get("Issues", ""),
),
supported_os=supported_os,
supported_accelerators=supported_accelerators,
supported_comfyui_version=supported_comfyui_version,
supported_comfyui_frontend_version=supported_comfyui_frontend_version,
)

comfy = ComfyConfig(
Expand Down
4 changes: 4 additions & 0 deletions comfy_cli/registry/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class ProjectConfig:
dependencies: List[str] = field(default_factory=list)
license: License = field(default_factory=License)
urls: URLs = field(default_factory=URLs)
supported_os: List[str] = field(default_factory=list)
supported_accelerators: List[str] = field(default_factory=list)
supported_comfyui_version: str = ""
supported_comfyui_frontend_version: str = ""


@dataclass
Expand Down
202 changes: 201 additions & 1 deletion tests/comfy_cli/registry/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import pytest

from comfy_cli.registry.config_parser import extract_node_configuration
from comfy_cli.registry.config_parser import (
extract_node_configuration,
validate_and_extract_accelerator_classifiers,
validate_and_extract_os_classifiers,
validate_version,
)
from comfy_cli.registry.types import (
License,
Model,
Expand Down Expand Up @@ -127,3 +132,198 @@ def test_extract_license_incorrect_format():
assert result is not None, "Expected PyProjectConfig, got None"
assert isinstance(result, PyProjectConfig)
assert result.project.license == License(text="MIT")


def test_extract_node_configuration_with_os_classifiers():
mock_data = {
"project": {
"classifiers": [
"Operating System :: OS Independent",
"Operating System :: Microsoft :: Windows",
"Programming Language :: Python :: 3",
"Topic :: Software Development",
]
}
}
with (
patch("os.path.isfile", return_value=True),
patch("builtins.open", mock_open()),
patch("tomlkit.load", return_value=mock_data),
):
result = extract_node_configuration("fake_path.toml")

assert result is not None
assert len(result.project.supported_os) == 2
assert "OS Independent" in result.project.supported_os
assert "Microsoft :: Windows" in result.project.supported_os


def test_extract_node_configuration_with_accelerator_classifiers():
mock_data = {
"project": {
"classifiers": [
"Environment :: GPU :: NVIDIA CUDA",
"Environment :: GPU :: AMD ROCm",
"Environment :: GPU :: Intel Arc",
"Environment :: NPU :: Huawei Ascend",
"Environment :: GPU :: Apple Metal",
"Programming Language :: Python :: 3",
"Topic :: Software Development",
]
}
}
with (
patch("os.path.isfile", return_value=True),
patch("builtins.open", mock_open()),
patch("tomlkit.load", return_value=mock_data),
):
result = extract_node_configuration("fake_path.toml")

assert result is not None
assert len(result.project.supported_accelerators) == 5
assert "GPU :: NVIDIA CUDA" in result.project.supported_accelerators
assert "GPU :: AMD ROCm" in result.project.supported_accelerators
assert "GPU :: Intel Arc" in result.project.supported_accelerators
assert "NPU :: Huawei Ascend" in result.project.supported_accelerators
assert "GPU :: Apple Metal" in result.project.supported_accelerators


def test_extract_node_configuration_with_comfyui_version():
mock_data = {"project": {"dependencies": ["packge1>=2.0.0", "comfyui-frontend-package>=1.2.3", "package2>=1.0.0"]}}
with (
patch("os.path.isfile", return_value=True),
patch("builtins.open", mock_open()),
patch("tomlkit.load", return_value=mock_data),
):
result = extract_node_configuration("fake_path.toml")

assert result is not None
assert result.project.supported_comfyui_frontend_version == ">=1.2.3"
assert len(result.project.dependencies) == 2
assert "comfyui-frontend-package>=1.2.3" not in result.project.dependencies
assert "packge1>=2.0.0" in result.project.dependencies
assert "package2>=1.0.0" in result.project.dependencies


def test_extract_node_configuration_with_requires_comfyui():
mock_data = {"project": {}, "tool": {"comfy": {"requires-comfyui": "2.0.0"}}}
with (
patch("os.path.isfile", return_value=True),
patch("builtins.open", mock_open()),
patch("tomlkit.load", return_value=mock_data),
):
result = extract_node_configuration("fake_path.toml")

assert result is not None
assert result.project.supported_comfyui_version == "2.0.0"


def test_validate_and_extract_os_classifiers_valid():
"""Test OS validation with valid classifiers."""
classifiers = [
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Operating System :: MacOS",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
]
result = validate_and_extract_os_classifiers(classifiers)
expected = ["Microsoft :: Windows", "POSIX :: Linux", "MacOS", "OS Independent"]
assert result == expected


@patch("typer.echo")
def test_validate_and_extract_os_classifiers_invalid(mock_echo):
"""Test OS validation with invalid classifiers."""
classifiers = [
"Operating System :: Microsoft :: Windows",
"Operating System :: Linux", # Invalid - should be "POSIX :: Linux"
"Programming Language :: Python :: 3",
]
result = validate_and_extract_os_classifiers(classifiers)
assert result == []
mock_echo.assert_called_once()
assert "Invalid Operating System classifier found" in mock_echo.call_args[0][0]


def test_validate_and_extract_accelerator_classifiers_valid():
"""Test accelerator validation with valid classifiers."""
classifiers = [
"Environment :: GPU :: NVIDIA CUDA",
"Environment :: GPU :: AMD ROCm",
"Environment :: GPU :: Intel Arc",
"Environment :: NPU :: Huawei Ascend",
"Environment :: GPU :: Apple Metal",
"Programming Language :: Python :: 3",
]
result = validate_and_extract_accelerator_classifiers(classifiers)
expected = [
"GPU :: NVIDIA CUDA",
"GPU :: AMD ROCm",
"GPU :: Intel Arc",
"NPU :: Huawei Ascend",
"GPU :: Apple Metal",
]
assert result == expected


@patch("typer.echo")
def test_validate_and_extract_accelerator_classifiers_invalid(mock_echo):
"""Test accelerator validation with invalid classifiers."""
classifiers = [
"Environment :: GPU :: NVIDIA CUDA",
"Environment :: GPU :: Invalid GPU", # Invalid
"Programming Language :: Python :: 3",
]
result = validate_and_extract_accelerator_classifiers(classifiers)
assert result == []
mock_echo.assert_called_once()
assert "Invalid Environment classifier found" in mock_echo.call_args[0][0]


def test_validate_version_valid():
"""Test version validation with valid versions."""
valid_versions = [
"1.1.1",
">=1.0.0",
"==2.1.0-beta",
"1.5.2",
"~=3.0.0",
"!=1.2.3",
">2.0.0",
"<3.0.0",
"<=4.0.0",
"<>1.0.0",
"=1.0.0",
"1.0.0-alpha1",
">=1.0.0,<2.0.0",
"==1.2.3,!=1.2.4",
">=1.0.0,<=2.0.0,!=1.5.0",
"1.0.0,2.0.0",
">1.0.0,<2.0.0,!=1.5.0-beta",
]

for version in valid_versions:
result = validate_version(version, "test_field")
assert result == version, f"Version {version} should be valid"


@patch("typer.echo")
def test_validate_version_invalid(mock_echo):
"""Test version validation with invalid versions."""
invalid_versions = [
"1.0", # Missing patch version
">=abc", # Invalid version format
"invalid-version", # Completely invalid
"1.0.0.0", # Too many version parts
">>1.0.0", # Invalid operator
">=1.0.0,invalid",
"1.0,2.0.0",
">=1.0.0,>=abc",
]

for version in invalid_versions:
result = validate_version(version, "test_field")
assert result == "", f"Version {version} should be invalid"

assert mock_echo.call_count == len(invalid_versions)