2525 PatchTaskRequest ,
2626 UpdateTaskRequest ,
2727)
28+
29+ # AIDEV-NOTE: Fix Pydantic forward reference issues
30+ # Import all step types first
31+ from agents_api .autogen .Tasks import (
32+ EvaluateStep ,
33+ ForeachStep ,
34+ IfElseWorkflowStep ,
35+ ParallelStep ,
36+ PromptStep ,
37+ SwitchStep ,
38+ ToolCallStep ,
39+ WaitForInputStep ,
40+ YieldStep ,
41+ )
2842from agents_api .clients .pg import create_db_pool
2943from agents_api .common .utils .memory import total_size
3044from agents_api .env import api_key , api_key_header_name , multi_tenant_mode
4761from agents_api .queries .tools .create_tools import create_tools
4862from agents_api .queries .users .create_user import create_user
4963from agents_api .web import app
50- from aiobotocore .session import get_session
5164from fastapi .testclient import TestClient
5265from temporalio .client import WorkflowHandle
5366from uuid_extensions import uuid7
5467
5568from .utils import (
56- get_localstack ,
5769 get_pg_dsn ,
5870 make_vector_with_similarity ,
5971)
6072from .utils import (
6173 patch_embed_acompletion as patch_embed_acompletion_ctx ,
6274)
6375
64- # AIDEV-NOTE: Fix Pydantic forward reference issues
65- # Import all step types first
66- from agents_api .autogen .Tasks import (
67- EvaluateStep ,
68- ErrorWorkflowStep ,
69- ForeachStep ,
70- GetStep ,
71- IfElseWorkflowStep ,
72- LogStep ,
73- ParallelStep ,
74- PromptStep ,
75- ReturnStep ,
76- SetStep ,
77- SleepStep ,
78- SwitchStep ,
79- ToolCallStep ,
80- WaitForInputStep ,
81- YieldStep ,
82- )
83-
8476# Rebuild models to resolve forward references
8577try :
8678 CreateTaskRequest .model_rebuild ()
@@ -220,13 +212,13 @@ async def test_doc(pg_dsn, test_developer, test_agent):
220212 owner_id = test_agent .id ,
221213 connection_pool = pool ,
222214 )
223-
215+
224216 # Explicitly Refresh Indices
225217 await pool .execute ("REINDEX DATABASE" )
226-
218+
227219 doc = await get_doc (developer_id = test_developer .id , doc_id = resp .id , connection_pool = pool )
228220 yield doc
229-
221+
230222 # TODO: Delete the doc
231223 # await delete_doc(
232224 # developer_id=test_developer.id,
@@ -245,7 +237,7 @@ async def test_doc_with_embedding(pg_dsn, test_developer, test_doc):
245237 embedding_with_confidence_0_5 = make_vector_with_similarity (d = 0.5 )
246238 embedding_with_confidence_neg_0_5 = make_vector_with_similarity (d = - 0.5 )
247239 embedding_with_confidence_1_neg = make_vector_with_similarity (d = - 1.0 )
248-
240+
249241 # Insert embedding with all 1.0s (similarity = 1.0)
250242 await pool .execute (
251243 """
@@ -257,7 +249,7 @@ async def test_doc_with_embedding(pg_dsn, test_developer, test_doc):
257249 test_doc .content [0 ] if isinstance (test_doc .content , list ) else test_doc .content ,
258250 f"[{ ', ' .join ([str (x ) for x in [1.0 ] * 1024 ])} ]" ,
259251 )
260-
252+
261253 # Insert embedding with confidence 0
262254 await pool .execute (
263255 """
@@ -269,7 +261,7 @@ async def test_doc_with_embedding(pg_dsn, test_developer, test_doc):
269261 "Test content 1" ,
270262 f"[{ ', ' .join ([str (x ) for x in embedding_with_confidence_0 ])} ]" ,
271263 )
272-
264+
273265 # Insert embedding with confidence 0.5
274266 await pool .execute (
275267 """
@@ -281,7 +273,7 @@ async def test_doc_with_embedding(pg_dsn, test_developer, test_doc):
281273 "Test content 2" ,
282274 f"[{ ', ' .join ([str (x ) for x in embedding_with_confidence_0_5 ])} ]" ,
283275 )
284-
276+
285277 # Insert embedding with confidence -0.5
286278 await pool .execute (
287279 """
@@ -293,7 +285,7 @@ async def test_doc_with_embedding(pg_dsn, test_developer, test_doc):
293285 "Test content 3" ,
294286 f"[{ ', ' .join ([str (x ) for x in embedding_with_confidence_neg_0_5 ])} ]" ,
295287 )
296-
288+
297289 # Insert embedding with confidence -1
298290 await pool .execute (
299291 """
@@ -305,11 +297,13 @@ async def test_doc_with_embedding(pg_dsn, test_developer, test_doc):
305297 "Test content 4" ,
306298 f"[{ ', ' .join ([str (x ) for x in embedding_with_confidence_1_neg ])} ]" ,
307299 )
308-
300+
309301 # Explicitly Refresh Indices
310302 await pool .execute ("REINDEX DATABASE" )
311-
312- yield await get_doc (developer_id = test_developer .id , doc_id = test_doc .id , connection_pool = pool )
303+
304+ yield await get_doc (
305+ developer_id = test_developer .id , doc_id = test_doc .id , connection_pool = pool
306+ )
313307
314308
315309@pytest .fixture
@@ -328,13 +322,13 @@ async def test_user_doc(pg_dsn, test_developer, test_user):
328322 owner_id = test_user .id ,
329323 connection_pool = pool ,
330324 )
331-
325+
332326 # Explicitly Refresh Indices
333327 await pool .execute ("REINDEX DATABASE" )
334-
328+
335329 doc = await get_doc (developer_id = test_developer .id , doc_id = resp .id , connection_pool = pool )
336330 yield doc
337-
331+
338332 # TODO: Delete the doc
339333
340334
@@ -376,7 +370,7 @@ async def test_new_developer(pg_dsn, random_email):
376370 developer_id = dev_id ,
377371 connection_pool = pool ,
378372 )
379-
373+
380374 return await get_developer (
381375 developer_id = dev_id ,
382376 connection_pool = pool ,
@@ -416,7 +410,7 @@ async def test_execution(
416410 client = None ,
417411 id = "blah" ,
418412 )
419-
413+
420414 execution = await create_execution (
421415 developer_id = test_developer_id ,
422416 task_id = test_task .id ,
@@ -450,7 +444,7 @@ async def test_execution_started(
450444 client = None ,
451445 id = "blah" ,
452446 )
453-
447+
454448 execution = await create_execution (
455449 developer_id = test_developer_id ,
456450 task_id = test_task .id ,
@@ -462,9 +456,9 @@ async def test_execution_started(
462456 workflow_handle = workflow_handle ,
463457 connection_pool = pool ,
464458 )
465-
459+
466460 actual_scope_id = custom_scope_id or uuid7 ()
467-
461+
468462 # Start the execution
469463 await create_execution_transition (
470464 developer_id = test_developer_id ,
@@ -515,13 +509,13 @@ async def test_tool(
515509 "description" : "A function that prints hello world" ,
516510 "parameters" : {"type" : "object" , "properties" : {}},
517511 }
518-
512+
519513 tool_spec = {
520514 "function" : function ,
521515 "name" : "hello_world1" ,
522516 "type" : "function" ,
523517 }
524-
518+
525519 [tool , * _ ] = await create_tools (
526520 developer_id = test_developer_id ,
527521 agent_id = test_agent .id ,
@@ -539,8 +533,15 @@ async def test_tool(
539533
540534
541535@pytest .fixture (scope = "session" )
542- def client (pg_dsn ):
536+ def client (pg_dsn , localstack_container ):
543537 """Test client fixture."""
538+ import os
539+
540+ # Set S3 environment variables before creating TestClient
541+ os .environ ["S3_ACCESS_KEY" ] = localstack_container .env ["AWS_ACCESS_KEY_ID" ]
542+ os .environ ["S3_SECRET_KEY" ] = localstack_container .env ["AWS_SECRET_ACCESS_KEY" ]
543+ os .environ ["S3_ENDPOINT" ] = localstack_container .get_url ()
544+
544545 with (
545546 TestClient (app = app ) as test_client ,
546547 patch (
@@ -550,63 +551,81 @@ def client(pg_dsn):
550551 ):
551552 yield test_client
552553
554+ # Clean up env vars
555+ for key in ["S3_ACCESS_KEY" , "S3_SECRET_KEY" , "S3_ENDPOINT" ]:
556+ if key in os .environ :
557+ del os .environ [key ]
558+
553559
554560@pytest .fixture
555561async def make_request (client , test_developer_id ):
556562 """Factory fixture for making authenticated requests."""
563+
557564 def _make_request (method , url , ** kwargs ):
558565 headers = kwargs .pop ("headers" , {})
559566 headers = {
560567 ** headers ,
561568 api_key_header_name : api_key ,
562569 }
563-
570+
564571 if multi_tenant_mode :
565572 headers ["X-Developer-Id" ] = str (test_developer_id )
566-
573+
567574 headers ["Content-Length" ] = str (total_size (kwargs .get ("json" , {})))
568-
575+
569576 return client .request (method , url , headers = headers , ** kwargs )
570-
577+
571578 return _make_request
572579
573580
574- @pytest_asyncio .fixture
575- async def s3_client ():
576- """S3 client fixture."""
577- with get_localstack () as localstack :
578- s3_endpoint = localstack .get_url ()
579-
580- from botocore .config import Config
581-
582- session = get_session ()
583- s3 = await session .create_client (
584- "s3" ,
585- endpoint_url = s3_endpoint ,
586- aws_access_key_id = localstack .env ["AWS_ACCESS_KEY_ID" ],
587- aws_secret_access_key = localstack .env ["AWS_SECRET_ACCESS_KEY" ],
588- config = Config (s3 = {'addressing_style' : 'path' })
589- ).__aenter__ ()
590-
591- app .state .s3_client = s3
592-
593- # Create the bucket if it doesn't exist
594- from agents_api .env import blob_store_bucket
595- try :
596- await s3 .head_bucket (Bucket = blob_store_bucket )
597- except Exception :
598- await s3 .create_bucket (Bucket = blob_store_bucket )
599-
600- try :
601- yield s3
602- finally :
603- await s3 .close ()
604- app .state .s3_client = None
581+ @pytest .fixture (scope = "session" )
582+ def localstack_container ():
583+ """Session-scoped LocalStack container."""
584+ from testcontainers .localstack import LocalStackContainer
585+
586+ localstack = LocalStackContainer (image = "localstack/localstack:s3-latest" ).with_services (
587+ "s3"
588+ )
589+ localstack .start ()
590+
591+ try :
592+ yield localstack
593+ finally :
594+ localstack .stop ()
595+
596+
597+ @pytest .fixture (autouse = True , scope = "session" )
598+ def disable_s3_cache ():
599+ """Disable async_s3 cache during tests to avoid event loop issues."""
600+ from agents_api .clients import async_s3
601+
602+ # Check if the functions are wrapped with alru_cache
603+ if hasattr (async_s3 .setup , "__wrapped__" ):
604+ # Save original functions
605+ original_setup = async_s3 .setup .__wrapped__
606+ original_exists = async_s3 .exists .__wrapped__
607+ original_list_buckets = async_s3 .list_buckets .__wrapped__
608+
609+ # Replace cached functions with uncached versions
610+ async_s3 .setup = original_setup
611+ async_s3 .exists = original_exists
612+ async_s3 .list_buckets = original_list_buckets
613+
614+ yield
615+
616+
617+ @pytest .fixture
618+ def s3_client ():
619+ """S3 client fixture that works with TestClient's event loop."""
620+ # The TestClient's lifespan will create the S3 client
621+ # The disable_s3_cache fixture ensures we don't have event loop issues
622+ yield
605623
606624
607625@pytest .fixture
608626async def clean_secrets (pg_dsn , test_developer_id ):
609627 """Fixture to clean up secrets before and after tests."""
628+
610629 async def purge () -> None :
611630 pool = await create_db_pool (dsn = pg_dsn )
612631 try :
@@ -623,7 +642,7 @@ async def purge() -> None:
623642 finally :
624643 # pool is closed in *the same* loop it was created in
625644 await pool .close ()
626-
645+
627646 await purge ()
628647 yield
629648 await purge ()
@@ -635,4 +654,4 @@ def pytest_configure(config):
635654 config .addinivalue_line ("markers" , "slow: marks tests as slow" )
636655 config .addinivalue_line ("markers" , "integration: marks tests as integration tests" )
637656 config .addinivalue_line ("markers" , "unit: marks tests as unit tests" )
638- config .addinivalue_line ("markers" , "workflow: marks tests as workflow tests" )
657+ config .addinivalue_line ("markers" , "workflow: marks tests as workflow tests" )
0 commit comments