Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions app/data_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import json
from typing import Any, Dict, List, Optional
from dataclasses import dataclass


@dataclass()
class EnrichedData:
data: Dict[str, Any]
input: Optional[Dict[str, Any]] = None
task: str = "process_algos"

def __post_init__(self) -> None:
self.input = {"task": self.task, "transactions": [self.data]}

def __repr__(self) -> str:
return json.dumps(self.wrap())

def wrap(self) -> Dict[str, Any]:
return {"input": {"task": self.task, "transactions": [self.data]}}

def transactions(self) -> List[Dict[str, Any]]:
return [self.data]
126 changes: 89 additions & 37 deletions app/jetstream.py
Copy link
Collaborator Author

@bgroupe bgroupe May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DGaffney I actually have a submodule with just this jetstream and a slimmed down build, since it doesn't need any of the heavy grazer dependencies. But I haven't pushed it yet because I was curious if this would get replaced at some point by jetstream-turbo?

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from app.settings import JETSTREAM_URL as SETTINGS_JETSTREAM_URL
from app.sentry import sentry_sdk


class Jetstream:
"""Consume Bluesky Jetstream and (optionally) forward *app.bsky.feed.post*
messages to an AWS SQS queue.
Expand All @@ -30,18 +31,23 @@ class Jetstream:
JETSTREAM_URL: str = os.getenv(
"JETSTREAM_URL",
f"{SETTINGS_JETSTREAM_URL}&wantedCollections=app.bsky.feed.post"
if "wantedCollections" not in SETTINGS_JETSTREAM_URL else SETTINGS_JETSTREAM_URL,
if "wantedCollections" not in SETTINGS_JETSTREAM_URL
else SETTINGS_JETSTREAM_URL,
)

# Determine the AWS region from the SQS queue URL
SQS_QUEUE_URL: str | None = os.getenv("SQS_QUEUE_URL")
AWS_REGION: str = os.getenv("AWS_REGION", "us-east-1") # Default to us-east-1 based on the queue URL
AWS_REGION: str = os.getenv(
"AWS_REGION", "us-east-1"
) # Default to us-east-1 based on the queue URL
REDIS_URL: str | None = os.getenv("REDIS_URL")

BATCH_SIZE: int = int(os.getenv("BATCH_SIZE", 2_000))
FLUSH_INTERVAL: int = int(os.getenv("FLUSH_INTERVAL", 30)) # seconds
MAX_SQS_BATCH: int = 10 # AWS hard‑limit per SendMessageBatch
QUEUE_MAX_SIZE: int = int(os.getenv("QUEUE_MAX_SIZE", 10_000)) # Maximum size for the message queue
QUEUE_MAX_SIZE: int = int(
os.getenv("QUEUE_MAX_SIZE", 10_000)
) # Maximum size for the message queue

# ------------------------------------------------------------------
# Helpers
Expand All @@ -61,12 +67,16 @@ def _validate_raw_message(raw: str) -> str | None:
if created_at:
for fmt in ("%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ"):
try:
ts = datetime.strptime(created_at, fmt).replace(tzinfo=timezone.utc)
ts = datetime.strptime(created_at, fmt).replace(
tzinfo=timezone.utc
)
break
except ValueError:
ts = None
if not ts or ts > datetime.utcnow().replace(tzinfo=timezone.utc):
data["commit"]["record"]["createdAt"] = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
data["commit"]["record"]["createdAt"] = datetime.utcnow().strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
)
return json.dumps(data)
except Exception as exc: # noqa: BLE001
logger.debug("validation error: %s", exc)
Expand All @@ -85,20 +95,23 @@ async def _init_sqs_client(cls):
try:
logger.info("Initializing SQS client in region %s", cls.AWS_REGION)
session = aioboto3.Session()
cls.SQS_CLIENT = await session.client("sqs", region_name=cls.AWS_REGION).__aenter__()
cls.SQS_CLIENT = await session.client(
"sqs", region_name=cls.AWS_REGION
).__aenter__()

# Verify the queue exists
logger.info("Verifying SQS queue exists: %s", cls.SQS_QUEUE_URL)
try:
# Get queue attributes as a simple way to check if queue exists
await cls.SQS_CLIENT.get_queue_attributes(
QueueUrl=cls.SQS_QUEUE_URL,
AttributeNames=['QueueArn']
QueueUrl=cls.SQS_QUEUE_URL, AttributeNames=["QueueArn"]
)
logger.info("SQS queue verified successfully")
except Exception as e:
logger.error("SQS queue verification failed: %s", e)
logger.error("Please check if the queue URL is correct and the queue exists")
logger.error(
"Please check if the queue URL is correct and the queue exists"
)
# Reset client so we don't keep using a bad configuration
await cls.SQS_CLIENT.close()
cls.SQS_CLIENT = None
Expand Down Expand Up @@ -129,21 +142,29 @@ async def _send_batch_to_sqs(cls, batch: list[str]):
return

for chunk in cls._chunk(batch, cls.MAX_SQS_BATCH):
entries = [{"Id": str(i), "MessageBody": body} for i, body in enumerate(chunk)]
entries = [
{"Id": str(i), "MessageBody": body} for i, body in enumerate(chunk)
]
try:
logger.debug("Sending batch of %d messages to SQS", len(entries))
resp = await cls.SQS_CLIENT.send_message_batch(QueueUrl=cls.SQS_QUEUE_URL, Entries=entries)
resp = await cls.SQS_CLIENT.send_message_batch(
QueueUrl=cls.SQS_QUEUE_URL, Entries=entries
)
if failed := resp.get("Failed"):
logger.error("%d SQS failures: %s", len(failed), failed)
else:
logger.debug("Successfully sent %d messages to SQS", len(entries))
logger.debug(
"Successfully sent %d messages to SQS", len(entries)
)
except Exception as e:
logger.error("SQS send_message_batch error: %s", e)
sentry_sdk.capture_exception(e)

# If we get a NonExistentQueue error, try to reinitialize the client
if "NonExistentQueue" in str(e):
logger.info("Attempting to reinitialize SQS client due to NonExistentQueue error")
logger.info(
"Attempting to reinitialize SQS client due to NonExistentQueue error"
)
if cls.SQS_CLIENT:
await cls.SQS_CLIENT.close()
cls.SQS_CLIENT = None
Expand All @@ -157,7 +178,9 @@ async def _send_batch_to_sqs(cls, batch: list[str]):
# Producer / flusher tasks
# ------------------------------------------------------------------
@classmethod
async def _producer(cls, queue: asyncio.Queue, redis, shutdown_event: asyncio.Event):
async def _producer(
cls, queue: asyncio.Queue, redis, shutdown_event: asyncio.Event
):
"""Continuously read websocket and enqueue validated messages.

This task runs independently and is never blocked by SQS operations.
Expand All @@ -169,21 +192,27 @@ async def _producer(cls, queue: asyncio.Queue, redis, shutdown_event: asyncio.Ev
2. Current time if no Redis cursor is found (start from now)
"""
# For SQS streaming, use current time as default or stored cursor from Redis
cursor = int(datetime.utcnow().timestamp() * 1_000_000) # Current time in microseconds
cursor = int(
datetime.utcnow().timestamp() * 1_000_000
) # Current time in microseconds
if redis and (stored := await redis.get("jetstream:last_cursor")):
try:
cursor = int(stored)
logger.info("Retrieved cursor from Redis: %s (timestamp: %s)",
cursor,
datetime.fromtimestamp(cursor / 1_000_000, tz=timezone.utc))
logger.info(
"Retrieved cursor from Redis: %s (timestamp: %s)",
cursor,
datetime.fromtimestamp(cursor / 1_000_000, tz=timezone.utc),
)
except (ValueError, TypeError) as e:
logger.warning("Failed to parse Redis cursor %r: %s", stored, e)

ws_url = f"{cls.JETSTREAM_URL}&maxMessageSizeBytes=100000&cursor={cursor}"
logger.info("Producer connecting with cursor for current time: %s", ws_url)

try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=10) as ws:
async with websockets.connect(
ws_url, ping_interval=20, ping_timeout=10
) as ws:
while not shutdown_event.is_set():
try:
raw = await asyncio.wait_for(ws.recv(), timeout=1.0)
Expand All @@ -193,13 +222,15 @@ async def _producer(cls, queue: asyncio.Queue, redis, shutdown_event: asyncio.Ev
await asyncio.wait_for(queue.put(msg), timeout=0.1)
except asyncio.TimeoutError:
# Queue is full, log warning but keep websocket reading
if not getattr(cls, '_queue_full_warned', False):
logger.warning("Message queue full! Some messages may be dropped.")
if not getattr(cls, "_queue_full_warned", False):
logger.warning(
"Message queue full! Some messages may be dropped."
)
cls._queue_full_warned = True
continue

# Reset warning flag if we successfully put a message
if getattr(cls, '_queue_full_warned', False):
if getattr(cls, "_queue_full_warned", False):
cls._queue_full_warned = False
except asyncio.TimeoutError:
# Just a timeout on websocket read, continue
Expand Down Expand Up @@ -229,8 +260,13 @@ async def _flusher(cls, queue: asyncio.Queue, redis, shutdown_event: asyncio.Eve
while not shutdown_event.is_set():
try:
# Calculate remaining time until next scheduled flush
timeout = cls.FLUSH_INTERVAL - (datetime.utcnow() - last_flush).total_seconds()
timeout = max(timeout, 0.1) # Small minimum timeout to check shutdown_event
timeout = (
cls.FLUSH_INTERVAL
- (datetime.utcnow() - last_flush).total_seconds()
)
timeout = max(
timeout, 0.1
) # Small minimum timeout to check shutdown_event

try:
# Get message with timeout to allow periodic flushing
Expand All @@ -243,8 +279,9 @@ async def _flusher(cls, queue: asyncio.Queue, redis, shutdown_event: asyncio.Eve

# Check if we should flush based on batch size or time elapsed
should_flush = (
len(batch) >= cls.BATCH_SIZE or
(datetime.utcnow() - last_flush).total_seconds() >= cls.FLUSH_INTERVAL
len(batch) >= cls.BATCH_SIZE
or (datetime.utcnow() - last_flush).total_seconds()
>= cls.FLUSH_INTERVAL
)

if should_flush and batch:
Expand Down Expand Up @@ -292,7 +329,9 @@ async def stream_to_sqs(cls):
if cls.SQS_CLIENT is not None:
logger.info("SQS client initialized successfully")
else:
logger.warning("Initial SQS client initialization failed, will retry later")
logger.warning(
"Initial SQS client initialization failed, will retry later"
)
except Exception as e:
logger.warning("Failed to initialize SQS client on startup: %s", e)
logger.info("Will continue and retry SQS initialization later")
Expand All @@ -315,18 +354,23 @@ async def stream_to_sqs(cls):
try:
logger.info("Starting producer and flusher tasks")
# Create both tasks and run them concurrently, but independently
producer = asyncio.create_task(cls._producer(queue, redis, shutdown_event))
flusher = asyncio.create_task(cls._flusher(queue, redis, shutdown_event))
producer = asyncio.create_task(
cls._producer(queue, redis, shutdown_event)
)
flusher = asyncio.create_task(
cls._flusher(queue, redis, shutdown_event)
)

# Wait for either task to complete
done, pending = await asyncio.wait(
[producer, flusher],
return_when=asyncio.FIRST_COMPLETED
[producer, flusher], return_when=asyncio.FIRST_COMPLETED
)

# Signal shutdown and cancel remaining tasks
shutdown_event.set()
logger.info("One task completed, signaling shutdown to remaining tasks")
logger.info(
"One task completed, signaling shutdown to remaining tasks"
)
for task in pending:
task.cancel()

Expand Down Expand Up @@ -380,7 +424,9 @@ async def stream_to_sqs(cls):
@classmethod
async def graceful_close(cls, ws):
try:
await asyncio.wait_for(ws.close(code=1000, reason="End of slice"), timeout=1)
await asyncio.wait_for(
ws.close(code=1000, reason="End of slice"), timeout=1
)
except asyncio.TimeoutError:
logger.warning("Graceful websocket close timed out; ignoring.")

Expand All @@ -394,7 +440,9 @@ async def fetch_minute_data(cls, start_us: int, end_us: int):
try:
raw_msg = await asyncio.wait_for(ws.recv(), timeout=2)
except asyncio.TimeoutError:
logger.info("No more data for slice %s->%s (2s silence).", start_us, end_us)
logger.info(
"No more data for slice %s->%s (2s silence).", start_us, end_us
)
break
except websockets.ConnectionClosed:
logger.info("Connection closed for slice %s->%s", start_us, end_us)
Expand All @@ -415,9 +463,11 @@ async def fetch_minute_data(cls, start_us: int, end_us: int):
asyncio.create_task(cls.graceful_close(ws))

@classmethod
async def yield_jetstream_reversed(cls, end_cursor: int | None = None, start_cursor: int | None = None):
async def yield_jetstream_reversed(
cls, end_cursor: int | None = None, start_cursor: int | None = None
):
now_us = int(datetime.utcnow().timestamp() * 1_000_000)
end_cursor = end_cursor or now_us - 60_000_000
end_cursor = end_cursor or now_us - 60_000_000
start_cursor = start_cursor or now_us - 24 * 3_600 * 1_000_000
if start_cursor >= end_cursor:
logger.warning("start_cursor >= end_cursor; no data to pull.")
Expand All @@ -426,7 +476,9 @@ async def yield_jetstream_reversed(cls, end_cursor: int | None = None, start_cur
while current_end > start_cursor:
one_min_ago = current_end - 60_000_000
time_range_start = max(one_min_ago, start_cursor)
utc_time = datetime.fromtimestamp(time_range_start / 1_000_000, tz=timezone.utc)
utc_time = datetime.fromtimestamp(
time_range_start / 1_000_000, tz=timezone.utc
)
logger.info("Reading slice from %s", utc_time)
async for record in cls.fetch_minute_data(time_range_start, current_end):
yield record
Expand Down
1 change: 1 addition & 0 deletions app/kube/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class KubeProcessor(KubeBase):
async def ingest_feed(cls, transactions):
records = []
deletes = []

for transaction in transactions:
if transaction.get("commit", {}).get("operation") == "create":
records.append(transaction)
Expand Down
13 changes: 7 additions & 6 deletions app/kube/router.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from app.data_types import EnrichedData
from app.kube.processor import KubeProcessor
from app.logger import logger


class KubeRouter:
@classmethod
async def process_request(cls, dispatcher, params, noop: bool):
logger.info("Params are here!")
logger.info(type(params))
logger.info(params.keys())
async def process_request(cls, dispatcher, params: EnrichedData, noop: bool):
logger.info(params.wrap)
logger.info(params)
if noop:
logger.info("noop")
else:
if params.get("task") == "process_algos":
if params.task == "process_algos":

await KubeProcessor.process_algos(
dispatcher, params.get("transactions")
dispatcher, params.transactions()
)
2 changes: 1 addition & 1 deletion app/ray/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from app.sentry import sentry_sdk


@ray.remote(max_concurrency=1000) # type: ignore
@ray.remote(max_concurrency=1000, max_task_retries=-1, max_restarts=-1) # type: ignore
class Cache(TimingBase):
def __init__(self, key_prefix="ray_workers", batch_size=100):
"""
Expand Down
2 changes: 1 addition & 1 deletion app/ray/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from app.telemetry import Telemetry


@ray.remote(max_concurrency=5) # type: ignore
@ray.remote(max_concurrency=5, max_task_retries=-1, max_restarts=-1) # type: ignore
class CPUWorker(TimingBase):
def __init__(
self,
Expand Down
9 changes: 6 additions & 3 deletions app/ray/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
import asyncio
from app.ray.utils import discover_named_actors, discover_named_actor

from app.settings import OmniBootSettings

class Dispatcher:
def __init__(
Expand All @@ -22,6 +22,9 @@ def __init__(
gpu_worker: A reference to the GPUWorker actor.
cache: A reference to the shared Cache actor.
"""
self.boot_settings = OmniBootSettings()
namespace = self.boot_settings.namespace

print("Looking for cache...")
self.cache = cache or discover_named_actor("cache:", timeout=10)
print("Looking for Bluesky Semaphore...")
Expand All @@ -38,10 +41,10 @@ def __init__(
)
print("Looking for GPU Worker...")
self.gpu_embedding_workers = gpu_embedding_workers or discover_named_actors(
"gpu:embedders", timeout=10
f"gpu:{namespace}:embedders", timeout=10
)
self.gpu_classifier_workers = gpu_classifier_workers or discover_named_actors(
"gpu:classifiers", timeout=10
f"gpu:{namespace}:classifiers", timeout=10
)
print("Looking for CPU Workers...")
self.cpu_workers = cpu_workers or discover_named_actors("cpu:", timeout=10)
Expand Down
Loading