Skip to content

Commit 63ed61e

Browse files
committed
feat: add dspy integration
1 parent 32f6370 commit 63ed61e

File tree

3 files changed

+271
-3
lines changed

3 files changed

+271
-3
lines changed

parea/client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def wrap_anthropic_client(self, client: "Anthropic", integration: Optional[str]
9898
def auto_trace_openai_clients(self) -> None:
9999
import openai
100100

101+
openai._ModuleClient = patch_openai_client_classes(openai._ModuleClient, self)
101102
openai.OpenAI = patch_openai_client_classes(openai.OpenAI, self)
102103
openai.AsyncOpenAI = patch_openai_client_classes(openai.AsyncOpenAI, self)
103104
openai.AzureOpenAI = patch_openai_client_classes(openai.AzureOpenAI, self)
@@ -377,9 +378,6 @@ async def aget_experiment_trace_logs(self, experiment_uuid: str, filters: TraceL
377378
return structure_trace_logs_from_api(response.json())
378379

379380

380-
_initialized_parea_wrapper = False
381-
382-
383381
def patch_openai_client_classes(openai_client, parea_client: Parea):
384382
"""Creates a subclass of the given openai_client to always wrap it with Parea at instantiation."""
385383

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
from typing import Any, Callable, Dict, Mapping, Optional, Tuple
2+
3+
from copy import copy, deepcopy
4+
5+
from wrapt import BoundFunctionWrapper, FunctionWrapper, wrap_object
6+
7+
from parea import trace
8+
9+
_DSPY_MODULE_NAME = "dspy"
10+
_DSP_MODULE_NAME = "dsp"
11+
12+
13+
class DSPyInstrumentor:
14+
15+
def instrument(self) -> None:
16+
# Instrument LM (language model) calls
17+
from dsp.modules.lm import LM
18+
from dspy import Predict
19+
20+
language_model_classes = LM.__subclasses__()
21+
for lm in language_model_classes:
22+
wrap_object(
23+
module=_DSP_MODULE_NAME,
24+
name=lm.__name__ + ".basic_request",
25+
factory=CopyableFunctionWrapper,
26+
args=(_LMBasicRequestWrapper(),),
27+
)
28+
29+
# Predict is a concrete (non-abstract) class that may be invoked
30+
# directly, but DSPy also has subclasses of Predict that override the
31+
# forward method. We instrument both the forward methods of the base
32+
# class and all subclasses.
33+
wrap_object(
34+
module=_DSPY_MODULE_NAME,
35+
name="Predict.forward",
36+
factory=CopyableFunctionWrapper,
37+
args=(_PredictForwardWrapper(),),
38+
)
39+
40+
predict_subclasses = Predict.__subclasses__()
41+
for predict_subclass in predict_subclasses:
42+
wrap_object(
43+
module=_DSPY_MODULE_NAME,
44+
name=predict_subclass.__name__ + ".forward",
45+
factory=CopyableFunctionWrapper,
46+
args=(_PredictForwardWrapper(),),
47+
)
48+
49+
50+
wrap_object(
51+
module=_DSPY_MODULE_NAME,
52+
name="Retrieve.forward",
53+
factory=CopyableFunctionWrapper,
54+
args=(_RetrieverForwardWrapper(),),
55+
)
56+
57+
wrap_object(
58+
module=_DSPY_MODULE_NAME,
59+
# At this time, dspy.Module does not have an abstract forward
60+
# method, but assumes that user-defined subclasses implement the
61+
# forward method and invokes that method using __call__.
62+
name="Module.__call__",
63+
factory=CopyableFunctionWrapper,
64+
args=(_ModuleForwardWrapper(),),
65+
)
66+
67+
# At this time, there is no common parent class for retriever models as
68+
# there is for language models. We instrument the retriever models on a
69+
# case-by-case basis.
70+
wrap_object(
71+
module=_DSP_MODULE_NAME,
72+
name="ColBERTv2.__call__",
73+
factory=CopyableFunctionWrapper,
74+
args=(_RetrieverModelCallWrapper(),),
75+
)
76+
77+
78+
class CopyableBoundFunctionWrapper(BoundFunctionWrapper): # type: ignore
79+
"""
80+
A bound function wrapper that can be copied and deep-copied. When used to
81+
wrap a class method, this allows the entire class to be copied and
82+
deep-copied.
83+
84+
For reference, see
85+
https://github.com/GrahamDumpleton/wrapt/issues/86#issuecomment-426161271
86+
and
87+
https://wrapt.readthedocs.io/en/master/wrappers.html#custom-function-wrappers
88+
"""
89+
90+
def __copy__(self) -> "CopyableBoundFunctionWrapper":
91+
return CopyableBoundFunctionWrapper(copy(self.__wrapped__), self._self_instance, self._self_wrapper)
92+
93+
def __deepcopy__(self, memo: Dict[Any, Any]) -> "CopyableBoundFunctionWrapper":
94+
return CopyableBoundFunctionWrapper(deepcopy(self.__wrapped__, memo), self._self_instance, self._self_wrapper)
95+
96+
97+
class CopyableFunctionWrapper(FunctionWrapper): # type: ignore
98+
"""
99+
A function wrapper that can be copied and deep-copied. When used to wrap a
100+
class method, this allows the entire class to be copied and deep-copied.
101+
102+
For reference, see
103+
https://github.com/GrahamDumpleton/wrapt/issues/86#issuecomment-426161271
104+
and
105+
https://wrapt.readthedocs.io/en/master/wrappers.html#custom-function-wrappers
106+
"""
107+
108+
__bound_function_wrapper__ = CopyableBoundFunctionWrapper
109+
110+
def __copy__(self) -> "CopyableFunctionWrapper":
111+
return CopyableFunctionWrapper(copy(self.__wrapped__), self._self_wrapper)
112+
113+
def __deepcopy__(self, memo: Dict[Any, Any]) -> "CopyableFunctionWrapper":
114+
return CopyableFunctionWrapper(deepcopy(self.__wrapped__, memo), self._self_wrapper)
115+
116+
117+
class _LMBasicRequestWrapper:
118+
"""
119+
Wrapper for DSP LM.basic_request
120+
Captures all calls to language models (lm)
121+
"""
122+
123+
def __call__(
124+
self,
125+
wrapped: Callable[..., Any],
126+
instance: Any,
127+
args: Tuple[type, Any],
128+
kwargs: Mapping[str, Any],
129+
) -> Any:
130+
span_name = instance.__class__.__name__ + ".request"
131+
return trace(name=span_name)(wrapped)(*args, **kwargs)
132+
133+
134+
class _PredictForwardWrapper:
135+
"""
136+
A wrapper for the Predict class to have a chain span for each prediction
137+
"""
138+
139+
def __call__(
140+
self,
141+
wrapped: Callable[..., Any],
142+
instance: Any,
143+
args: Tuple[type, Any],
144+
kwargs: Mapping[str, Any],
145+
) -> Any:
146+
from dspy import Predict
147+
148+
# At this time, subclasses of Predict override the base class' forward
149+
# method and invoke the parent class' forward method from within the
150+
# overridden method. The forward method for both Predict and its
151+
# subclasses have been instrumented. To avoid creating duplicate spans
152+
# for a single invocation, we don't create a span for the base class'
153+
# forward method if the instance belongs to a proper subclass of Predict
154+
# with an overridden forward method.
155+
is_instance_of_predict_subclass = isinstance(instance, Predict) and (cls := instance.__class__) is not Predict
156+
has_overridden_forward_method = getattr(cls, "forward", None) is not getattr(Predict, "forward", None)
157+
wrapped_method_is_base_class_forward_method = wrapped.__qualname__ == Predict.forward.__qualname__
158+
if is_instance_of_predict_subclass and has_overridden_forward_method and wrapped_method_is_base_class_forward_method:
159+
return wrapped(*args, **kwargs)
160+
else:
161+
return trace(name=_get_predict_span_name(instance))(wrapped)(*args, **kwargs)
162+
163+
164+
class _ModuleForwardWrapper:
165+
"""
166+
Instruments the __call__ method of dspy.Module. DSPy end users define custom
167+
subclasses of Module implementing a forward method, loosely resembling the
168+
ergonomics of torch.nn.Module. The __call__ method of dspy.Module invokes
169+
the forward method of the user-defined subclass.
170+
"""
171+
172+
def __call__(
173+
self,
174+
wrapped: Callable[..., Any],
175+
instance: Any,
176+
args: Tuple[type, Any],
177+
kwargs: Mapping[str, Any],
178+
) -> Any:
179+
span_name = instance.__class__.__name__ + ".forward"
180+
return trace(name=span_name)(wrapped)(*args, **kwargs)
181+
182+
183+
class _RetrieverForwardWrapper:
184+
"""
185+
Instruments the forward method of dspy.Retrieve, which is a wrapper around
186+
retriever models such as ColBERTv2. At this time, Retrieve does not contain
187+
any additional information that cannot be gleaned from the underlying
188+
retriever model sub-span. It is, however, a user-facing concept, so we have
189+
decided to instrument it.
190+
"""
191+
192+
def __call__(
193+
self,
194+
wrapped: Callable[..., Any],
195+
instance: Any,
196+
args: Tuple[type, Any],
197+
kwargs: Mapping[str, Any],
198+
) -> Any:
199+
span_name = instance.__class__.__name__ + ".forward"
200+
return trace(name=span_name)(wrapped)(*args, **kwargs)
201+
202+
203+
class _RetrieverModelCallWrapper:
204+
"""
205+
Instruments the __call__ method of retriever models such as ColBERTv2.
206+
"""
207+
208+
def __call__(
209+
self,
210+
wrapped: Callable[..., Any],
211+
instance: Any,
212+
args: Tuple[type, Any],
213+
kwargs: Mapping[str, Any],
214+
) -> Any:
215+
class_name = instance.__class__.__name__
216+
span_name = class_name + ".__call__"
217+
return trace(name=span_name)(wrapped)(*args, **kwargs)
218+
219+
220+
def _get_predict_span_name(instance: Any) -> str:
221+
"""
222+
Gets the name for the Predict span, which are the composition of a Predict
223+
class or subclass and a user-defined signature. An example name would be
224+
"Predict(UserDefinedSignature).forward".
225+
"""
226+
class_name = str(instance.__class__.__name__)
227+
if (signature := getattr(instance, "signature", None)) and (signature_name := _get_signature_name(signature)):
228+
return f"{class_name}({signature_name}).forward"
229+
return f"{class_name}.forward"
230+
231+
232+
def _get_signature_name(signature: Any) -> Optional[str]:
233+
"""
234+
A best-effort attempt to get the name of a signature.
235+
"""
236+
if (
237+
# At the time of this writing, the __name__ attribute on signatures does
238+
# not return the user-defined class name, but __qualname__ does.
239+
qual_name := getattr(signature, "__qualname__", None)
240+
) is None:
241+
return None
242+
return str(qual_name.split(".")[-1])

parea/utils/universal_encoder.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,32 @@ class UniversalEncoder(json.JSONEncoder):
5555
A JSON encoder that can handle additional types such as dataclasses, attrs, and more.
5656
"""
5757

58+
def handle_dspy_response(self, obj) -> Any:
59+
try:
60+
import dspy
61+
except ImportError:
62+
return None
63+
64+
from dsp.templates.template_v3 import Template
65+
from dspy.primitives.example import Example
66+
67+
if hasattr(obj, "completions") and hasattr(obj.completions, "_completions"):
68+
# multiple completions
69+
return obj.completions._completions
70+
elif hasattr(obj, "_asdict"):
71+
# convert namedtuples to dictionaries
72+
return obj._asdict()
73+
elif isinstance(obj, Example):
74+
# handles Prediction objects and other sub-classes of Example
75+
return getattr(obj, "_store", {})
76+
elif isinstance(obj, Template):
77+
return {
78+
"fields": [self.default(field) for field in obj.fields],
79+
"instructions": obj.instructions,
80+
}
81+
else:
82+
return None
83+
5884
def default(self, obj: Any):
5985
if isinstance(obj, str):
6086
return obj
@@ -92,6 +118,8 @@ def default(self, obj: Any):
92118
return obj.tolist()
93119
elif is_pandas_instance(obj):
94120
return obj.to_dict(orient="records")
121+
elif dspy_response := self.handle_dspy_response(obj):
122+
return dspy_response
95123
else:
96124
return super().default(obj)
97125

0 commit comments

Comments
 (0)