Skip to content

Commit fd5871b

Browse files
committed
Work around type hinting problems with redis.asyncio
Typing for the asyncio redis client is messed up, and functions return `T | Awaitable[T]` instead of `T`. Add a workaround wrapper that asserts that we are in the second case. See: redis/redis-py#3619 for an in-progress upstream fix.
1 parent b4ed3a9 commit fd5871b

File tree

4 files changed

+40
-19
lines changed

4 files changed

+40
-19
lines changed

beeai/agents/backport_agent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from observability import setup_observability
3131
from tools.commands import RunShellCommandTool
3232
from triage_agent import BackportData, ErrorData
33-
from utils import get_agent_execution_config, mcp_tools, redis_client, get_git_finalization_steps
33+
from utils import fixAwait, get_agent_execution_config, mcp_tools, redis_client, get_git_finalization_steps
3434

3535
logger = logging.getLogger(__name__)
3636

@@ -230,7 +230,7 @@ class Task(BaseModel):
230230

231231
while True:
232232
logger.info("Waiting for tasks from backport_queue (timeout: 30s)...")
233-
element = await redis.brpop("backport_queue", timeout=30)
233+
element = await fixAwait(redis.brpop(["backport_queue"], timeout=30))
234234
if element is None:
235235
logger.info("No tasks received, continuing to wait...")
236236
continue
@@ -264,13 +264,13 @@ async def retry(task, error):
264264
f"Task failed (attempt {task.attempts}/{max_retries}), "
265265
f"re-queuing for retry: {backport_data.jira_issue}"
266266
)
267-
await redis.lpush("backport_queue", task.model_dump_json())
267+
await fixAwait(redis.lpush("backport_queue", task.model_dump_json()))
268268
else:
269269
logger.error(
270270
f"Task failed after {max_retries} attempts, "
271271
f"moving to error list: {backport_data.jira_issue}"
272272
)
273-
await redis.lpush("error_list", error)
273+
await fixAwait(redis.lpush("error_list", error))
274274

275275
try:
276276
logger.info(f"Starting backport processing for {backport_data.jira_issue}")
@@ -287,7 +287,7 @@ async def retry(task, error):
287287
rmtree(local_clone)
288288
if output.success:
289289
logger.info(f"Backport successful for {backport_data.jira_issue}, " f"adding to completed list")
290-
await redis.lpush("completed_backport_list", output.model_dump_json())
290+
await fixAwait(redis.lpush("completed_backport_list", output.model_dump_json()))
291291
else:
292292
logger.warning(f"Backport failed for {backport_data.jira_issue}: {output.error}")
293293
await retry(task, output.error)

beeai/agents/rebase_agent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from observability import setup_observability
2525
from tools.commands import RunShellCommandTool
2626
from triage_agent import RebaseData, ErrorData
27-
from utils import get_agent_execution_config, mcp_tools, redis_client, get_git_finalization_steps
27+
from utils import fixAwait, get_agent_execution_config, mcp_tools, redis_client, get_git_finalization_steps
2828

2929
logger = logging.getLogger(__name__)
3030

@@ -195,7 +195,7 @@ class Task(BaseModel):
195195

196196
while True:
197197
logger.info("Waiting for tasks from rebase_queue (timeout: 30s)...")
198-
element = await redis.brpop("rebase_queue", timeout=30)
198+
element = await fixAwait(redis.brpop(["rebase_queue"], timeout=30))
199199
if element is None:
200200
logger.info("No tasks received, continuing to wait...")
201201
continue
@@ -225,13 +225,13 @@ async def retry(task, error):
225225
f"Task failed (attempt {task.attempts}/{max_retries}), "
226226
f"re-queuing for retry: {rebase_data.jira_issue}"
227227
)
228-
await redis.lpush("rebase_queue", task.model_dump_json())
228+
await fixAwait(redis.lpush("rebase_queue", task.model_dump_json()))
229229
else:
230230
logger.error(
231231
f"Task failed after {max_retries} attempts, "
232232
f"moving to error list: {rebase_data.jira_issue}"
233233
)
234-
await redis.lpush("error_list", error)
234+
await fixAwait(redis.lpush("error_list", error))
235235

236236
try:
237237
logger.info(f"Starting rebase processing for {rebase_data.jira_issue}")
@@ -246,7 +246,7 @@ async def retry(task, error):
246246
else:
247247
if output.success:
248248
logger.info(f"Rebase successful for {rebase_data.jira_issue}, " f"adding to completed list")
249-
await redis.lpush("completed_rebase_list", output.model_dump_json())
249+
await fixAwait(redis.lpush("completed_rebase_list", output.model_dump_json()))
250250
else:
251251
logger.warning(f"Rebase failed for {rebase_data.jira_issue}: {output.error}")
252252
await retry(task, output.error)

beeai/agents/triage_agent.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from observability import setup_observability
2424
from tools.commands import RunShellCommandTool
2525
from tools.patch_validator import PatchValidatorTool
26-
from utils import get_agent_execution_config, mcp_tools, redis_client
26+
from utils import fixAwait, get_agent_execution_config, mcp_tools, redis_client
2727

2828
logger = logging.getLogger(__name__)
2929

@@ -285,7 +285,7 @@ class Task(BaseModel):
285285

286286
while True:
287287
logger.info("Waiting for tasks from triage_queue (timeout: 30s)...")
288-
element = await redis.brpop("triage_queue", timeout=30)
288+
element = await fixAwait(redis.brpop(["triage_queue"], timeout=30))
289289
if element is None:
290290
logger.info("No tasks received, continuing to wait...")
291291
continue
@@ -304,12 +304,12 @@ async def retry(task, error):
304304
f"Task failed (attempt {task.attempts}/{max_retries}), "
305305
f"re-queuing for retry: {input.issue}"
306306
)
307-
await redis.lpush("triage_queue", task.model_dump_json())
307+
await fixAwait(redis.lpush("triage_queue", task.model_dump_json()))
308308
else:
309309
logger.error(
310310
f"Task failed after {max_retries} attempts, " f"moving to error list: {input.issue}"
311311
)
312-
await redis.lpush("error_list", error)
312+
await fixAwait(redis.lpush("error_list", error))
313313

314314
try:
315315
logger.info(f"Starting triage processing for {input.issue}")
@@ -329,21 +329,21 @@ async def retry(task, error):
329329
if output.resolution == Resolution.REBASE:
330330
logger.info(f"Triage resolved as REBASE for {input.issue}, " f"adding to rebase queue")
331331
task = Task(metadata=output.data.model_dump())
332-
await redis.lpush("rebase_queue", task.model_dump_json())
332+
await fixAwait(redis.lpush("rebase_queue", task.model_dump_json()))
333333
elif output.resolution == Resolution.BACKPORT:
334334
logger.info(f"Triage resolved as BACKPORT for {input.issue}, " f"adding to backport queue")
335335
task = Task(metadata=output.data.model_dump())
336-
await redis.lpush("backport_queue", task.model_dump_json())
336+
await fixAwait(redis.lpush("backport_queue", task.model_dump_json()))
337337
elif output.resolution == Resolution.CLARIFICATION_NEEDED:
338338
logger.info(
339339
f"Triage resolved as CLARIFICATION_NEEDED for {input.issue}, "
340340
f"adding to clarification needed queue"
341341
)
342342
task = Task(metadata=output.data.model_dump())
343-
await redis.lpush("clarification_needed_queue", task.model_dump_json())
343+
await fixAwait(redis.lpush("clarification_needed_queue", task.model_dump_json()))
344344
elif output.resolution == Resolution.NO_ACTION:
345345
logger.info(f"Triage resolved as NO_ACTION for {input.issue}, " f"adding to no action list")
346-
await redis.lpush("no_action_list", output.data.model_dump_json())
346+
await fixAwait(redis.lpush("no_action_list", output.data.model_dump_json()))
347347
elif output.resolution == Resolution.ERROR:
348348
logger.warning(f"Triage resolved as ERROR for {input.issue}, retrying")
349349
await retry(task, output.data.model_dump_json())

beeai/agents/utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import inspect
12
import os
23

34
from contextlib import asynccontextmanager
4-
from typing import AsyncGenerator, Callable
5+
from typing import AsyncGenerator, Awaitable, Callable, TypeVar
56

67
import redis.asyncio as redis
78
from mcp import ClientSession
@@ -29,6 +30,26 @@ async def redis_client(redis_url: str) -> AsyncGenerator[redis.Redis, None]:
2930
await client.aclose()
3031

3132

33+
T = TypeVar("T")
34+
35+
async def fixAwait(v: T | Awaitable[T]) -> T:
36+
"""
37+
Work around typing problems in the asyncio redis client.
38+
39+
Typing for the asyncio redis client is messed up, and functions
40+
return `T | Awaitable[T]` instead of `T`. This function
41+
fixes the type error by asserting that the value is awaitable
42+
before awaiting it.
43+
44+
For a proper fix, see: https://github.com/redis/redis-py/pull/3619
45+
46+
47+
Usage: `await fixAwait(redis.get("key"))`
48+
"""
49+
assert inspect.isawaitable(v)
50+
return await v
51+
52+
3253
@asynccontextmanager
3354
async def mcp_tools(
3455
sse_url: str, filter: Callable[[str], bool] | None = None

0 commit comments

Comments
 (0)