Skip to content

Commit 0ec56d3

Browse files
authored
fix: contextvars propagation in ThreadPoolExecutor calls (#60)
* fix: contextvars propagation in ThreadPoolExecutor calls
1 parent 04f7352 commit 0ec56d3

File tree

3 files changed

+94
-13
lines changed

3 files changed

+94
-13
lines changed

src/guardrails/resources/chat/chat.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import asyncio
44
from collections.abc import AsyncIterator
55
from concurrent.futures import ThreadPoolExecutor
6+
from contextvars import copy_context
7+
from functools import partial
68
from typing import Any
79

810
from ..._base_client import GuardrailsBaseClient
@@ -93,10 +95,10 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals
9395
if supports_safety_identifier(self._client._resource_client):
9496
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
9597

96-
llm_future = executor.submit(
97-
self._client._resource_client.chat.completions.create,
98-
**llm_kwargs,
99-
)
98+
llm_call_fn = partial(self._client._resource_client.chat.completions.create, **llm_kwargs)
99+
ctx = copy_context()
100+
llm_future = executor.submit(ctx.run, llm_call_fn)
101+
100102
input_results = self._client._run_stage_guardrails(
101103
"input",
102104
latest_message,

src/guardrails/resources/responses/responses.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import asyncio
44
from collections.abc import AsyncIterator
55
from concurrent.futures import ThreadPoolExecutor
6+
from contextvars import copy_context
7+
from functools import partial
68
from typing import Any
79

810
from pydantic import BaseModel
@@ -75,10 +77,10 @@ def create(
7577
if supports_safety_identifier(self._client._resource_client):
7678
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
7779

78-
llm_future = executor.submit(
79-
self._client._resource_client.responses.create,
80-
**llm_kwargs,
81-
)
80+
llm_call_fn = partial(self._client._resource_client.responses.create, **llm_kwargs)
81+
ctx = copy_context()
82+
llm_future = executor.submit(ctx.run, llm_call_fn)
83+
8284
input_results = self._client._run_stage_guardrails(
8385
"input",
8486
latest_message,
@@ -141,10 +143,10 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM
141143
if supports_safety_identifier(self._client._resource_client):
142144
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
143145

144-
llm_future = executor.submit(
145-
self._client._resource_client.responses.parse,
146-
**llm_kwargs,
147-
)
146+
llm_call_fn = partial(self._client._resource_client.responses.parse, **llm_kwargs)
147+
ctx = copy_context()
148+
llm_future = executor.submit(ctx.run, llm_call_fn)
149+
148150
input_results = self._client._run_stage_guardrails(
149151
"input",
150152
latest_message,

tests/unit/test_context.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
from concurrent.futures import ThreadPoolExecutor
6+
from contextvars import ContextVar, copy_context
57
from dataclasses import FrozenInstanceError
68

79
import pytest
@@ -34,4 +36,79 @@ def test_context_is_immutable() -> None:
3436
context = GuardrailsContext(guardrail_llm=_StubClient())
3537

3638
with pytest.raises(FrozenInstanceError):
37-
context.guardrail_llm = None # type: ignore[misc]
39+
context.guardrail_llm = None
40+
41+
42+
def test_contextvar_propagates_with_copy_context() -> None:
43+
test_var: ContextVar[str | None] = ContextVar("test_var", default=None)
44+
test_var.set("test_value")
45+
46+
def get_contextvar():
47+
return test_var.get()
48+
49+
ctx = copy_context()
50+
result = ctx.run(get_contextvar)
51+
assert result == "test_value" # noqa: S101
52+
53+
54+
def test_contextvar_propagates_with_threadpool() -> None:
55+
test_var: ContextVar[str | None] = ContextVar("test_var", default=None)
56+
test_var.set("thread_test")
57+
58+
def get_contextvar():
59+
return test_var.get()
60+
61+
ctx = copy_context()
62+
with ThreadPoolExecutor(max_workers=1) as executor:
63+
future = executor.submit(ctx.run, get_contextvar)
64+
result = future.result()
65+
66+
assert result == "thread_test" # noqa: S101
67+
68+
69+
def test_guardrails_context_propagates_with_copy_context() -> None:
70+
context = GuardrailsContext(guardrail_llm=_StubClient())
71+
set_context(context)
72+
73+
def get_guardrails_context():
74+
return get_context()
75+
76+
ctx = copy_context()
77+
result = ctx.run(get_guardrails_context)
78+
assert result is context # noqa: S101
79+
80+
clear_context()
81+
82+
83+
def test_guardrails_context_propagates_with_threadpool() -> None:
84+
context = GuardrailsContext(guardrail_llm=_StubClient())
85+
set_context(context)
86+
87+
def get_guardrails_context():
88+
return get_context()
89+
90+
ctx = copy_context()
91+
with ThreadPoolExecutor(max_workers=1) as executor:
92+
future = executor.submit(ctx.run, get_guardrails_context)
93+
result = future.result()
94+
95+
assert result is context # noqa: S101
96+
97+
clear_context()
98+
99+
100+
def test_multiple_contextvars_propagate_with_threadpool() -> None:
101+
var1: ContextVar[str | None] = ContextVar("var1", default=None)
102+
var2: ContextVar[int | None] = ContextVar("var2", default=None)
103+
var1.set("value1")
104+
var2.set(42)
105+
106+
def get_multiple_contextvars():
107+
return (var1.get(), var2.get())
108+
109+
ctx = copy_context()
110+
with ThreadPoolExecutor(max_workers=1) as executor:
111+
future = executor.submit(ctx.run, get_multiple_contextvars)
112+
result = future.result()
113+
114+
assert result == ("value1", 42) # noqa: S101

0 commit comments

Comments
 (0)