diff --git a/openai/__init__.py b/openai/__init__.py index f80085eada..5180f88fc6 100644 --- a/openai/__init__.py +++ b/openai/__init__.py @@ -26,6 +26,7 @@ if TYPE_CHECKING: from aiohttp import ClientSession + import requests api_key = os.environ.get("OPENAI_API_KEY") # Path of a file with an API key, whose contents can change. Supercedes @@ -47,6 +48,10 @@ debug = False log = None # Set to either 'debug' or 'info', controls console logging +session: ContextVar[Optional["requests.Session"]] = ContextVar( + "requests-session", default=None +) + aiosession: ContextVar[Optional["ClientSession"]] = ContextVar( "aiohttp-session", default=None ) # Acts as a global aiohttp ClientSession that reuses connections. diff --git a/openai/api_requestor.py b/openai/api_requestor.py index 827b73b78e..4a9124e70f 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -511,7 +511,10 @@ def request_raw( ) if not hasattr(_thread_context, "session"): - _thread_context.session = _make_session() + session = openai.session.get() + if not session: + session = _make_session() + _thread_context.session = session try: result = _thread_context.session.request( method, diff --git a/openai/tests/test_endpoints.py b/openai/tests/test_endpoints.py index c3fc1094bb..4deb227934 100644 --- a/openai/tests/test_endpoints.py +++ b/openai/tests/test_endpoints.py @@ -2,6 +2,7 @@ import json import pytest +import requests import openai from openai import error @@ -86,3 +87,11 @@ def test_timeout_does_not_error(): model="ada", request_timeout=10, ) + + +def test_user_passed_session(): + with requests.Session() as session: + openai.session.set(session) + + completion = openai.Completion.create(prompt="test", model="ada") + assert completion