Skip to content

Commit 715008a

Browse files
committed
Work around type hinting problems with redis.asyncio
Typing for the asyncio redis client is messed up, and functions return `T | Awaitable[T]` instead of `T`. Add a workaround wrapper that asserts that we are in the second case. See: redis/redis-py#3619 for an in-progress upstream fix.
1 parent 9f25dc0 commit 715008a

File tree

4 files changed

+40
-19
lines changed

4 files changed

+40
-19
lines changed

beeai/agents/backport_agent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from tools.text import CreateTool, InsertTool, StrReplaceTool, ViewTool
3333
from tools.wicked_git import GitLogSearchTool, GitPatchCreationTool
3434
from triage_agent import BackportData, ErrorData
35-
from utils import check_subprocess, get_agent_execution_config, mcp_tools, redis_client, post_private_jira_comment
35+
from utils import fix_await, check_subprocess, get_agent_execution_config, mcp_tools, redis_client, post_private_jira_comment
3636

3737
logger = logging.getLogger(__name__)
3838

@@ -234,7 +234,7 @@ class Task(BaseModel):
234234

235235
while True:
236236
logger.info("Waiting for tasks from backport_queue (timeout: 30s)...")
237-
element = await redis.brpop("backport_queue", timeout=30)
237+
element = await fix_await(redis.brpop(["backport_queue"], timeout=30))
238238
if element is None:
239239
logger.info("No tasks received, continuing to wait...")
240240
continue
@@ -256,13 +256,13 @@ async def retry(task, error):
256256
f"Task failed (attempt {task.attempts}/{max_retries}), "
257257
f"re-queuing for retry: {backport_data.jira_issue}"
258258
)
259-
await redis.lpush("backport_queue", task.model_dump_json())
259+
await fix_await(redis.lpush("backport_queue", task.model_dump_json()))
260260
else:
261261
logger.error(
262262
f"Task failed after {max_retries} attempts, "
263263
f"moving to error list: {backport_data.jira_issue}"
264264
)
265-
await redis.lpush("error_list", error)
265+
await fix_await(redis.lpush("error_list", error))
266266

267267
try:
268268
logger.info(f"Starting backport processing for {backport_data.jira_issue}")
@@ -298,7 +298,7 @@ async def retry(task, error):
298298
rmtree(local_clone)
299299
if state.backport_data.success:
300300
logger.info(f"Backport successful for {backport_data.jira_issue}, " f"adding to completed list")
301-
await redis.lpush("completed_backport_list", state.backport_data.model_dump_json())
301+
await redis.lpush("completed_backport_list", output.model_dump_json())
302302
else:
303303
logger.warning(f"Backport failed for {backport_data.jira_issue}: {state.backport_data.error}")
304304
await retry(task, state.backport_data.error)

beeai/agents/rebase_agent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from tools.specfile import AddChangelogEntryTool
2929
from tools.text import CreateTool, InsertTool, StrReplaceTool, ViewTool
3030
from triage_agent import RebaseData, ErrorData
31-
from utils import get_agent_execution_config, mcp_tools, redis_client, run_tool, post_private_jira_comment
31+
from utils import fix_await, get_agent_execution_config, mcp_tools, redis_client, run_tool, post_private_jira_comment
3232

3333
logger = logging.getLogger(__name__)
3434

@@ -237,7 +237,7 @@ class Task(BaseModel):
237237

238238
while True:
239239
logger.info("Waiting for tasks from rebase_queue (timeout: 30s)...")
240-
element = await redis.brpop("rebase_queue", timeout=30)
240+
element = await fix_await(redis.brpop(["rebase_queue"], timeout=30))
241241
if element is None:
242242
logger.info("No tasks received, continuing to wait...")
243243
continue
@@ -260,13 +260,13 @@ async def retry(task, error):
260260
f"Task failed (attempt {task.attempts}/{max_retries}), "
261261
f"re-queuing for retry: {rebase_data.jira_issue}"
262262
)
263-
await redis.lpush("rebase_queue", task.model_dump_json())
263+
await fix_await(redis.lpush("rebase_queue", task.model_dump_json()))
264264
else:
265265
logger.error(
266266
f"Task failed after {max_retries} attempts, "
267267
f"moving to error list: {rebase_data.jira_issue}"
268268
)
269-
await redis.lpush("error_list", error)
269+
await fix_await(redis.lpush("error_list", error))
270270

271271
try:
272272
logger.info(f"Starting rebase processing for {rebase_data.jira_issue}")
@@ -296,7 +296,7 @@ async def retry(task, error):
296296
else:
297297
if state.rebase_result.success:
298298
logger.info(f"Rebase successful for {rebase_data.jira_issue}, " f"adding to completed list")
299-
await redis.lpush("completed_rebase_list", state.rebase_result.model_dump_json())
299+
await fix_await(redis.lpush("completed_rebase_list", state.rebase_result.model_dump_json()))
300300
else:
301301
logger.warning(f"Rebase failed for {rebase_data.jira_issue}: {state.rebase_result.error}")
302302
await retry(task, state.rebase_result.error)

beeai/agents/triage_agent.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from tools.commands import RunShellCommandTool
2525
from tools.patch_validator import PatchValidatorTool
2626
from tools.version_mapper import VersionMapperTool
27-
from utils import get_agent_execution_config, mcp_tools, redis_client, post_private_jira_comment
27+
from utils import fix_await, get_agent_execution_config, mcp_tools, redis_client, post_private_jira_comment
2828

2929
logger = logging.getLogger(__name__)
3030

@@ -303,7 +303,7 @@ class Task(BaseModel):
303303

304304
while True:
305305
logger.info("Waiting for tasks from triage_queue (timeout: 30s)...")
306-
element = await redis.brpop("triage_queue", timeout=30)
306+
element = await fix_await(redis.brpop(["triage_queue"], timeout=30))
307307
if element is None:
308308
logger.info("No tasks received, continuing to wait...")
309309
continue
@@ -322,12 +322,12 @@ async def retry(task, error):
322322
f"Task failed (attempt {task.attempts}/{max_retries}), "
323323
f"re-queuing for retry: {input.issue}"
324324
)
325-
await redis.lpush("triage_queue", task.model_dump_json())
325+
await fix_await(redis.lpush("triage_queue", task.model_dump_json()))
326326
else:
327327
logger.error(
328328
f"Task failed after {max_retries} attempts, " f"moving to error list: {input.issue}"
329329
)
330-
await redis.lpush("error_list", error)
330+
await fix_await(redis.lpush("error_list", error))
331331

332332
try:
333333
logger.info(f"Starting triage processing for {input.issue}")
@@ -353,21 +353,21 @@ async def retry(task, error):
353353
if output.resolution == Resolution.REBASE:
354354
logger.info(f"Triage resolved as REBASE for {input.issue}, " f"adding to rebase queue")
355355
task = Task(metadata=output.data.model_dump())
356-
await redis.lpush("rebase_queue", task.model_dump_json())
356+
await fix_await(redis.lpush("rebase_queue", task.model_dump_json()))
357357
elif output.resolution == Resolution.BACKPORT:
358358
logger.info(f"Triage resolved as BACKPORT for {input.issue}, " f"adding to backport queue")
359359
task = Task(metadata=output.data.model_dump())
360-
await redis.lpush("backport_queue", task.model_dump_json())
360+
await fix_await(redis.lpush("backport_queue", task.model_dump_json()))
361361
elif output.resolution == Resolution.CLARIFICATION_NEEDED:
362362
logger.info(
363363
f"Triage resolved as CLARIFICATION_NEEDED for {input.issue}, "
364364
f"adding to clarification needed queue"
365365
)
366366
task = Task(metadata=output.data.model_dump())
367-
await redis.lpush("clarification_needed_queue", task.model_dump_json())
367+
await fix_await(redis.lpush("clarification_needed_queue", task.model_dump_json()))
368368
elif output.resolution == Resolution.NO_ACTION:
369369
logger.info(f"Triage resolved as NO_ACTION for {input.issue}, " f"adding to no action list")
370-
await redis.lpush("no_action_list", output.data.model_dump_json())
370+
await fix_await(redis.lpush("no_action_list", output.data.model_dump_json()))
371371
elif output.resolution == Resolution.ERROR:
372372
logger.warning(f"Triage resolved as ERROR for {input.issue}, retrying")
373373
await retry(task, output.data.model_dump_json())

beeai/agents/utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import asyncio
2+
import inspect
23
import logging
34
import os
45
import shlex
56
import subprocess
67
from contextlib import asynccontextmanager
78
from pathlib import Path
8-
from typing import Any, AsyncGenerator, Callable, Tuple
9+
from typing import Any, AsyncGenerator, Awaitable, Callable, TypeVar, Tuple
910

1011
import redis.asyncio as redis
1112
from mcp import ClientSession
@@ -96,6 +97,26 @@ async def redis_client(redis_url: str) -> AsyncGenerator[redis.Redis, None]:
9697
await client.aclose()
9798

9899

100+
T = TypeVar("T")
101+
102+
async def fix_await(v: T | Awaitable[T]) -> T:
103+
"""
104+
Work around typing problems in the asyncio redis client.
105+
106+
Typing for the asyncio redis client is messed up, and functions
107+
return `T | Awaitable[T]` instead of `T`. This function
108+
fixes the type error by asserting that the value is awaitable
109+
before awaiting it.
110+
111+
For a proper fix, see: https://github.com/redis/redis-py/pull/3619
112+
113+
114+
Usage: `await fixAwait(redis.get("key"))`
115+
"""
116+
assert inspect.isawaitable(v)
117+
return await v
118+
119+
99120
@asynccontextmanager
100121
async def mcp_tools(
101122
sse_url: str, filter: Callable[[str], bool] | None = None

0 commit comments

Comments
 (0)