Skip to content

Commit d33f4a9

Browse files
committed
feat: Introduce a lock on the agent such that we don't mutate the same row more than once at a time
1 parent e2fe3dc commit d33f4a9

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

encord_agents/fastapi/cors.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@
44
interactions from the Encord platform.
55
"""
66

7+
import asyncio
8+
import json
79
import typing
810
from http import HTTPStatus
11+
from uuid import UUID
912

1013
from encord.exceptions import AuthorisationError
14+
from pydantic import ValidationError
15+
16+
from encord_agents.core.data_model import FrameData
1117

1218
try:
1319
from fastapi import FastAPI, Request
@@ -102,6 +108,45 @@ async def _authorization_error_exception_handler(request: Request, exc: Authoris
102108
)
103109

104110

111+
class FieldPairLockMiddleware(BaseHTTPMiddleware):
112+
def __init__(
113+
self,
114+
app: ASGIApp,
115+
):
116+
super().__init__(app)
117+
self.field_locks: dict[tuple[UUID, UUID], asyncio.Lock] = {}
118+
self.locks_lock = asyncio.Lock()
119+
120+
async def get_lock(self, frame_data: FrameData) -> asyncio.Lock:
121+
lock_key = (frame_data.project_hash, frame_data.data_hash)
122+
async with self.locks_lock:
123+
if lock_key not in self.field_locks:
124+
self.field_locks[lock_key] = asyncio.Lock()
125+
return self.field_locks[lock_key]
126+
127+
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
128+
if request.method != "POST":
129+
return await call_next(request)
130+
try:
131+
body = await request.body()
132+
try:
133+
frame_data = FrameData.model_validate_json(body)
134+
except ValidationError:
135+
# Hope that route doesn't use FrameData
136+
return await call_next(request)
137+
lock = await self.get_lock(frame_data)
138+
async with lock:
139+
# Create a new request with the same body since we've already consumed it
140+
request._body = body
141+
return await call_next(request)
142+
except Exception as e:
143+
return Response(
144+
content=json.dumps({"detail": f"Error in middleware: {str(e)}"}),
145+
status_code=500,
146+
media_type="application/json",
147+
)
148+
149+
105150
def get_encord_app(*, custom_cors_regex: str | None = None) -> FastAPI:
106151
"""
107152
Get a FastAPI app with the Encord middleware.
@@ -114,10 +159,12 @@ def get_encord_app(*, custom_cors_regex: str | None = None) -> FastAPI:
114159
FastAPI: A FastAPI app with the Encord middleware.
115160
"""
116161
app = FastAPI()
162+
117163
app.add_middleware(
118164
EncordCORSMiddleware,
119165
allow_origin_regex=custom_cors_regex or ENCORD_DOMAIN_REGEX,
120166
)
121167
app.add_middleware(EncordTestHeaderMiddleware)
168+
app.add_middleware(FieldPairLockMiddleware)
122169
app.exception_handlers[AuthorisationError] = _authorization_error_exception_handler
123170
return app

tests/integration_tests/fastapi/test_dependencies.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from http import HTTPStatus
23
from typing import Annotated, NamedTuple
34
from uuid import uuid4
@@ -306,3 +307,47 @@ def post_client(client: Annotated[EncordUserClient, Depends(dep_client)]) -> Non
306307
resp = client.post("/client", headers={"Origin": "https://example.com"})
307308
assert resp.status_code == 200, resp.content
308309
assert "Access-Control-Allow-Origin" not in resp.headers
310+
311+
312+
class TestFieldPairLockMiddleware:
313+
context: SharedResolutionContext
314+
client: TestClient
315+
list_holder: list[int | tuple[int, str]]
316+
317+
# Set the project and first label row for the class
318+
@classmethod
319+
@pytest.fixture(autouse=True)
320+
def setup(cls, context: SharedResolutionContext) -> None:
321+
cls.context = context
322+
app = get_encord_app()
323+
cls.list_holder = []
324+
325+
@app.post("/threadsafe-endpoint")
326+
async def threadsafe_endpoint(frame_data: FrameData) -> None:
327+
cls.list_holder.append(frame_data.frame)
328+
await asyncio.sleep(0.1)
329+
cls.list_holder.append((frame_data.frame, "DONE"))
330+
331+
cls.client = TestClient(app)
332+
333+
def test_field_pair_lock_middleware(self) -> None:
334+
assert self.list_holder == []
335+
resp = self.client.post(
336+
"/threadsafe-endpoint",
337+
json={
338+
"projectHash": self.context.project.project_hash,
339+
"dataHash": self.context.video_label_row.data_hash,
340+
"frame": 0,
341+
},
342+
)
343+
assert resp.status_code == 200, resp.content
344+
resp = self.client.post(
345+
"/threadsafe-endpoint",
346+
json={
347+
"projectHash": self.context.project.project_hash,
348+
"dataHash": self.context.video_label_row.data_hash,
349+
"frame": 1,
350+
},
351+
)
352+
assert resp.status_code == 200, resp.content
353+
assert self.list_holder == [0, (0, "DONE"), 1, (1, "DONE")]

0 commit comments

Comments
 (0)