Skip to content

Commit b0b53db

Browse files
committed
code cleanup
1 parent 418eecf commit b0b53db

File tree

4 files changed

+177
-30
lines changed

4 files changed

+177
-30
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ python-dateutil = "^2.9.0"
2222
requests = "^2.28.2"
2323
typer = "^0.15.4"
2424
urllib3 = "^2.6.1"
25+
pyyaml = "^6.0.3"
2526

2627
[tool.poetry.group.dev.dependencies]
2728
datamodel-code-generator = "^0.35.0"

src/groundlight/edge/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .config import (
2+
ConfigBase,
23
DEFAULT,
34
DISABLED,
45
EDGE_ANSWERS_WITH_ESCALATION,
@@ -15,6 +16,7 @@
1516
"DISABLED",
1617
"EDGE_ANSWERS_WITH_ESCALATION",
1718
"NO_CLOUD",
19+
"ConfigBase",
1820
"DetectorsConfig",
1921
"DetectorConfig",
2022
"EdgeEndpointConfig",

src/groundlight/edge/config.py

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from typing import Any, Optional, Union
22

33
from model import Detector
4-
from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator
4+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
55
from typing_extensions import Self
6+
import yaml
67

78

89
class GlobalConfig(BaseModel):
10+
"""Global runtime settings for edge-endpoint behavior."""
11+
12+
model_config = ConfigDict(extra="forbid")
13+
914
refresh_rate: float = Field(
1015
default=60.0,
1116
description="The interval (in seconds) at which the inference server checks for a new model binary update.",
@@ -22,7 +27,7 @@ class InferenceConfig(BaseModel):
2227
"""
2328

2429
# Keep shared presets immutable (DEFAULT/NO_CLOUD/etc.) so one mutation cannot globally change behavior.
25-
model_config = ConfigDict(frozen=True)
30+
model_config = ConfigDict(extra="forbid", frozen=True)
2631

2732
name: str = Field(..., exclude=True, description="A unique name for this inference config preset.")
2833
enabled: bool = Field(
@@ -71,18 +76,41 @@ class DetectorConfig(BaseModel):
7176
Configuration for a specific detector.
7277
"""
7378

79+
model_config = ConfigDict(extra="forbid")
80+
7481
detector_id: str = Field(..., description="Detector ID")
7582
edge_inference_config: str = Field(..., description="Config for edge inference.")
7683

7784

78-
class DetectorsConfig(BaseModel):
79-
"""
80-
Detector and inference-config mappings for edge inference.
81-
"""
85+
class ConfigBase(BaseModel):
86+
"""Shared detector/inference configuration behavior for edge config models."""
87+
88+
model_config = ConfigDict(extra="forbid")
8289

8390
edge_inference_configs: dict[str, InferenceConfig] = Field(default_factory=dict)
8491
detectors: list[DetectorConfig] = Field(default_factory=list)
8592

93+
@field_validator("edge_inference_configs", mode="before")
94+
@classmethod
95+
def hydrate_inference_config_names(
96+
cls, value: dict[str, InferenceConfig | dict[str, Any]] | None
97+
) -> dict[str, InferenceConfig | dict[str, Any]]:
98+
"""Hydrate InferenceConfig.name from payload mapping keys."""
99+
if value is None:
100+
return {}
101+
if not isinstance(value, dict):
102+
return value
103+
104+
hydrated_configs: dict[str, InferenceConfig | dict[str, Any]] = {}
105+
for name, config in value.items():
106+
if isinstance(config, InferenceConfig):
107+
hydrated_configs[name] = config
108+
continue
109+
if not isinstance(config, dict):
110+
raise TypeError("Each edge inference config must be an object.")
111+
hydrated_configs[name] = {"name": name, **config}
112+
return hydrated_configs
113+
86114
@model_validator(mode="after")
87115
def validate_inference_configs(self):
88116
"""
@@ -128,7 +156,7 @@ def add_detector(self, detector: Union[str, Detector], edge_inference_config: In
128156
self.detectors.append(DetectorConfig(detector_id=detector_id, edge_inference_config=edge_inference_config.name))
129157

130158
def to_payload(self) -> dict[str, Any]:
131-
"""Return flattened detector payload used by edge-endpoint config HTTP APIs."""
159+
"""Return detector payload used by edge-endpoint config HTTP APIs."""
132160
return {
133161
"edge_inference_configs": {
134162
name: config.model_dump() for name, config in self.edge_inference_configs.items()
@@ -137,36 +165,54 @@ def to_payload(self) -> dict[str, Any]:
137165
}
138166

139167

140-
class EdgeEndpointConfig(BaseModel):
168+
class DetectorsConfig(ConfigBase):
169+
"""
170+
Detector and inference-config mappings for edge inference.
171+
"""
172+
173+
174+
class EdgeEndpointConfig(ConfigBase):
141175
"""
142176
Top-level edge endpoint configuration.
143177
"""
144178

145179
global_config: GlobalConfig = Field(default_factory=GlobalConfig)
146-
detectors_config: DetectorsConfig = Field(default_factory=DetectorsConfig)
147-
148-
@property
149-
def edge_inference_configs(self) -> dict[str, InferenceConfig]:
150-
"""Convenience accessor for detector inference config map."""
151-
return self.detectors_config.edge_inference_configs
152180

153-
@property
154-
def detectors(self) -> list[DetectorConfig]:
155-
"""Convenience accessor for detector assignments."""
156-
return self.detectors_config.detectors
181+
@classmethod
182+
def from_yaml(
183+
cls,
184+
filename: Optional[str] = None,
185+
yaml_str: Optional[str] = None,
186+
) -> "EdgeEndpointConfig":
187+
"""Create an EdgeEndpointConfig from a YAML filename or YAML string."""
188+
if filename is None and yaml_str is None:
189+
raise ValueError("Either filename or yaml_str must be provided.")
190+
if filename is not None and yaml_str is not None:
191+
raise ValueError("Only one of filename or yaml_str can be provided.")
192+
if filename is not None:
193+
if not filename.strip():
194+
raise ValueError("filename must be a non-empty path when provided.")
195+
with open(filename, "r") as f:
196+
yaml_str = f.read()
197+
198+
yaml_text = yaml_str or ""
199+
parsed = yaml.safe_load(yaml_text) or {}
200+
return cls.model_validate(parsed)
157201

158-
@model_serializer(mode="plain")
159-
def serialize(self):
160-
"""Serialize to the flattened shape expected by edge-endpoint configs."""
202+
def to_payload(self) -> dict[str, Any]:
203+
"""Return the full edge-endpoint payload shape."""
161204
return {
162205
"global_config": self.global_config.model_dump(),
163-
**self.detectors_config.to_payload(),
206+
"edge_inference_configs": {
207+
name: config.model_dump() for name, config in self.edge_inference_configs.items()
208+
},
209+
"detectors": [detector.model_dump() for detector in self.detectors],
164210
}
165211

166-
def add_detector(self, detector: Union[str, Detector], edge_inference_config: InferenceConfig) -> None:
167-
"""Add a detector with the given inference config. Accepts detector ID or Detector object."""
168-
self.detectors_config.add_detector(detector, edge_inference_config)
169-
212+
@classmethod
213+
def from_payload(cls, payload: dict[str, Any]) -> "EdgeEndpointConfig":
214+
"""Construct an EdgeEndpointConfig from a payload dictionary."""
215+
return cls.model_validate(payload)
170216

171217
# Preset inference configs matching the standard edge-endpoint defaults.
172218
DEFAULT = InferenceConfig(name="default")

test/unit/test_edge_config.py

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from datetime import datetime, timezone
2-
32
import pytest
43
from groundlight.edge import (
54
DEFAULT,
@@ -108,6 +107,25 @@ def test_constructor_accepts_matching_inference_config_key_and_name():
108107
assert [detector.detector_id for detector in config.detectors] == ["det_1"]
109108

110109

110+
def test_constructor_hydrates_inference_config_name_from_dict_key():
111+
"""Hydrates inference config names from payload dict keys."""
112+
config = DetectorsConfig(
113+
edge_inference_configs={"default": {"enabled": True}},
114+
detectors=[{"detector_id": "det_1", "edge_inference_config": "default"}],
115+
)
116+
117+
assert config.edge_inference_configs["default"].name == "default"
118+
119+
120+
def test_constructor_rejects_detector_map_input():
121+
"""Rejects detector maps and requires detector list payloads."""
122+
with pytest.raises(ValueError):
123+
DetectorsConfig(
124+
edge_inference_configs={"default": {"enabled": True}},
125+
detectors={"det_1": {"detector_id": "det_1", "edge_inference_config": "default"}},
126+
)
127+
128+
111129
def test_constructor_rejects_undefined_inference_config_reference():
112130
"""Rejects detector entries that reference missing inference configs."""
113131
with pytest.raises(ValueError, match="not defined"):
@@ -117,7 +135,7 @@ def test_constructor_rejects_undefined_inference_config_reference():
117135
)
118136

119137

120-
def test_edge_endpoint_config_add_detector_delegates_to_detectors_logic():
138+
def test_edge_endpoint_config_add_detector_uses_shared_config_logic():
121139
"""Adds detectors via EdgeEndpointConfig and preserves inferred config mapping."""
122140
config = EdgeEndpointConfig()
123141
config.add_detector("det_1", NO_CLOUD)
@@ -157,23 +175,103 @@ def test_detectors_config_to_payload_shape():
157175
assert set(payload["edge_inference_configs"].keys()) == {"default", "no_cloud"}
158176

159177

178+
def test_edge_endpoint_config_accepts_top_level_payload_shape():
179+
"""Accepts the top-level edge endpoint payload shape used by APIs."""
180+
config = EdgeEndpointConfig.model_validate(
181+
{
182+
"global_config": {"refresh_rate": CUSTOM_REFRESH_RATE},
183+
"edge_inference_configs": {"default": {"enabled": True}},
184+
"detectors": [{"detector_id": "det_1", "edge_inference_config": "default"}],
185+
}
186+
)
187+
188+
assert config.global_config.refresh_rate == CUSTOM_REFRESH_RATE
189+
assert [detector.detector_id for detector in config.detectors] == ["det_1"]
190+
191+
192+
def test_edge_endpoint_config_from_yaml_accepts_yaml_text():
193+
"""Parses edge-endpoint YAML text using EdgeEndpointConfig.from_yaml."""
194+
config = EdgeEndpointConfig.from_yaml(
195+
yaml_str="""
196+
global_config:
197+
refresh_rate: 15.0
198+
edge_inference_configs:
199+
default:
200+
enabled: true
201+
detectors:
202+
- detector_id: det_1
203+
edge_inference_config: default
204+
"""
205+
)
206+
207+
assert config.global_config.refresh_rate == 15.0
208+
assert [detector.detector_id for detector in config.detectors] == ["det_1"]
209+
210+
211+
def test_edge_endpoint_config_from_yaml_accepts_filename(tmp_path):
212+
"""Parses edge-endpoint YAML from a file path."""
213+
config_file = tmp_path / "edge-config.yaml"
214+
config_file.write_text(
215+
"global_config: {}\n"
216+
"edge_inference_configs:\n"
217+
" default:\n"
218+
" enabled: true\n"
219+
"detectors:\n"
220+
" - detector_id: det_1\n"
221+
" edge_inference_config: default\n"
222+
)
223+
config = EdgeEndpointConfig.from_yaml(filename=str(config_file))
224+
225+
assert [detector.detector_id for detector in config.detectors] == ["det_1"]
226+
227+
228+
def test_edge_endpoint_config_from_yaml_requires_exactly_one_input():
229+
"""Rejects missing input and mixed filename/yaml_str input."""
230+
with pytest.raises(ValueError, match="Either filename or yaml_str must be provided"):
231+
EdgeEndpointConfig.from_yaml()
232+
233+
with pytest.raises(ValueError, match="Only one of filename or yaml_str can be provided"):
234+
EdgeEndpointConfig.from_yaml(filename="a.yaml", yaml_str="global_config: {}")
235+
236+
with pytest.raises(ValueError, match="filename must be a non-empty path"):
237+
EdgeEndpointConfig.from_yaml(filename=" ")
238+
239+
240+
def test_edge_endpoint_config_rejects_extra_top_level_fields():
241+
"""Rejects unknown top-level fields to avoid silent config drift."""
242+
with pytest.raises(ValueError, match="Extra inputs are not permitted"):
243+
EdgeEndpointConfig.model_validate({"global_config": {}, "unknown_field": True})
244+
245+
160246
def test_model_dump_shape_for_edge_endpoint_config():
161-
"""Serializes full edge endpoint config in flattened wire shape."""
247+
"""Serializes full edge endpoint config in wire payload shape."""
162248
config = EdgeEndpointConfig(
163249
global_config=GlobalConfig(refresh_rate=CUSTOM_REFRESH_RATE, confident_audit_rate=CUSTOM_AUDIT_RATE)
164250
)
165251
config.add_detector("det_1", DEFAULT)
166252
config.add_detector("det_2", EDGE_ANSWERS_WITH_ESCALATION)
167253
config.add_detector("det_3", NO_CLOUD)
168254

169-
payload = config.model_dump()
255+
payload = config.to_payload()
170256

171257
assert payload["global_config"]["refresh_rate"] == CUSTOM_REFRESH_RATE
172258
assert payload["global_config"]["confident_audit_rate"] == CUSTOM_AUDIT_RATE
173259
assert len(payload["detectors"]) == 3 # noqa: PLR2004
174260
assert set(payload["edge_inference_configs"].keys()) == {"default", "edge_answers_with_escalation", "no_cloud"}
175261

176262

263+
def test_edge_endpoint_config_from_payload_round_trip():
264+
"""Round-trips edge endpoint config through payload helpers."""
265+
config = EdgeEndpointConfig()
266+
config.add_detector("det_1", DEFAULT)
267+
config.add_detector("det_2", NO_CLOUD)
268+
269+
payload = config.to_payload()
270+
reconstructed = EdgeEndpointConfig.from_payload(payload)
271+
272+
assert reconstructed == config
273+
274+
177275
def test_inference_config_validation_errors():
178276
"""Raises on invalid inference config flag combinations and values."""
179277
with pytest.raises(ValueError, match="disable_cloud_escalation"):

0 commit comments

Comments
 (0)