Skip to content

Commit 45feffa

Browse files
author
Tim Huff
committed
responding to more AI PR feedback
1 parent b3a7f66 commit 45feffa

File tree

3 files changed

+137
-64
lines changed

3 files changed

+137
-64
lines changed

src/groundlight/edge/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
DISABLED,
44
EDGE_ANSWERS_WITH_ESCALATION,
55
NO_CLOUD,
6-
DetectorsConfig,
76
DetectorConfig,
7+
DetectorsConfig,
88
EdgeEndpointConfig,
9-
EdgeInferenceConfig,
109
GlobalConfig,
1110
InferenceConfig,
1211
)
@@ -19,7 +18,6 @@
1918
"DetectorsConfig",
2019
"DetectorConfig",
2120
"EdgeEndpointConfig",
22-
"EdgeInferenceConfig",
2321
"GlobalConfig",
2422
"InferenceConfig",
2523
]

src/groundlight/edge/config.py

Lines changed: 63 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ class InferenceConfig(BaseModel):
2424
model_config = ConfigDict(frozen=True)
2525

2626
name: str = Field(..., exclude=True, description="A unique name for this inference config preset.")
27-
enabled: bool = Field( # TODO investigate and update the functionality of this option
27+
enabled: bool = Field(
2828
default=True, description="Whether the edge endpoint should accept image queries for this detector."
2929
)
30-
api_token: Union[str, None] = Field(
30+
api_token: Optional[str] = Field(
3131
default=None, description="API token used to fetch the inference model for this detector."
3232
)
3333
always_return_edge_prediction: bool = Field(
@@ -74,6 +74,63 @@ class DetectorConfig(BaseModel):
7474
edge_inference_config: str = Field(..., description="Config for edge inference.")
7575

7676

77+
def _validate_detector_config_state(
78+
edge_inference_configs: dict[str, InferenceConfig], detectors: list[DetectorConfig]
79+
) -> None:
80+
for name, config in edge_inference_configs.items():
81+
if name != config.name:
82+
raise ValueError(f"Edge inference config key '{name}' must match InferenceConfig.name '{config.name}'.")
83+
84+
seen_detector_ids = set()
85+
duplicate_detector_ids = set()
86+
for detector_config in detectors:
87+
detector_id = detector_config.detector_id
88+
if detector_id in seen_detector_ids:
89+
duplicate_detector_ids.add(detector_id)
90+
else:
91+
seen_detector_ids.add(detector_id)
92+
if duplicate_detector_ids:
93+
duplicates = ", ".join(sorted(duplicate_detector_ids))
94+
raise ValueError(f"Duplicate detector IDs are not allowed: {duplicates}.")
95+
96+
for detector_config in detectors:
97+
if detector_config.edge_inference_config not in edge_inference_configs:
98+
raise ValueError(f"Edge inference config '{detector_config.edge_inference_config}' not defined.")
99+
100+
101+
def _add_detector_to_state(
102+
edge_inference_configs: dict[str, InferenceConfig],
103+
detectors: list[DetectorConfig],
104+
detector: Union[str, Detector],
105+
edge_inference_config: Union[str, InferenceConfig],
106+
) -> DetectorConfig:
107+
detector_id = detector.id if isinstance(detector, Detector) else detector
108+
if any(existing.detector_id == detector_id for existing in detectors):
109+
raise ValueError(f"A detector with ID '{detector_id}' already exists.")
110+
if isinstance(edge_inference_config, InferenceConfig):
111+
config = edge_inference_config
112+
existing = edge_inference_configs.get(config.name)
113+
if existing is None:
114+
edge_inference_configs[config.name] = config
115+
elif existing != config:
116+
raise ValueError(f"A different inference config named '{config.name}' is already registered.")
117+
config_name = config.name
118+
else:
119+
config_name = edge_inference_config
120+
if config_name not in edge_inference_configs:
121+
raise ValueError(
122+
f"Edge inference config '{config_name}' not defined. "
123+
f"Available configs: {list(edge_inference_configs.keys())}"
124+
)
125+
126+
detector_config = DetectorConfig(
127+
detector_id=detector_id,
128+
edge_inference_config=config_name,
129+
)
130+
detectors.append(detector_config)
131+
return detector_config
132+
133+
77134
class DetectorsConfig(BaseModel):
78135
"""
79136
Detector and inference-config mappings for edge inference.
@@ -84,54 +141,11 @@ class DetectorsConfig(BaseModel):
84141

85142
@model_validator(mode="after")
86143
def validate_inference_configs(self):
87-
for name, config in self.edge_inference_configs.items():
88-
if name != config.name:
89-
raise ValueError(
90-
f"Edge inference config key '{name}' must match InferenceConfig.name '{config.name}'."
91-
)
92-
93-
seen_detector_ids = set()
94-
duplicate_detector_ids = set()
95-
for detector_config in self.detectors:
96-
detector_id = detector_config.detector_id
97-
if detector_id in seen_detector_ids:
98-
duplicate_detector_ids.add(detector_id)
99-
else:
100-
seen_detector_ids.add(detector_id)
101-
if duplicate_detector_ids:
102-
duplicates = ", ".join(sorted(duplicate_detector_ids))
103-
raise ValueError(f"Duplicate detector IDs are not allowed: {duplicates}.")
104-
105-
for detector_config in self.detectors:
106-
if detector_config.edge_inference_config not in self.edge_inference_configs:
107-
raise ValueError(f"Edge inference config '{detector_config.edge_inference_config}' not defined.")
144+
_validate_detector_config_state(self.edge_inference_configs, self.detectors)
108145
return self
109146

110147
def add_detector(self, detector: Union[str, Detector], edge_inference_config: Union[str, InferenceConfig]) -> None:
111-
detector_id = detector.id if isinstance(detector, Detector) else detector
112-
if any(d.detector_id == detector_id for d in self.detectors):
113-
raise ValueError(f"A detector with ID '{detector_id}' already exists.")
114-
if isinstance(edge_inference_config, InferenceConfig):
115-
config = edge_inference_config
116-
existing = self.edge_inference_configs.get(config.name)
117-
if existing is None:
118-
self.edge_inference_configs[config.name] = config
119-
elif existing != config:
120-
raise ValueError(f"A different inference config named '{config.name}' is already registered.")
121-
config_name = config.name
122-
else:
123-
config_name = edge_inference_config
124-
if config_name not in self.edge_inference_configs:
125-
raise ValueError(
126-
f"Edge inference config '{config_name}' not defined. "
127-
f"Available configs: {list(self.edge_inference_configs.keys())}"
128-
)
129-
self.detectors.append(
130-
DetectorConfig(
131-
detector_id=detector_id,
132-
edge_inference_config=config_name,
133-
)
134-
)
148+
_add_detector_to_state(self.edge_inference_configs, self.detectors, detector, edge_inference_config)
135149

136150

137151
class EdgeEndpointConfig(BaseModel):
@@ -145,20 +159,11 @@ class EdgeEndpointConfig(BaseModel):
145159

146160
@model_validator(mode="after")
147161
def validate_inference_configs(self):
148-
DetectorsConfig(
149-
edge_inference_configs=self.edge_inference_configs,
150-
detectors=self.detectors,
151-
)
162+
_validate_detector_config_state(self.edge_inference_configs, self.detectors)
152163
return self
153164

154165
def add_detector(self, detector: Union[str, Detector], edge_inference_config: Union[str, InferenceConfig]) -> None:
155-
detectors_config = DetectorsConfig(
156-
edge_inference_configs=self.edge_inference_configs,
157-
detectors=self.detectors,
158-
)
159-
detectors_config.add_detector(detector, edge_inference_config)
160-
self.edge_inference_configs = detectors_config.edge_inference_configs
161-
self.detectors = detectors_config.detectors
166+
_add_detector_to_state(self.edge_inference_configs, self.detectors, detector, edge_inference_config)
162167

163168
@classmethod
164169
def from_detectors_config(
@@ -172,9 +177,6 @@ def from_detectors_config(
172177
)
173178

174179

175-
EdgeInferenceConfig = InferenceConfig
176-
177-
178180
# Preset inference configs matching the standard edge-endpoint defaults.
179181
DEFAULT = InferenceConfig(name="default")
180182
EDGE_ANSWERS_WITH_ESCALATION = InferenceConfig(

test/unit/test_edge_config.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,38 @@
1+
from datetime import datetime
2+
13
import pytest
4+
from model import Detector, DetectorTypeEnum
5+
26
from groundlight.edge import (
37
DEFAULT,
8+
DISABLED,
49
EDGE_ANSWERS_WITH_ESCALATION,
510
NO_CLOUD,
611
DetectorsConfig,
712
EdgeEndpointConfig,
13+
GlobalConfig,
814
InferenceConfig,
915
)
1016

1117
ONE_DETECTOR = 1
1218
TWO_DETECTORS = 2
19+
THREE_DETECTORS = 3
20+
CUSTOM_REFRESH_RATE = 10.0
21+
CUSTOM_AUDIT_RATE = 0.0
22+
23+
24+
def _make_detector(detector_id: str) -> Detector:
25+
return Detector(
26+
id=detector_id,
27+
type=DetectorTypeEnum.detector,
28+
created_at=datetime.utcnow(),
29+
name="test detector",
30+
query="Is there a dog?",
31+
group_name="default",
32+
metadata=None,
33+
mode="BINARY",
34+
mode_configuration=None,
35+
)
1336

1437

1538
def test_edge_endpoint_config_is_not_subclass_of_detectors_config():
@@ -89,6 +112,29 @@ def test_edge_endpoint_config_add_detector_delegates_to_detectors_logic():
89112
assert set(config.edge_inference_configs.keys()) == {"no_cloud", "edge_answers_with_escalation", "default"}
90113

91114

115+
def test_add_detector_accepts_detector_object():
116+
config = EdgeEndpointConfig()
117+
config.add_detector(_make_detector("det_1"), DEFAULT)
118+
119+
assert [detector.detector_id for detector in config.detectors] == ["det_1"]
120+
121+
122+
def test_add_detector_accepts_string_inference_config_name():
123+
config = EdgeEndpointConfig()
124+
config.edge_inference_configs["default"] = DEFAULT
125+
config.add_detector("det_1", "default")
126+
127+
assert [detector.edge_inference_config for detector in config.detectors] == ["default"]
128+
129+
130+
def test_disabled_preset_can_be_used():
131+
config = EdgeEndpointConfig()
132+
config.add_detector("det_1", DISABLED)
133+
134+
assert [detector.edge_inference_config for detector in config.detectors] == ["disabled"]
135+
assert config.edge_inference_configs["disabled"] == DISABLED
136+
137+
92138
def test_from_detectors_config_copies_detector_data():
93139
detectors_config = DetectorsConfig()
94140
detectors_config.add_detector("det_1", DEFAULT)
@@ -100,6 +146,33 @@ def test_from_detectors_config_copies_detector_data():
100146
assert len(detectors_config.detectors) == TWO_DETECTORS
101147

102148

149+
def test_from_detectors_config_uses_custom_global_config():
150+
detectors_config = DetectorsConfig()
151+
detectors_config.add_detector("det_1", DEFAULT)
152+
custom_global_config = GlobalConfig(refresh_rate=CUSTOM_REFRESH_RATE, confident_audit_rate=CUSTOM_AUDIT_RATE)
153+
154+
config = EdgeEndpointConfig.from_detectors_config(detectors_config, global_config=custom_global_config)
155+
156+
assert config.global_config == custom_global_config
157+
assert len(config.detectors) == ONE_DETECTOR
158+
159+
160+
def test_model_dump_shape_for_edge_endpoint_config():
161+
config = EdgeEndpointConfig(
162+
global_config=GlobalConfig(refresh_rate=CUSTOM_REFRESH_RATE, confident_audit_rate=CUSTOM_AUDIT_RATE)
163+
)
164+
config.add_detector("det_1", DEFAULT)
165+
config.add_detector("det_2", EDGE_ANSWERS_WITH_ESCALATION)
166+
config.add_detector("det_3", NO_CLOUD)
167+
168+
payload = config.model_dump()
169+
170+
assert payload["global_config"]["refresh_rate"] == CUSTOM_REFRESH_RATE
171+
assert payload["global_config"]["confident_audit_rate"] == CUSTOM_AUDIT_RATE
172+
assert len(payload["detectors"]) == THREE_DETECTORS
173+
assert set(payload["edge_inference_configs"].keys()) == {"default", "edge_answers_with_escalation", "no_cloud"}
174+
175+
103176
def test_inference_config_validation_errors():
104177
with pytest.raises(ValueError, match="disable_cloud_escalation"):
105178
InferenceConfig(name="bad", disable_cloud_escalation=True)

0 commit comments

Comments
 (0)