Skip to content

Commit ab38df4

Browse files
authored
Merge pull request #63 from jkoelker/jk/refactor
refactor(civitai): extract parse_civitai_url
2 parents 39005cf + 03cc644 commit ab38df4

File tree

7 files changed

+76
-146
lines changed

7 files changed

+76
-146
lines changed

src/oneiro/bot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import discord
1010
from discord import option
1111

12-
from oneiro.civitai import CivitaiClient, CivitaiError
12+
from oneiro.civitai import CivitaiClient, CivitaiError, parse_civitai_url
1313
from oneiro.config import Config
1414
from oneiro.filters import ContentFilter
1515
from oneiro.lora_detector import AutoLoraDetector, create_detector_from_config
@@ -21,7 +21,7 @@
2121
PipelineManager,
2222
)
2323
from oneiro.pipelines.civitai_checkpoint import CivitaiCheckpointPipeline
24-
from oneiro.pipelines.lora import is_lora_compatible, parse_civitai_url
24+
from oneiro.pipelines.lora import is_lora_compatible
2525
from oneiro.queue import GenerationQueue, QueueStatus
2626

2727
# Global managers (initialized on bot ready)

src/oneiro/civitai.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import hashlib
55
import json
66
import os
7+
import re
78
import shutil
89
from collections.abc import Awaitable, Callable
910
from dataclasses import dataclass, field
@@ -63,6 +64,38 @@ class BaseModel(str, Enum):
6364
FLUX_1 = "Flux.1"
6465

6566

67+
def parse_civitai_url(url: str) -> tuple[int, int | None]:
68+
"""Parse Civitai URL to extract model ID and optional version ID.
69+
70+
Supports formats:
71+
- https://civitai.com/models/12345
72+
- https://civitai.com/models/12345/model-name
73+
- https://civitai.com/models/12345?modelVersionId=67890
74+
- https://civitai.com/models/12345/name?modelVersionId=67890
75+
76+
Args:
77+
url: Civitai model URL
78+
79+
Returns:
80+
Tuple of (model_id, version_id or None)
81+
82+
Raises:
83+
ValueError: If URL format is invalid
84+
"""
85+
# Match model ID in path
86+
model_match = re.search(r"/models/(\d+)", url)
87+
if not model_match:
88+
raise ValueError(f"Invalid Civitai URL format: {url}")
89+
90+
model_id = int(model_match.group(1))
91+
92+
# Check for version in query string
93+
version_match = re.search(r"modelVersionId=(\d+)", url)
94+
version_id = int(version_match.group(1)) if version_match else None
95+
96+
return model_id, version_id
97+
98+
6699
@dataclass
67100
class ModelFile:
68101
"""Represents a downloadable file in a model version."""

src/oneiro/pipelines/embedding.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
"""Textual inversion / embedding configuration types and loading utilities."""
22

3-
import re
43
from dataclasses import dataclass, field
54
from enum import Enum
65
from pathlib import Path
76
from typing import Any
87

9-
from oneiro.civitai import CivitaiClient
8+
from oneiro.civitai import CivitaiClient, parse_civitai_url
109
from oneiro.pipelines.lora import PIPELINE_BASE_MODEL_MAP
1110

1211

@@ -69,38 +68,6 @@ def __post_init__(self) -> None:
6968
raise ValueError("local source requires path")
7069

7170

72-
def parse_civitai_url(url: str) -> tuple[int, int | None]:
73-
"""Parse Civitai URL to extract model ID and optional version ID.
74-
75-
Supports formats:
76-
- https://civitai.com/models/12345
77-
- https://civitai.com/models/12345/model-name
78-
- https://civitai.com/models/12345?modelVersionId=67890
79-
- https://civitai.com/models/12345/name?modelVersionId=67890
80-
81-
Args:
82-
url: Civitai model URL
83-
84-
Returns:
85-
Tuple of (model_id, version_id or None)
86-
87-
Raises:
88-
ValueError: If URL format is invalid
89-
"""
90-
# Match model ID in path
91-
model_match = re.search(r"/models/(\d+)", url)
92-
if not model_match:
93-
raise ValueError(f"Invalid Civitai URL format: {url}")
94-
95-
model_id = int(model_match.group(1))
96-
97-
# Check for version in query string
98-
version_match = re.search(r"modelVersionId=(\d+)", url)
99-
version_id = int(version_match.group(1)) if version_match else None
100-
101-
return model_id, version_id
102-
103-
10471
def parse_embedding_config(
10572
config: dict[str, Any] | str, name: str | None = None
10673
) -> EmbeddingConfig:

src/oneiro/pipelines/lora.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
"""LoRA configuration types and loading utilities."""
22

3-
import re
43
from dataclasses import dataclass, field
54
from enum import Enum
65
from pathlib import Path
76
from typing import Any
87

9-
from oneiro.civitai import CivitaiClient
8+
from oneiro.civitai import CivitaiClient, parse_civitai_url
109

1110

1211
class LoraSource(str, Enum):
@@ -80,38 +79,6 @@ def __post_init__(self) -> None:
8079
raise ValueError("local source requires path")
8180

8281

83-
def parse_civitai_url(url: str) -> tuple[int, int | None]:
84-
"""Parse Civitai URL to extract model ID and optional version ID.
85-
86-
Supports formats:
87-
- https://civitai.com/models/12345
88-
- https://civitai.com/models/12345/model-name
89-
- https://civitai.com/models/12345?modelVersionId=67890
90-
- https://civitai.com/models/12345/name?modelVersionId=67890
91-
92-
Args:
93-
url: Civitai model URL
94-
95-
Returns:
96-
Tuple of (model_id, version_id or None)
97-
98-
Raises:
99-
ValueError: If URL format is invalid
100-
"""
101-
# Match model ID in path
102-
model_match = re.search(r"/models/(\d+)", url)
103-
if not model_match:
104-
raise ValueError(f"Invalid Civitai URL format: {url}")
105-
106-
model_id = int(model_match.group(1))
107-
108-
# Check for version in query string
109-
version_match = re.search(r"modelVersionId=(\d+)", url)
110-
version_id = int(version_match.group(1)) if version_match else None
111-
112-
return model_id, version_id
113-
114-
11582
def parse_lora_config(config: dict[str, Any] | str, index: int = 0) -> LoraConfig:
11683
"""Parse a LoRA configuration from TOML config format.
11784

tests/test_civitai.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,47 @@
1717
ModelFile,
1818
ModelType,
1919
ModelVersion,
20+
parse_civitai_url,
2021
)
2122

23+
24+
class TestParseCivitaiUrl:
25+
"""Tests for parse_civitai_url function."""
26+
27+
def test_basic_model_url(self):
28+
"""Parses basic model URL."""
29+
model_id, version_id = parse_civitai_url("https://civitai.com/models/12345")
30+
assert model_id == 12345
31+
assert version_id is None
32+
33+
def test_model_url_with_name(self):
34+
"""Parses model URL with name slug."""
35+
model_id, version_id = parse_civitai_url("https://civitai.com/models/12345/my-cool-model")
36+
assert model_id == 12345
37+
assert version_id is None
38+
39+
def test_model_url_with_version(self):
40+
"""Parses model URL with version ID in query string."""
41+
model_id, version_id = parse_civitai_url(
42+
"https://civitai.com/models/12345?modelVersionId=67890"
43+
)
44+
assert model_id == 12345
45+
assert version_id == 67890
46+
47+
def test_model_url_with_name_and_version(self):
48+
"""Parses model URL with name and version."""
49+
model_id, version_id = parse_civitai_url(
50+
"https://civitai.com/models/12345/model-name?modelVersionId=67890"
51+
)
52+
assert model_id == 12345
53+
assert version_id == 67890
54+
55+
def test_invalid_url_raises(self):
56+
"""Invalid URL raises ValueError."""
57+
with pytest.raises(ValueError, match="Invalid Civitai URL"):
58+
parse_civitai_url("https://example.com/something")
59+
60+
2261
# Sample API responses for testing
2362
SAMPLE_MODEL_RESPONSE = {
2463
"id": 12345,

tests/test_embedding.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,49 +11,11 @@
1111
EmbeddingIncompatibleError,
1212
EmbeddingSource,
1313
is_embedding_compatible,
14-
parse_civitai_url,
1514
parse_embedding_config,
1615
parse_embeddings_from_config,
1716
)
1817

1918

20-
class TestParseCivitaiUrl:
21-
"""Tests for parse_civitai_url function."""
22-
23-
def test_basic_model_url(self):
24-
"""Parses basic model URL."""
25-
model_id, version_id = parse_civitai_url("https://civitai.com/models/12345")
26-
assert model_id == 12345
27-
assert version_id is None
28-
29-
def test_model_url_with_name(self):
30-
"""Parses model URL with name slug."""
31-
model_id, version_id = parse_civitai_url("https://civitai.com/models/12345/my-cool-model")
32-
assert model_id == 12345
33-
assert version_id is None
34-
35-
def test_model_url_with_version(self):
36-
"""Parses model URL with version ID in query string."""
37-
model_id, version_id = parse_civitai_url(
38-
"https://civitai.com/models/12345?modelVersionId=67890"
39-
)
40-
assert model_id == 12345
41-
assert version_id == 67890
42-
43-
def test_model_url_with_name_and_version(self):
44-
"""Parses model URL with name and version."""
45-
model_id, version_id = parse_civitai_url(
46-
"https://civitai.com/models/12345/model-name?modelVersionId=67890"
47-
)
48-
assert model_id == 12345
49-
assert version_id == 67890
50-
51-
def test_invalid_url_raises(self):
52-
"""Invalid URL raises ValueError."""
53-
with pytest.raises(ValueError, match="Invalid Civitai URL"):
54-
parse_civitai_url("https://example.com/something")
55-
56-
5719
class TestEmbeddingConfig:
5820
"""Tests for EmbeddingConfig dataclass."""
5921

tests/test_lora.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,50 +11,12 @@
1111
LoraIncompatibleError,
1212
LoraSource,
1313
is_lora_compatible,
14-
parse_civitai_url,
1514
parse_lora_config,
1615
parse_loras_from_config,
1716
parse_loras_from_model_config,
1817
)
1918

2019

21-
class TestParseCivitaiUrl:
22-
"""Tests for parse_civitai_url function."""
23-
24-
def test_basic_model_url(self):
25-
"""Parses basic model URL."""
26-
model_id, version_id = parse_civitai_url("https://civitai.com/models/12345")
27-
assert model_id == 12345
28-
assert version_id is None
29-
30-
def test_model_url_with_name(self):
31-
"""Parses model URL with name slug."""
32-
model_id, version_id = parse_civitai_url("https://civitai.com/models/12345/my-cool-model")
33-
assert model_id == 12345
34-
assert version_id is None
35-
36-
def test_model_url_with_version(self):
37-
"""Parses model URL with version ID in query string."""
38-
model_id, version_id = parse_civitai_url(
39-
"https://civitai.com/models/12345?modelVersionId=67890"
40-
)
41-
assert model_id == 12345
42-
assert version_id == 67890
43-
44-
def test_model_url_with_name_and_version(self):
45-
"""Parses model URL with name and version."""
46-
model_id, version_id = parse_civitai_url(
47-
"https://civitai.com/models/12345/model-name?modelVersionId=67890"
48-
)
49-
assert model_id == 12345
50-
assert version_id == 67890
51-
52-
def test_invalid_url_raises(self):
53-
"""Invalid URL raises ValueError."""
54-
with pytest.raises(ValueError, match="Invalid Civitai URL"):
55-
parse_civitai_url("https://example.com/something")
56-
57-
5820
class TestLoraConfig:
5921
"""Tests for LoraConfig dataclass."""
6022

0 commit comments

Comments
 (0)