Skip to content

Commit e2bdb90

Browse files
committed
✅ Update tests for breaking API changes
Signed-off-by: Evaline Ju <[email protected]>
1 parent e63e0db commit e2bdb90

File tree

3 files changed

+42
-30
lines changed

3 files changed

+42
-30
lines changed

tests/generative_detectors/test_base.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Standard
2-
from dataclasses import dataclass
3-
from typing import Optional
4-
from unittest.mock import patch
2+
from dataclasses import dataclass, field
3+
from typing import Any, Optional
4+
from unittest.mock import MagicMock, patch
55
import asyncio
66

77
# Third Party
@@ -49,17 +49,23 @@ class MockHFConfig:
4949
@dataclass
5050
class MockModelConfig:
5151
task = "generate"
52+
runner_type = "generate"
5253
tokenizer = MODEL_NAME
5354
trust_remote_code = False
5455
tokenizer_mode = "auto"
5556
max_model_len = 100
5657
tokenizer_revision = None
57-
embedding_mode = False
5858
multimodal_config = MultiModalConfig()
59-
diff_sampling_param: Optional[dict] = None
6059
hf_config = MockHFConfig()
6160
logits_processor_pattern = None
61+
logits_processors: list[str] | None = None
62+
diff_sampling_param: dict | None = None
6263
allowed_local_media_path: str = ""
64+
allowed_media_domains: list[str] | None = None
65+
encoder_config = None
66+
generation_config: str = "auto"
67+
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
68+
skip_tokenizer_init = False
6369

6470
def get_diff_sampling_param(self):
6571
return self.diff_sampling_param or {}
@@ -78,18 +84,18 @@ async def _async_serving_detection_completion_init():
7884
"""Initialize a chat completion base with string templates"""
7985
engine = MockEngine()
8086
engine.errored = False
81-
model_config = await engine.get_model_config()
87+
engine.model_config = MockModelConfig()
88+
engine.input_processor = MagicMock()
89+
engine.io_processor = MagicMock()
8290
models = OpenAIServingModels(
8391
engine_client=engine,
84-
model_config=model_config,
8592
base_model_paths=BASE_MODEL_PATHS,
8693
)
8794

8895
detection_completion = ChatCompletionDetectionBase(
8996
task_template="hello {{user_text}}",
9097
output_template="bye {{text}}",
9198
engine_client=engine,
92-
model_config=model_config,
9399
models=models,
94100
response_role="assistant",
95101
chat_template=CHAT_TEMPLATE,

tests/generative_detectors/test_granite_guardian.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Standard
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
33
from http import HTTPStatus
4-
from typing import Optional
5-
from unittest.mock import patch
4+
from typing import Any, Optional
5+
from unittest.mock import MagicMock, patch
66
import asyncio
77
import json
88

@@ -76,27 +76,30 @@ class MockHFConfig:
7676
@dataclass
7777
class MockModelConfig:
7878
task = "generate"
79+
runner_type = "generate"
7980
tokenizer = MODEL_NAME
8081
trust_remote_code = False
8182
tokenizer_mode = "auto"
8283
max_model_len = 100
8384
tokenizer_revision = None
84-
embedding_mode = False
8585
multimodal_config = MultiModalConfig()
86-
diff_sampling_param: Optional[dict] = None
8786
hf_config = MockHFConfig()
8887
logits_processor_pattern = None
88+
logits_processors: list[str] | None = None
89+
diff_sampling_param: dict | None = None
8990
allowed_local_media_path: str = ""
91+
allowed_media_domains: list[str] | None = None
92+
encoder_config = None
93+
generation_config: str = "auto"
94+
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
95+
skip_tokenizer_init = False
9096

9197
def get_diff_sampling_param(self):
9298
return self.diff_sampling_param or {}
9399

94100

95101
@dataclass
96102
class MockEngine:
97-
async def get_model_config(self):
98-
return MockModelConfig()
99-
100103
async def get_tokenizer(self):
101104
return MockTokenizer()
102105

@@ -105,18 +108,18 @@ async def _granite_guardian_init():
105108
"""Initialize a granite guardian"""
106109
engine = MockEngine()
107110
engine.errored = False
108-
model_config = await engine.get_model_config()
111+
engine.model_config = MockModelConfig()
112+
engine.input_processor = MagicMock()
113+
engine.io_processor = MagicMock()
109114
models = OpenAIServingModels(
110115
engine_client=engine,
111-
model_config=model_config,
112116
base_model_paths=BASE_MODEL_PATHS,
113117
)
114118

115119
granite_guardian = GraniteGuardian(
116120
task_template=None,
117121
output_template=None,
118122
engine_client=engine,
119-
model_config=model_config,
120123
models=models,
121124
response_role="assistant",
122125
chat_template=CHAT_TEMPLATE,

tests/generative_detectors/test_llama_guard.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Standard
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
33
from http import HTTPStatus
4-
from typing import Optional
5-
from unittest.mock import patch
4+
from typing import Any, Optional
5+
from unittest.mock import MagicMock, patch
66
import asyncio
77

88
# Third Party
@@ -54,27 +54,30 @@ class MockHFConfig:
5454
@dataclass
5555
class MockModelConfig:
5656
task = "generate"
57+
runner_type = "generate"
5758
tokenizer = MODEL_NAME
5859
trust_remote_code = False
5960
tokenizer_mode = "auto"
6061
max_model_len = 100
6162
tokenizer_revision = None
62-
embedding_mode = False
6363
multimodal_config = MultiModalConfig()
64-
diff_sampling_param: Optional[dict] = None
6564
hf_config = MockHFConfig()
6665
logits_processor_pattern = None
66+
logits_processors: list[str] | None = None
67+
diff_sampling_param: dict | None = None
6768
allowed_local_media_path: str = ""
69+
allowed_media_domains: list[str] | None = None
70+
encoder_config = None
71+
generation_config: str = "auto"
72+
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
73+
skip_tokenizer_init = False
6874

6975
def get_diff_sampling_param(self):
7076
return self.diff_sampling_param or {}
7177

7278

7379
@dataclass
7480
class MockEngine:
75-
async def get_model_config(self):
76-
return MockModelConfig()
77-
7881
async def get_tokenizer(self):
7982
return MockTokenizer()
8083

@@ -83,18 +86,18 @@ async def _llama_guard_init():
8386
"""Initialize a llama guard"""
8487
engine = MockEngine()
8588
engine.errored = False
86-
model_config = await engine.get_model_config()
89+
engine.model_config = MockModelConfig()
90+
engine.input_processor = MagicMock()
91+
engine.io_processor = MagicMock()
8792
models = OpenAIServingModels(
8893
engine_client=engine,
89-
model_config=model_config,
9094
base_model_paths=BASE_MODEL_PATHS,
9195
)
9296

9397
llama_guard_detection = LlamaGuard(
9498
task_template=None,
9599
output_template=None,
96100
engine_client=engine,
97-
model_config=model_config,
98101
models=models,
99102
response_role="assistant",
100103
chat_template=CHAT_TEMPLATE,

0 commit comments

Comments
 (0)