Skip to content

Commit 76eb7d9

Browse files
committed
chore: Format
1 parent 1869c30 commit 76eb7d9

File tree

4 files changed

+131
-50
lines changed

4 files changed

+131
-50
lines changed

projects/policyengine-api-tagger/src/policyengine_api_tagger/api/revision_cleanup.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
class TagInfo(BaseModel):
2222
"""Information about a traffic tag."""
23+
2324
tag: str
2425
revision: str
2526
country: str # "us" or "uk"
@@ -77,7 +78,7 @@ def _parse_tag(self, tag: str, revision: str) -> TagInfo | None:
7778
Example: country-us-model-1-459-0 -> country=us, version=(1, 459, 0)
7879
"""
7980
# Match pattern: country-{us|uk}-model-{version}
80-
match = re.match(r'^country-(us|uk)-model-(.+)$', tag)
81+
match = re.match(r"^country-(us|uk)-model-(.+)$", tag)
8182
if not match:
8283
log.debug(f"Tag '{tag}' doesn't match expected pattern, skipping")
8384
return None
@@ -86,11 +87,11 @@ def _parse_tag(self, tag: str, revision: str) -> TagInfo | None:
8687
version_with_dashes = match.group(2)
8788

8889
# Convert dashes back to dots for the version string
89-
version_str = version_with_dashes.replace('-', '.')
90+
version_str = version_with_dashes.replace("-", ".")
9091

9192
# Parse version into tuple of integers for comparison
9293
try:
93-
version_parts = tuple(int(p) for p in version_str.split('.'))
94+
version_parts = tuple(int(p) for p in version_str.split("."))
9495
except ValueError:
9596
log.warning(f"Could not parse version from tag '{tag}': {version_str}")
9697
return None
@@ -109,7 +110,9 @@ async def _get_service(self) -> Service:
109110
service_name = self._get_service_name()
110111
return await client.get_service(name=service_name)
111112

112-
async def _update_service_traffic(self, service: Service, tags_to_keep: list[TagInfo]) -> None:
113+
async def _update_service_traffic(
114+
self, service: Service, tags_to_keep: list[TagInfo]
115+
) -> None:
113116
"""
114117
Update the service traffic configuration to only include specified tags.
115118
@@ -130,19 +133,24 @@ async def _update_service_traffic(self, service: Service, tags_to_keep: list[Tag
130133
# Add the tags we want to keep (with percent=0)
131134
for tag_info in tags_to_keep:
132135
from google.cloud.run_v2 import TrafficTarget
133-
new_traffic.append(TrafficTarget(
134-
percent=0,
135-
revision=tag_info.revision,
136-
tag=tag_info.tag,
137-
))
136+
137+
new_traffic.append(
138+
TrafficTarget(
139+
percent=0,
140+
revision=tag_info.revision,
141+
tag=tag_info.tag,
142+
)
143+
)
138144

139145
# Update the service
140146
service.traffic = new_traffic
141147

142148
request = UpdateServiceRequest(service=service)
143149
await client.update_service(request=request)
144150

145-
async def _analyze_tags(self, keep_count: int) -> tuple[
151+
async def _analyze_tags(
152+
self, keep_count: int
153+
) -> tuple[
146154
Service | None,
147155
list[TagInfo],
148156
list[TagInfo],
@@ -226,7 +234,15 @@ async def _analyze_tags(self, keep_count: int) -> tuple[
226234

227235
log.info(f"Keeping {len(tags_to_keep)} tags, removing {len(tags_removed)} tags")
228236

229-
return service, all_tags, tags_to_keep, newest_us, newest_uk, tags_removed, errors
237+
return (
238+
service,
239+
all_tags,
240+
tags_to_keep,
241+
newest_us,
242+
newest_uk,
243+
tags_removed,
244+
errors,
245+
)
230246

231247
async def preview(self, keep_count: int = 40) -> CleanupResult:
232248
"""
@@ -238,8 +254,15 @@ async def preview(self, keep_count: int = 40) -> CleanupResult:
238254
Returns:
239255
CleanupResult showing what would be kept/removed
240256
"""
241-
service, all_tags, tags_to_keep, newest_us, newest_uk, tags_removed, errors = \
242-
await self._analyze_tags(keep_count)
257+
(
258+
service,
259+
all_tags,
260+
tags_to_keep,
261+
newest_us,
262+
newest_uk,
263+
tags_removed,
264+
errors,
265+
) = await self._analyze_tags(keep_count)
243266

244267
if service is None:
245268
return CleanupResult(
@@ -276,8 +299,15 @@ async def cleanup(self, keep_count: int = 40) -> CleanupResult:
276299
Returns:
277300
CleanupResult with details of what was cleaned up
278301
"""
279-
service, all_tags, tags_to_keep, newest_us, newest_uk, tags_removed, errors = \
280-
await self._analyze_tags(keep_count)
302+
(
303+
service,
304+
all_tags,
305+
tags_to_keep,
306+
newest_us,
307+
newest_uk,
308+
tags_removed,
309+
errors,
310+
) = await self._analyze_tags(keep_count)
281311

282312
if service is None:
283313
return CleanupResult(

projects/policyengine-api-tagger/src/policyengine_api_tagger/api/routes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ async def get_tag_uri(country: str, model_version: str) -> str:
2323
return uri
2424

2525
@router.post("/cleanup")
26-
async def cleanup_old_revisions(keep: int = 40, dry_run: bool = False) -> CleanupResult:
26+
async def cleanup_old_revisions(
27+
keep: int = 40, dry_run: bool = False
28+
) -> CleanupResult:
2729
"""
2830
Clean up old traffic tags, keeping the specified number.
2931

projects/policyengine-api-tagger/tests/api/test_cleanup_integration.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ def client(self, mock_cloudrun_service):
7171

7272
yield TestClient(app)
7373

74-
def test_cleanup_returns_success_with_tags(
75-
self, client, mock_cloudrun_service
76-
):
74+
def test_cleanup_returns_success_with_tags(self, client, mock_cloudrun_service):
7775
"""Cleanup succeeds and returns tag information."""
7876
mock_cloudrun_service["service"].traffic = [
7977
make_mock_traffic_entry("country-us-model-1-459-0", "rev-us"),
@@ -90,9 +88,7 @@ def test_cleanup_returns_success_with_tags(
9088
assert result["newest_us_tag"] == "country-us-model-1-459-0"
9189
assert result["newest_uk_tag"] == "country-uk-model-2-65-9"
9290

93-
def test_cleanup_dry_run_does_not_modify(
94-
self, client, mock_cloudrun_service
95-
):
91+
def test_cleanup_dry_run_does_not_modify(self, client, mock_cloudrun_service):
9692
"""Cleanup with dry_run=true should NOT call update_service."""
9793
mock_cloudrun_service["service"].traffic = [
9894
make_mock_traffic_entry("country-us-model-1-100-0", "rev-1"),
@@ -112,9 +108,7 @@ def test_cleanup_dry_run_does_not_modify(
112108
# CRITICAL: update_service should NOT have been called
113109
mock_cloudrun_service["client"].update_service.assert_not_called()
114110

115-
def test_cleanup_identifies_safeguards(
116-
self, client, mock_cloudrun_service
117-
):
111+
def test_cleanup_identifies_safeguards(self, client, mock_cloudrun_service):
118112
"""Cleanup correctly identifies newest US and UK tags."""
119113
mock_cloudrun_service["service"].traffic = [
120114
make_mock_traffic_entry("country-us-model-1-100-0", "rev-1"),
@@ -130,9 +124,7 @@ def test_cleanup_identifies_safeguards(
130124
assert result["newest_us_tag"] == "country-us-model-1-459-0"
131125
assert result["newest_uk_tag"] == "country-uk-model-2-65-9"
132126

133-
def test_cleanup_respects_keep_count(
134-
self, client, mock_cloudrun_service
135-
):
127+
def test_cleanup_respects_keep_count(self, client, mock_cloudrun_service):
136128
"""Cleanup keeps the correct number of tags."""
137129
mock_cloudrun_service["service"].traffic = [
138130
make_mock_traffic_entry("country-us-model-1-100-0", "rev-1"),
@@ -156,9 +148,7 @@ def test_cleanup_validates_keep_minimum(self, client, mock_cloudrun_service):
156148
assert response.status_code == 400
157149
assert "at least 2" in response.json()["detail"]
158150

159-
def test_cleanup_handles_no_tags(
160-
self, client, mock_cloudrun_service
161-
):
151+
def test_cleanup_handles_no_tags(self, client, mock_cloudrun_service):
162152
"""Cleanup handles service with no tags."""
163153
mock_cloudrun_service["service"].traffic = [
164154
make_mock_traffic_entry(None, "rev-main", percent=100),

projects/policyengine-api-tagger/tests/api/test_revision_cleanup.py

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,15 @@ async def test_identifies_newest_us_and_uk_tags(self, cleanup):
101101
with patch.object(cleanup, "_get_service", new_callable=AsyncMock) as mock_get:
102102
mock_get.return_value = mock_service
103103

104-
service, all_tags, tags_to_keep, newest_us, newest_uk, tags_removed, errors = \
105-
await cleanup._analyze_tags(keep_count=40)
104+
(
105+
service,
106+
all_tags,
107+
tags_to_keep,
108+
newest_us,
109+
newest_uk,
110+
tags_removed,
111+
errors,
112+
) = await cleanup._analyze_tags(keep_count=40)
106113

107114
assert newest_us is not None
108115
assert newest_us.tag == "country-us-model-1-459-0"
@@ -126,8 +133,15 @@ async def test_keeps_safeguards_plus_newest(self, cleanup):
126133
with patch.object(cleanup, "_get_service", new_callable=AsyncMock) as mock_get:
127134
mock_get.return_value = mock_service
128135

129-
service, all_tags, tags_to_keep, newest_us, newest_uk, tags_removed, errors = \
130-
await cleanup._analyze_tags(keep_count=4)
136+
(
137+
service,
138+
all_tags,
139+
tags_to_keep,
140+
newest_us,
141+
newest_uk,
142+
tags_removed,
143+
errors,
144+
) = await cleanup._analyze_tags(keep_count=4)
131145

132146
# Should keep: newest US, newest UK, then next 2 newest by version
133147
assert len(tags_to_keep) == 4
@@ -152,8 +166,15 @@ async def test_handles_keep_count_less_than_2(self, cleanup):
152166
with patch.object(cleanup, "_get_service", new_callable=AsyncMock) as mock_get:
153167
mock_get.return_value = mock_service
154168

155-
service, all_tags, tags_to_keep, newest_us, newest_uk, tags_removed, errors = \
156-
await cleanup._analyze_tags(keep_count=1) # Should become 2
169+
(
170+
service,
171+
all_tags,
172+
tags_to_keep,
173+
newest_us,
174+
newest_uk,
175+
tags_removed,
176+
errors,
177+
) = await cleanup._analyze_tags(keep_count=1) # Should become 2
157178

158179
# Should still keep both safeguards
159180
assert len(tags_to_keep) == 2
@@ -169,8 +190,15 @@ async def test_handles_no_tags(self, cleanup):
169190
with patch.object(cleanup, "_get_service", new_callable=AsyncMock) as mock_get:
170191
mock_get.return_value = mock_service
171192

172-
service, all_tags, tags_to_keep, newest_us, newest_uk, tags_removed, errors = \
173-
await cleanup._analyze_tags(keep_count=40)
193+
(
194+
service,
195+
all_tags,
196+
tags_to_keep,
197+
newest_us,
198+
newest_uk,
199+
tags_removed,
200+
errors,
201+
) = await cleanup._analyze_tags(keep_count=40)
174202

175203
assert len(all_tags) == 0
176204
assert len(tags_to_keep) == 0
@@ -189,8 +217,15 @@ async def test_handles_only_us_tags(self, cleanup):
189217
with patch.object(cleanup, "_get_service", new_callable=AsyncMock) as mock_get:
190218
mock_get.return_value = mock_service
191219

192-
service, all_tags, tags_to_keep, newest_us, newest_uk, tags_removed, errors = \
193-
await cleanup._analyze_tags(keep_count=40)
220+
(
221+
service,
222+
all_tags,
223+
tags_to_keep,
224+
newest_us,
225+
newest_uk,
226+
tags_removed,
227+
errors,
228+
) = await cleanup._analyze_tags(keep_count=40)
194229

195230
assert newest_us is not None
196231
assert newest_us.tag == "country-us-model-1-200-0"
@@ -202,8 +237,15 @@ async def test_handles_service_error(self, cleanup):
202237
with patch.object(cleanup, "_get_service", new_callable=AsyncMock) as mock_get:
203238
mock_get.side_effect = Exception("Cloud Run API error")
204239

205-
service, all_tags, tags_to_keep, newest_us, newest_uk, tags_removed, errors = \
206-
await cleanup._analyze_tags(keep_count=40)
240+
(
241+
service,
242+
all_tags,
243+
tags_to_keep,
244+
newest_us,
245+
newest_uk,
246+
tags_removed,
247+
errors,
248+
) = await cleanup._analyze_tags(keep_count=40)
207249

208250
assert service is None
209251
assert len(errors) == 1
@@ -250,7 +292,9 @@ async def test_preview_does_not_call_update(self, cleanup):
250292

251293
with (
252294
patch.object(cleanup, "_get_service", new_callable=AsyncMock) as mock_get,
253-
patch.object(cleanup, "_update_service_traffic", new_callable=AsyncMock) as mock_update,
295+
patch.object(
296+
cleanup, "_update_service_traffic", new_callable=AsyncMock
297+
) as mock_update,
254298
):
255299
mock_get.return_value = mock_service
256300

@@ -278,7 +322,9 @@ async def test_calls_update_when_tags_to_remove(self, cleanup):
278322

279323
with (
280324
patch.object(cleanup, "_get_service", new_callable=AsyncMock) as mock_get,
281-
patch.object(cleanup, "_update_service_traffic", new_callable=AsyncMock) as mock_update,
325+
patch.object(
326+
cleanup, "_update_service_traffic", new_callable=AsyncMock
327+
) as mock_update,
282328
):
283329
mock_get.return_value = mock_service
284330

@@ -299,7 +345,9 @@ async def test_does_not_call_update_when_nothing_to_remove(self, cleanup):
299345

300346
with (
301347
patch.object(cleanup, "_get_service", new_callable=AsyncMock) as mock_get,
302-
patch.object(cleanup, "_update_service_traffic", new_callable=AsyncMock) as mock_update,
348+
patch.object(
349+
cleanup, "_update_service_traffic", new_callable=AsyncMock
350+
) as mock_update,
303351
):
304352
mock_get.return_value = mock_service
305353

@@ -322,7 +370,9 @@ async def test_handles_update_failure(self, cleanup):
322370

323371
with (
324372
patch.object(cleanup, "_get_service", new_callable=AsyncMock) as mock_get,
325-
patch.object(cleanup, "_update_service_traffic", new_callable=AsyncMock) as mock_update,
373+
patch.object(
374+
cleanup, "_update_service_traffic", new_callable=AsyncMock
375+
) as mock_update,
326376
):
327377
mock_get.return_value = mock_service
328378
mock_update.side_effect = Exception("Update failed")
@@ -354,7 +404,9 @@ async def capture_update(service, tags_to_keep):
354404

355405
with (
356406
patch.object(cleanup, "_get_service", new_callable=AsyncMock) as mock_get,
357-
patch.object(cleanup, "_update_service_traffic", side_effect=capture_update) as mock_update,
407+
patch.object(
408+
cleanup, "_update_service_traffic", side_effect=capture_update
409+
) as mock_update,
358410
):
359411
mock_get.return_value = mock_service
360412

@@ -388,8 +440,15 @@ async def test_sorts_versions_correctly(self, cleanup):
388440
with patch.object(cleanup, "_get_service", new_callable=AsyncMock) as mock_get:
389441
mock_get.return_value = mock_service
390442

391-
service, all_tags, tags_to_keep, newest_us, newest_uk, tags_removed, errors = \
392-
await cleanup._analyze_tags(keep_count=2)
443+
(
444+
service,
445+
all_tags,
446+
tags_to_keep,
447+
newest_us,
448+
newest_uk,
449+
tags_removed,
450+
errors,
451+
) = await cleanup._analyze_tags(keep_count=2)
393452

394453
# 1.100.0 > 1.10.0 > 1.9.0 numerically
395454
assert newest_us.tag == "country-us-model-1-100-0"

0 commit comments

Comments
 (0)