Skip to content

Commit 8e377ed

Browse files
committed
refactor: lingering gen fixes; style: fix
1 parent 5b2e4ea commit 8e377ed

File tree

18 files changed

+524
-529
lines changed

18 files changed

+524
-529
lines changed

.github/workflows/maintests.yml

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,26 @@ jobs:
1919
build:
2020
env:
2121
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
22-
TESTS_ONGOING: 1
23-
HORDE_SDK_TESTING: 1
24-
HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1
22+
TESTS_ONGOING: "1"
23+
AI_HORDE_TESTING: "1"
2524
runs-on: ubuntu-latest
2625
strategy:
2726
matrix:
28-
python: ["3.12", "3.13"]
27+
python-version:
28+
- "3.12"
29+
- "3.13"
2930

3031
steps:
31-
- uses: actions/checkout@v3
32-
- name: Setup Python
33-
uses: actions/setup-python@v4
32+
- uses: actions/checkout@v4
33+
- name: Install uv and set the python version
34+
uses: astral-sh/setup-uv@v6
3435
with:
35-
python-version: ${{ matrix.python }}
36-
- name: Install tox and any other packages
37-
run: |
38-
python -m pip install --upgrade pip
39-
pip install --upgrade -r requirements.dev.txt
40-
- name: Run unit tests
41-
run: tox -e tests-no-api-calls
36+
python-version: ${{ matrix.python-version }}
37+
enable-cache: true
38+
39+
- name: Install the project
40+
run: uv sync --locked --all-extras --dev
41+
42+
43+
- name: Run tests
44+
run: uv run pytest tests --ignore-glob='**/*api_calls*' -m "not api_side_ci"

.github/workflows/prtests.yml

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,27 @@ jobs:
2222
build:
2323
env:
2424
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
25-
TESTS_ONGOING: 1
26-
HORDE_SDK_TESTING: 1
27-
HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1
25+
TESTS_ONGOING: "1"
26+
AI_HORDE_TESTING: "1"
2827
runs-on: ubuntu-latest
2928
strategy:
3029
matrix:
31-
python: ["3.12", "3.13"]
30+
python-version:
31+
- "3.12"
32+
- "3.13"
3233

3334
steps:
34-
- uses: actions/checkout@v3
35+
- uses: actions/checkout@v4
3536
with:
3637
ref: ${{ github.event.pull_request.head.sha }}
37-
- name: Setup Python
38-
uses: actions/setup-python@v4
38+
- name: Install uv and set the python version
39+
uses: astral-sh/setup-uv@v6
3940
with:
40-
python-version: ${{ matrix.python }}
41-
- name: Install tox and any other packages
42-
run: |
43-
python -m pip install --upgrade pip
44-
pip install --upgrade -r requirements.dev.txt
45-
- name: Run unit tests
46-
run: tox -e tests-no-api-calls
41+
python-version: ${{ matrix.python-version }}
42+
enable-cache: true
43+
44+
- name: Install the project
45+
run: uv sync --locked --all-extras --dev
46+
47+
- name: Run tests
48+
run: uv run pytest tests --ignore-glob='**/*api_calls*' -m "not api_side_ci"

horde_sdk/ai_horde_api/ai_horde_clients.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,7 +1456,7 @@ async def heartbeat_request(
14561456
async def image_generate_request(
14571457
self,
14581458
image_gen_request: ImageGenerateAsyncRequest,
1459-
timeout: int = GENERATION_MAX_LIFE, # noqa: ASYNC109 # FIXME
1459+
timeout: int = GENERATION_MAX_LIFE, # noqa: ASYNC109 # FIXME
14601460
check_callback: Callable[[ImageGenerateCheckResponse], None] | None = None,
14611461
delay: float = 0.0,
14621462
) -> tuple[ImageGenerateStatusResponse, GenerationID]:
@@ -1542,7 +1542,7 @@ async def image_generate_request_dry_run(
15421542
async def alchemy_request(
15431543
self,
15441544
alchemy_request: AlchemyAsyncRequest,
1545-
timeout: int = GENERATION_MAX_LIFE, # noqa: ASYNC109 # FIXME
1545+
timeout: int = GENERATION_MAX_LIFE, # noqa: ASYNC109 # FIXME
15461546
check_callback: Callable[[AlchemyStatusResponse], None] | None = None,
15471547
) -> tuple[AlchemyStatusResponse, GenerationID]:
15481548
"""Submit an alchemy request to the AI-Horde API, and wait for it to complete.
@@ -1588,7 +1588,7 @@ async def alchemy_request(
15881588
async def text_generate_request(
15891589
self,
15901590
text_gen_request: TextGenerateAsyncRequest,
1591-
timeout: int = GENERATION_MAX_LIFE, # noqa: ASYNC109 # FIXME
1591+
timeout: int = GENERATION_MAX_LIFE, # noqa: ASYNC109 # FIXME
15921592
check_callback: Callable[[TextGenerateStatusResponse], None] | None = None,
15931593
delay: float = 0.0,
15941594
) -> tuple[TextGenerateStatusResponse, GenerationID]:

horde_sdk/generation_parameters/generic/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class SchemaVersionedBaseModel(BaseModel):
4646

4747
@model_validator(mode="before")
4848
@classmethod
49-
def _assign_schema_version(cls, data: Any) -> Any: # noqa: ANN401
49+
def _assign_schema_version(cls, data: Any) -> Any: # noqa: ANN401
5050
"""Populate ``schema_version`` when omitted by callers."""
5151
if data is None:
5252
return {"schema_version": cls.SCHEMA_VERSION}
@@ -80,8 +80,6 @@ class GenerationParameterBaseModel(SchemaVersionedBaseModel, AbstractGenerationP
8080
for a list of those LoRa entries.
8181
"""
8282

83-
84-
8583
underlying_generation_scheme: UNDERLYING_GENERATION_SCHEME | None = None
8684
"""The underlying method the generation uses to produce results.
8785
@@ -160,7 +158,7 @@ class GenerationWithModelParameters(GenerationParameterBaseModel):
160158
underlying_generation_scheme: UNDERLYING_GENERATION_SCHEME = UNDERLYING_GENERATION_SCHEME.MODEL
161159
"""See :attr:`ComposedParameterSetBase.underlying_generation_scheme` for more information."""
162160

163-
model: str
161+
model: str | None = None
164162
model_baseline: str | None = None
165163

166164

horde_sdk/generation_parameters/image/object_models.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Mapping, Sequence
3+
from collections.abc import Sequence
44
from pathlib import Path
55
from typing import override
66

@@ -259,6 +259,10 @@ class BasicImageGenerationParameters(BasicImageGenerationParametersTemplate):
259259
from_attributes=True,
260260
)
261261

262+
263+
model: str
264+
"""The model to use for the generation."""
265+
262266
prompt: str
263267
"""The prompt to use for the generation."""
264268

@@ -610,7 +614,8 @@ def get_number_expected_results(self: ImageGenerationParametersTemplate) -> int:
610614
def to_parameters(
611615
self,
612616
*,
613-
base_param_updates: Mapping[str, object] | None = None,
617+
base_param_updates: BasicImageGenerationParametersTemplate | None = None,
618+
additional_param_updates: ImageGenerationComponentContainer | None = None,
614619
result_ids: Sequence[ID_TYPES] | None = None,
615620
allocator: ResultIdAllocator | None = None,
616621
seed: str = "image",
@@ -623,9 +628,16 @@ def to_parameters(
623628
overrides: dict[str, object] | None = None
624629
if base_param_updates:
625630
overrides = {
626-
"base_params": base_params.model_copy(update=dict(base_param_updates)),
631+
"base_params": base_params.model_copy(update=base_param_updates.model_dump(exclude_none=True)),
627632
}
628633

634+
if additional_param_updates:
635+
if overrides is None:
636+
overrides = {}
637+
if not self.additional_params:
638+
raise ValueError("additional_params must be defined before applying updates.")
639+
overrides["additional_params"] = self.additional_params.model_copy(update=dict(additional_param_updates))
640+
629641
finalization = finalize_template_for_parameters(
630642
self,
631643
overrides=overrides,

horde_sdk/generation_parameters/text/object_models.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Mapping, Sequence
3+
from collections.abc import Sequence
44
from enum import auto
55
from typing import Self, override
66

@@ -191,6 +191,9 @@ class BasicTextGenerationParameters(BasicTextGenerationParametersTemplate): # T
191191

192192
model_config = get_default_frozen_model_config_dict()
193193

194+
model: str
195+
"""The model to use for the generation."""
196+
194197
prompt: str # pyright: ignore[reportGeneralTypeIssues, reportIncompatibleVariableOverride]
195198
"""The prompt to use for the generation."""
196199

@@ -216,7 +219,7 @@ def get_number_expected_results(self) -> int:
216219
def to_parameters(
217220
self,
218221
*,
219-
base_param_updates: Mapping[str, object] | None = None,
222+
base_param_updates: BasicTextGenerationParametersTemplate | None = None,
220223
result_ids: Sequence[ID_TYPES] | None = None,
221224
allocator: ResultIdAllocator | None = None,
222225
seed: str = "text",
@@ -229,7 +232,7 @@ def to_parameters(
229232
overrides: dict[str, object] | None = None
230233
if base_param_updates:
231234
overrides = {
232-
"base_params": base_params.model_copy(update=dict(base_param_updates)),
235+
"base_params": base_params.model_copy(update=base_param_updates.model_dump(exclude_none=True)),
233236
}
234237

235238
def _inject_base_params_into_fingerprint(

horde_sdk/generic_api/apimodels.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ class MyDataRootModel(HordeResponseRootModel[MyData]):
138138
)
139139

140140

141-
142141
class HordeResponseBaseModel(HordeResponse, BaseModel):
143142
"""Base class for all Horde API response data models (leveraging pydantic)."""
144143

horde_sdk/generic_api/generic_clients.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -973,11 +973,11 @@ async def _handle_exit_async(
973973
# Log the results of each cleanup request.
974974
for i, cleanup_response in enumerate(cleanup_responses):
975975
if isinstance(cleanup_response, Exception):
976-
logger.error(f"Recovery request {i+1} failed!")
976+
logger.error(f"Recovery request {i + 1} failed!")
977977

978-
logger.info(f"Recovery request {i+1} submitted!")
979-
logger.debug(f"Recovery request {i+1}: {cleanup_requests[i].log_safe_model_dump()}")
980-
logger.debug(f"Recovery response {i+1}: {cleanup_response}")
978+
logger.info(f"Recovery request {i + 1} submitted!")
979+
logger.debug(f"Recovery request {i + 1}: {cleanup_requests[i].log_safe_model_dump()}")
980+
logger.debug(f"Recovery response {i + 1}: {cleanup_response}")
981981

982982
# Return True to indicate that all requests were handled successfully.
983983
return True

horde_sdk/worker/builders.py

Lines changed: 0 additions & 88 deletions
This file was deleted.

horde_sdk/worker/generations.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,19 @@
66
from loguru import logger
77

88
from horde_sdk.consts import ID_TYPES
9+
from horde_sdk.generation_parameters import BasicImageGenerationParametersTemplate
910
from horde_sdk.generation_parameters.alchemy import SingleAlchemyParameters
1011
from horde_sdk.generation_parameters.alchemy.object_models import SingleAlchemyParametersTemplate
1112
from horde_sdk.generation_parameters.image import ImageGenerationParameters
12-
from horde_sdk.generation_parameters.image.object_models import ImageGenerationParametersTemplate
13+
from horde_sdk.generation_parameters.image.object_models import (
14+
ImageGenerationComponentContainer,
15+
ImageGenerationParametersTemplate,
16+
)
1317
from horde_sdk.generation_parameters.text import TextGenerationParameters
14-
from horde_sdk.generation_parameters.text.object_models import TextGenerationParametersTemplate
18+
from horde_sdk.generation_parameters.text.object_models import (
19+
BasicTextGenerationParametersTemplate,
20+
TextGenerationParametersTemplate,
21+
)
1522
from horde_sdk.generation_parameters.utils import ResultIdAllocator
1623
from horde_sdk.safety import SafetyRules, default_image_safety_rules, default_text_safety_rules
1724
from horde_sdk.worker.consts import (
@@ -198,7 +205,8 @@ def from_template(
198205
*,
199206
generation_id: ID_TYPES | None = None,
200207
dispatch_result_ids: Sequence[ID_TYPES] | None = None,
201-
base_param_updates: Mapping[str, object] | None = None,
208+
base_param_updates: BasicImageGenerationParametersTemplate | None = None,
209+
additional_param_updates: ImageGenerationComponentContainer | None = None,
202210
result_ids: Sequence[ID_TYPES] | None = None,
203211
allocator: ResultIdAllocator | None = None,
204212
seed: str = "image",
@@ -503,7 +511,7 @@ def from_template(
503511
*,
504512
generation_id: ID_TYPES | None = None,
505513
dispatch_result_ids: Sequence[ID_TYPES] | None = None,
506-
base_param_updates: Mapping[str, object] | None = None,
514+
base_param_updates: BasicTextGenerationParametersTemplate | None = None,
507515
result_ids: Sequence[ID_TYPES] | None = None,
508516
allocator: ResultIdAllocator | None = None,
509517
seed: str = "text",

0 commit comments

Comments
 (0)