diff --git a/tests/test_session.py b/tests/test_session.py index 032f2bb38..b9e09d750 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -398,3 +398,86 @@ async def test_session_memory_rejects_both_session_and_list_input(runner_method) assert "manually manage conversation history" in str(exc_info.value) session.close() + +@pytest.mark.asyncio +async def test_sqlite_session_unicode_content(): + """Test that session correctly stores and retrieves unicode/non-ASCII content.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_unicode.db" + session_id = "unicode_test" + session = SQLiteSession(session_id, db_path) + + # Add unicode content to the session + items: list[TResponseInputItem] = [ + {"role": "user", "content": "こんにちは"}, + {"role": "assistant", "content": "😊👍"}, + {"role": "user", "content": "Привет"}, + ] + await session.add_items(items) + + # Retrieve items and verify unicode content + retrieved = await session.get_items() + assert retrieved[0].get("content") == "こんにちは" + assert retrieved[1].get("content") == "😊👍" + assert retrieved[2].get("content") == "Привет" + session.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_special_characters_and_sql_injection(): + """ + Test that session safely stores and retrieves items with special characters and SQL keywords. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_special_chars.db" + session_id = "special_chars_test" + session = SQLiteSession(session_id, db_path) + + # Add items with special characters and SQL keywords + items: list[TResponseInputItem] = [ + {"role": "user", "content": "O'Reilly"}, + {"role": "assistant", "content": "DROP TABLE sessions;"}, + {"role": "user", "content": ( + '"SELECT * FROM users WHERE name = \"admin\";"' + )}, + {"role": "assistant", "content": "Robert'); DROP TABLE students;--"}, + {"role": "user", "content": "Normal message"}, + ] + await session.add_items(items) + retrieved = await session.get_items() + assert len(retrieved) == len(items) + for i, item in enumerate(items): + assert retrieved[i].get("content") == item["content"] + session.close() + +@pytest.mark.asyncio +async def test_sqlite_session_concurrent_access(): + """ + Test concurrent access to the same session to verify data integrity. + """ + import concurrent.futures + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_concurrent.db" + session_id = "concurrent_test" + session = SQLiteSession(session_id, db_path) + + # Add initial item + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Message {i}"} for i in range(10) + ] + + # Use ThreadPoolExecutor to simulate concurrent writes + def add_item(item): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(session.add_items([item])) + loop.close() + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + executor.map(add_item, items) + + # Retrieve all items and verify all are present + retrieved = await session.get_items() + contents = {item.get("content") for item in retrieved} + expected = {f"Message {i}" for i in range(10)} + assert contents == expected + session.close()