Skip to content

Commit adfe04c

Browse files
authored
Video check (#566)
- remove "inline" mode - add gcs delete prefix to allow deleting the `msg_XYZ` directory in GCS (so that we dont rely on message in the DB) - add cleanup in the handle_retry_exhausted - change log error before retry_exhausted to warning - propagate trace through workers ~~Have not verified traces on google~~ Look good there ~~Still looking into why message_id isn't there~~ I think messages erroring (saw some from inference error) preventing creation of message (no message_Id)
1 parent 20b945f commit adfe04c

File tree

7 files changed

+212
-121
lines changed

7 files changed

+212
-121
lines changed

.env.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ AI2_MODEL_HUB_API_KEY=fake
2020
RECAPTCHA_ENABLED=false
2121
RECAPTCHA_KEY=fake
2222

23-
VIDEO_SAFETY_CHECK_WORKER_STRATEGY=fake
23+
SAFETY_QUEUE_ENABLED=false
2424
SAFETY_QUEUE_URL=fake
2525
SAFTEY_GCS_UPLOAD_BUCKET=fake

.vscode/launch.json

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@
5353
"jinja": true,
5454
"justMyCode": false
5555
},
56+
{
57+
"name": "Python Debugger: API Safety Worker",
58+
"type": "debugpy",
59+
"request": "launch",
60+
"module": "dramatiq",
61+
"env": {
62+
"PYTHONPATH": "${workspaceFolder}/apps/api/src",
63+
"ENV": "development",
64+
"GOOGLE_APPLICATION_CREDENTIALS": "${workspaceFolder}/service_account.json"
65+
},
66+
"args": [
67+
"--processes", "1",
68+
"--threads", "1",
69+
"api.safety_queue:setup_safety_queue",
70+
"api.thread.chat.safety.safety_checkers.google_video_safety_checker"
71+
],
72+
"jinja": true,
73+
"justMyCode": false
74+
},
5675
{
5776
"name": "Alembic Auto-generate Debug Migration",
5877
"type": "debugpy",

apps/api/src/api/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ class Settings(BaseSettings):
7272
RECAPTCHA_KEY: str = Field(init=False)
7373
RECAPTCHA_MIN_SCORE_REQUIREMENT: float = 0.3
7474

75+
SAFETY_QUEUE_ENABLED: bool = True
7576
SAFETY_QUEUE_URL: str = Field(init=False)
7677
SAFTEY_GCS_UPLOAD_BUCKET: str = Field(init=False)
77-
VIDEO_SAFETY_CHECK_WORKER_STRATEGY: str = "deferred"
7878

7979
model_config = SettingsConfigDict(
8080
extra="ignore",

apps/api/src/api/safety_queue.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,83 @@
11
import dramatiq
2+
import structlog
23
from dramatiq.brokers.redis import RedisBroker
34
from dramatiq.brokers.stub import StubBroker
45
from dramatiq.middleware.asyncio import AsyncIO
56
from dramatiq.middleware.prometheus import Prometheus
7+
from opentelemetry import context, propagate, trace
68
from typing_extensions import override
79

810
from api.config import settings
11+
from api.logging.fastapi_logger import FastAPIStructLogger
912
from api.logging.setup import setup_logging
1013
from api.otel.setup import setup_otel
1114

1215
# Dramatiq requires a broker to be set before actors are declared.
1316
# This sets a stub so actor modules can be imported safely before setup_safety_queue() is called.
1417
# https://github.com/Bogdanp/dramatiq/pull/762
1518
dramatiq.set_broker(StubBroker())
19+
logger = FastAPIStructLogger()
20+
21+
SAFETY_QUEUE_NAMESPACE = "playground_api_safety_queue"
1622

1723

1824
class OtelMiddleware(dramatiq.Middleware):
1925
@override
2026
def after_worker_boot(self, broker: dramatiq.Broker, worker: dramatiq.Worker) -> None:
2127
setup_otel()
2228

29+
@override
30+
def before_enqueue(self, broker: dramatiq.Broker, message: dramatiq.Message, delay: int) -> None:
31+
if "otel_context" not in message.options:
32+
carrier: dict = {}
33+
propagate.inject(carrier)
34+
message.options["otel_context"] = carrier
35+
36+
@override
37+
def before_process_message(self, broker: dramatiq.Broker, message: dramatiq.MessageProxy) -> None:
38+
# Context is in message.options for a regular dramatiq actor
39+
# or in the message.kwargs.failed_message.options for retry_exhausted actor
40+
carrier = message.options.get("otel_context") or (
41+
message.kwargs.get("failed_message", {}).get("options", {}).get("otel_context", {})
42+
)
43+
44+
ctx = propagate.extract(carrier)
45+
token = context.attach(ctx)
46+
message.options["_otel_token"] = token
47+
48+
structlog.contextvars.clear_contextvars()
49+
span_ctx = trace.get_current_span().get_span_context()
50+
if span_ctx.is_valid:
51+
structlog.contextvars.bind_contextvars(
52+
trace_id=format(span_ctx.trace_id, "032x"),
53+
span_id=format(span_ctx.span_id, "016x"),
54+
trace_flags=span_ctx.trace_flags,
55+
)
56+
57+
@override
58+
def after_process_message(
59+
self, broker: dramatiq.Broker, message: dramatiq.MessageProxy, *, result=None, exception=None
60+
) -> None:
61+
token = message.options.pop("_otel_token", None)
62+
if token is not None:
63+
context.detach(token)
64+
structlog.contextvars.clear_contextvars()
65+
66+
@override
67+
def after_skip_message(self, broker: dramatiq.Broker, message: dramatiq.MessageProxy) -> None:
68+
token = message.options.pop("_otel_token", None)
69+
if token is not None:
70+
context.detach(token)
71+
structlog.contextvars.clear_contextvars()
72+
2373

2474
def setup_safety_queue() -> None:
2575
setup_logging(json_logs=settings.LOG_JSON_FORMAT, log_level=settings.LOG_LEVEL)
2676

27-
if settings.VIDEO_SAFETY_CHECK_WORKER_STRATEGY != "deferred":
77+
if not settings.SAFETY_QUEUE_ENABLED:
2878
return
2979

30-
redis_broker = RedisBroker(url=settings.SAFETY_QUEUE_URL, namespace="playground_safety_queue")
80+
redis_broker = RedisBroker(url=settings.SAFETY_QUEUE_URL, namespace=SAFETY_QUEUE_NAMESPACE)
3181

3282
old_broker = dramatiq.get_broker()
3383
for existing_actor_name in old_broker.get_declared_actors():
@@ -36,7 +86,7 @@ def setup_safety_queue() -> None:
3686
redis_broker.declare_actor(actor)
3787

3888
redis_broker.add_middleware(Prometheus())
39-
redis_broker.add_middleware(AsyncIO())
4089
redis_broker.add_middleware(OtelMiddleware())
90+
redis_broker.add_middleware(AsyncIO())
4191

4292
dramatiq.set_broker(redis_broker)

apps/api/src/api/thread/chat/safety/safety_checkers/google_video_safety_checker.py

Lines changed: 61 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
SafetyChecker,
2727
SafetyCheckRequest,
2828
SafetyCheckResponse,
29-
SafetyCheckUnsafeError,
3029
SkippedSafetyCheckResponse,
3130
)
3231

@@ -56,9 +55,6 @@ def _make_worker_sessionmaker():
5655
class VideoIntelligenceOperationNotAvailableError(Exception): ...
5756

5857

59-
class VideoIntelligenceOperationNotFinishedError(Exception): ...
60-
61-
6258
class VideoIntelligenceOperationMessageNotFoundError(Exception): ...
6359

6460

@@ -108,46 +104,65 @@ async def check_request(self, request: SafetyCheckRequest, *, throw: bool = Fals
108104

109105
span.set_attribute("operation_name", operation_name)
110106

111-
if settings.VIDEO_SAFETY_CHECK_WORKER_STRATEGY == "deferred":
112-
handle_video_safety_check.send(
113-
operation_name=operation_name, file_url=request.content, message_id=request.message_id
114-
)
115-
116-
if settings.VIDEO_SAFETY_CHECK_WORKER_STRATEGY == "inline":
117-
response = await _handle_video_safety_check_async(
118-
operation_name=operation_name, file_url=request.content, message_id=request.message_id
119-
)
120-
if not response.is_safe() and throw:
121-
raise SafetyCheckUnsafeError
122-
123-
return response
107+
handle_video_safety_check.send(
108+
operation_name=operation_name, file_url=request.content, message_id=request.message_id
109+
)
124110

125111
return SkippedSafetyCheckResponse()
126112

127113

128114
@dramatiq.actor
129-
def handle_retry_exhausted(*args, **kwargs) -> None:
130-
# Logs retry limits so we can alert of off them
115+
@tracer.start_as_current_span("handle_retry_exhausted")
116+
async def handle_retry_exhausted(failed_message: dict, retry_info: dict):
117+
kwargs: dict = failed_message.get("kwargs", {})
118+
operation_name: str = kwargs.get("operation_name", "unknown")
119+
file_url: str = kwargs.get("file_url", "unknown")
120+
message_id: str = kwargs.get("message_id", "unknown")
121+
122+
span = trace.get_current_span()
123+
span.set_attributes({
124+
"operation_name": operation_name,
125+
"file_url": file_url,
126+
"message_id": message_id,
127+
"retries": retry_info.get("retries", 0),
128+
})
129+
span.set_status(trace.StatusCode.ERROR, "video safety check retries exhausted")
130+
131131
logger.error(
132132
"video_safety_worker.retry_exhausted",
133-
job_args=str(args),
134-
job_kwargs=str(kwargs),
133+
operation_name=operation_name,
134+
file_url=file_url,
135+
message_id=message_id,
136+
retries=retry_info.get("retries"),
137+
max_retries=retry_info.get("max_retries"),
138+
traceback=failed_message.get("options", {}).get("traceback"),
135139
)
136140

141+
storage_client = get_google_cloud_storage()
142+
safety_file_name = Path(file_url).parts[-1]
143+
144+
await storage_client.delete_file(filename=safety_file_name, bucket_name=settings.SAFTEY_GCS_UPLOAD_BUCKET)
145+
146+
await storage_client.delete_prefix(prefix=message_id, bucket_name=settings.USER_CONTENT_BUCKET)
147+
148+
Session = _make_worker_sessionmaker() # noqa: N806
149+
async with Session() as session:
150+
message_repository = AsyncMessageRepository(session)
151+
message = await message_repository.get_message_by_id(message_id)
152+
if message is not None:
153+
message.harmful = True
154+
message.file_urls = None
155+
await message_repository.update(message)
156+
await session.commit()
157+
137158

138-
@dramatiq.actor( # type: ignore
159+
@dramatiq.actor(
139160
queue_name=SAFETY_QUEUE_NAME,
140161
max_retries=5,
141162
on_retry_exhausted=handle_retry_exhausted.actor_name,
142163
)
143-
async def handle_video_safety_check(operation_name: str, file_url: str, message_id: str) -> None:
144-
await _handle_video_safety_check_async(operation_name=operation_name, file_url=file_url, message_id=message_id)
145-
146-
147164
@tracer.start_as_current_span("handle_video_safety_check")
148-
async def _handle_video_safety_check_async(
149-
operation_name: str, file_url: str, message_id: str
150-
) -> GoogleVideoIntelligenceResponse:
165+
async def handle_video_safety_check(operation_name: str, file_url: str, message_id: str):
151166
span = trace.get_current_span()
152167
span.set_attributes({
153168
"operation_name": operation_name,
@@ -168,6 +183,11 @@ async def _handle_video_safety_check_async(
168183
trace.StatusCode.ERROR,
169184
f"Operation {operation_name} not found. The Operation endpoint may not have the operation yet.",
170185
)
186+
logger.warning(
187+
"video_safety.operation_not_available",
188+
operation_name=operation_name,
189+
message_id=message_id,
190+
)
171191
raise VideoIntelligenceOperationNotAvailableError
172192

173193
try:
@@ -182,35 +202,30 @@ async def _handle_video_safety_check_async(
182202
trace.StatusCode.ERROR,
183203
f"Operation {operation_name} not found or not done. The Operation endpoint may not have the operation yet.",
184204
)
205+
logger.warning(
206+
"video_safety.operation_not_available",
207+
operation_name=operation_name,
208+
message_id=message_id,
209+
)
185210
raise VideoIntelligenceOperationNotAvailableError from e
186211

187212
result = await operation.result()
188213

189214
if not isinstance(result, AnnotateVideoResponse):
190215
msg = "Unexpected result from google video checker"
216+
span.set_status(trace.StatusCode.ERROR, msg)
191217
raise TypeError(msg)
192218

193219
mapped_response = GoogleVideoIntelligenceResponse(result)
194220
span.set_attribute("is_safe", mapped_response.is_safe())
195221

196-
# if we are blocking mode -- the message doesn't exist yet and we will return
197-
# a response to indicate its not safe back to the main event loop
198-
if settings.VIDEO_SAFETY_CHECK_WORKER_STRATEGY == "inline":
199-
logger.info(
200-
"video_safety.blocking_mode.complete",
201-
operation=operation_name,
202-
is_safe=mapped_response.is_safe(),
203-
message_id=message_id,
204-
)
205-
206-
return mapped_response
207-
208222
message_repository = AsyncMessageRepository(session)
209223
message = await message_repository.get_message_by_id(message_id)
210224

211225
if message is None:
212226
not_found_message = f"Message {message_id} not found when evaluating a video safety check"
213-
logger.error(
227+
span.set_status(trace.StatusCode.ERROR, not_found_message)
228+
logger.warning(
214229
"video_safety.message_not_found",
215230
operation=operation_name,
216231
message_id=message_id,
@@ -228,29 +243,22 @@ async def _handle_video_safety_check_async(
228243
)
229244

230245
storage_client = get_google_cloud_storage()
231-
232246
safety_file_name = Path(file_url).parts[-1]
233247

234248
await storage_client.delete_file(
235249
filename=safety_file_name,
236250
bucket_name=settings.SAFTEY_GCS_UPLOAD_BUCKET,
237251
)
238252

239-
if message.file_urls:
240-
await storage_client.delete_multiple_files_by_url(
241-
file_urls=message.file_urls,
242-
bucket_name=settings.USER_CONTENT_BUCKET,
243-
)
244-
message.file_urls = None
253+
await storage_client.delete_prefix(prefix=message_id, bucket_name=settings.USER_CONTENT_BUCKET)
254+
message.file_urls = None
245255

246256
await message_repository.update(message)
247257
await session.commit()
248258

249259
logger.info(
250-
"video_safety.queue_mode.complete",
251-
operation=operation_name,
260+
"video_safety.complete",
261+
operation_name=operation_name,
252262
is_safe=mapped_response.is_safe(),
253263
message_id=message_id,
254264
)
255-
256-
return mapped_response

0 commit comments

Comments
 (0)