Skip to content

Commit eadc3db

Browse files
authored
Merge pull request #803 from parea-ai/PAI-918-dspy-integration
Pai 918 dspy integration
2 parents 32f6370 + 01e2863 commit eadc3db

File tree

6 files changed

+470
-8
lines changed

6 files changed

+470
-8
lines changed

parea/client.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,33 @@ def wrap_anthropic_client(self, client: "Anthropic", integration: Optional[str]
9595
if integration:
9696
self._client.add_integration(integration)
9797

98-
def auto_trace_openai_clients(self) -> None:
98+
def auto_trace_openai_clients(self, integration: Optional[str]) -> None:
9999
import openai
100100

101+
openai._ModuleClient = patch_openai_client_classes(openai._ModuleClient, self)
101102
openai.OpenAI = patch_openai_client_classes(openai.OpenAI, self)
102103
openai.AsyncOpenAI = patch_openai_client_classes(openai.AsyncOpenAI, self)
103104
openai.AzureOpenAI = patch_openai_client_classes(openai.AzureOpenAI, self)
104105
openai.AsyncAzureOpenAI = patch_openai_client_classes(openai.AsyncAzureOpenAI, self)
105106

107+
if integration:
108+
self._client.add_integration(integration)
109+
110+
def trace_dspy(self):
111+
from parea.utils.trace_integrations.dspy import DSPyInstrumentor
112+
113+
try:
114+
import openai
115+
116+
if openai.version.__version__.startswith("0."):
117+
self.wrap_openai_client(openai, "dspy")
118+
else:
119+
self.auto_trace_openai_clients("dspy")
120+
except ImportError:
121+
pass
122+
123+
DSPyInstrumentor().instrument()
124+
106125
def integrate_with_sglang(self):
107126
self.auto_trace_openai_clients()
108127
self._client.add_integration("sglang")
@@ -377,9 +396,6 @@ async def aget_experiment_trace_logs(self, experiment_uuid: str, filters: TraceL
377396
return structure_trace_logs_from_api(response.json())
378397

379398

380-
_initialized_parea_wrapper = False
381-
382-
383399
def patch_openai_client_classes(openai_client, parea_client: Parea):
384400
"""Creates a subclass of the given openai_client to always wrap it with Parea at instantiation."""
385401

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import os
2+
3+
import dspy
4+
from dotenv import load_dotenv
5+
6+
from parea import Parea
7+
8+
load_dotenv()
9+
10+
p = Parea(api_key=os.getenv("PAREA_API_KEY"), project_name="testing")
11+
p.trace_dspy()
12+
13+
gpt3_turbo = dspy.OpenAI(model="gpt-3.5-turbo-1106", max_tokens=300)
14+
dspy.configure(lm=gpt3_turbo)
15+
# print(gpt3_turbo("hello! this is a raw prompt to GPT-3.5."))
16+
17+
# Define a retrieval model server to send retrieval requests to
18+
colbertv2_wiki17_abstracts = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts")
19+
20+
# Configure retrieval server internally
21+
dspy.settings.configure(rm=colbertv2_wiki17_abstracts)
22+
23+
24+
# Define a simple signature for basic question answering
25+
class BasicQA(dspy.Signature):
26+
"""Answer questions with short factoid answers."""
27+
28+
question = dspy.InputField()
29+
answer = dspy.OutputField(desc="often between 1 and 5 words")
30+
31+
32+
# Pass signature to Predict module
33+
generate_answer = dspy.Predict(BasicQA)
34+
35+
# Call the predictor on a particular input.
36+
question = "What is the color of the sky?"
37+
pred = generate_answer(question=question)
38+
39+
print(f"Question: {question}")
40+
print(f"Predicted Answer: {pred.answer}")
41+
question = "What's something great about the ColBERT retrieval model ?!?abc"
42+
43+
# 1) Declare with a signature, and pass some config.
44+
classify = dspy.ChainOfThought("question -> answer", n=1)
45+
46+
# 2) Call with input argument.
47+
response = classify(question=question)
48+
49+
# 3) Access the outputs.
50+
print(response.completions.answer)
51+
52+
53+
# Define a simple signature for basic question answering
54+
class BasicQA(dspy.Signature):
55+
"""Answer questions with short factoid answers."""
56+
57+
question = dspy.InputField()
58+
answer = dspy.OutputField(desc="often between 1 and 5 words")
59+
60+
61+
# Pass signature to ChainOfThought module
62+
generate_answer = dspy.ChainOfThought(BasicQA)
63+
64+
# Call the predictor on a particular input.
65+
question = "What is the color of the sky?12"
66+
pred = generate_answer(question=question)
67+
68+
print(f"Question: {question}")
69+
print(f"Predicted Answer: {pred.answer}")
70+
71+
72+
class BasicQA(dspy.Signature):
73+
"""Answer questions with short factoid answers."""
74+
75+
question = dspy.InputField()
76+
answer = dspy.OutputField(desc="often between 1 and 5 words")
77+
78+
79+
# Example completions generated by a model for reference
80+
completions = [
81+
dspy.Prediction(rationale=" I recall that during clear days, the sky often appears this colo12r", answer="blue"),
82+
dspy.Prediction(rationale=" Based on common knowledge, I believe the sky is typically seen 12as this color", answer="green"),
83+
dspy.Prediction(rationale=" From images and depictions in media, the sky is frequently42 represented with this hue", answer="blue"),
84+
]
85+
86+
# Pass signature to MultiChainComparison module
87+
compare_answers = dspy.MultiChainComparison(BasicQA)
88+
89+
# Call the MultiChainComparison on the completions
90+
question = " What is the color of th e sky14?"
91+
final_pred = compare_answers(completions, question=question)
92+
93+
print(f"Question: {question}")
94+
print(f"Final Predicted Answer (after comparison): {final_pred.answer}")
95+
print(f"Final Rationale: {final_pred.rationale}")
96+
97+
98+
# Define a simple signature for basic question answering
99+
class GenerateAnswer(dspy.Signature):
100+
"""Answer questions with short factoid answers."""
101+
102+
question = dspy.InputField()
103+
answer = dspy.OutputField(desc="often between 1 and 5 words")
104+
105+
106+
# Pass signature to ProgramOfThought Module
107+
pot = dspy.ProgramOfThought(GenerateAnswer)
108+
109+
# Call the ProgramOfThought module on a particular input
110+
question = "Sarah has 5 applez. She buys 123 more apples from the store. How many apples does Sarah have now?"
111+
result = pot(question=question)
112+
113+
print(f"Question: {question}")
114+
print(f"Final Predicted Answer (after ProgramOfThought process): {result.answer}")
115+
116+
117+
# Define a simple signature for basic question answering
118+
class BasicQA(dspy.Signature):
119+
"""Answer questions with short factoid answers."""
120+
121+
question = dspy.InputField()
122+
answer = dspy.OutputField(desc="often between 1 and 5 words")
123+
124+
125+
# Pass signature to ReAct module
126+
react_module = dspy.ReAct(BasicQA, tools=[])
127+
128+
# Call the ReAct module on a particular input
129+
question = "What is the color of the 2 skies?"
130+
result = react_module(question=question)
131+
132+
print(f"Question: {question}")
133+
print(f"Final Predicted Answer (after ReAct process): {result.answer}")
134+
135+
136+
query = "Where was the first FIFA World Cup held?12"
137+
138+
139+
# Call the retriever on a particular query.
140+
retrieve = dspy.Retrieve(k=3)
141+
topK_passages = retrieve(query).passages
142+
143+
print(f"Top {retrieve.k} passages for question: {query} \n", "-" * 30, "\n")
144+
145+
for idx, passage in enumerate(topK_passages):
146+
print(f"{idx+1}]", passage, "\n")
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from typing import Any, Callable, Dict, Mapping, Optional, Tuple
2+
3+
from copy import copy, deepcopy
4+
5+
from wrapt import BoundFunctionWrapper, FunctionWrapper, wrap_object
6+
7+
from parea import trace
8+
9+
_DSPY_MODULE_NAME = "dspy"
10+
_DSP_MODULE_NAME = "dsp"
11+
12+
13+
class DSPyInstrumentor:
14+
15+
def instrument(self) -> None:
16+
# Instrument LM (language model) calls
17+
from dsp.modules.lm import LM
18+
from dspy import Predict
19+
20+
language_model_classes = LM.__subclasses__()
21+
for lm in language_model_classes:
22+
wrap_object(
23+
module=_DSP_MODULE_NAME,
24+
name=lm.__name__ + ".basic_request",
25+
factory=CopyableFunctionWrapper,
26+
args=(_GeneralDSPyWrapper("request"),),
27+
)
28+
29+
# Predict is a concrete (non-abstract) class that may be invoked
30+
# directly, but DSPy also has subclasses of Predict that override the
31+
# forward method. We instrument both the forward methods of the base
32+
# class and all subclasses.
33+
wrap_object(
34+
module=_DSPY_MODULE_NAME,
35+
name="Predict.forward",
36+
factory=CopyableFunctionWrapper,
37+
args=(_PredictForwardWrapper(),),
38+
)
39+
40+
predict_subclasses = Predict.__subclasses__()
41+
for predict_subclass in predict_subclasses:
42+
wrap_object(
43+
module=_DSPY_MODULE_NAME,
44+
name=predict_subclass.__name__ + ".forward",
45+
factory=CopyableFunctionWrapper,
46+
args=(_PredictForwardWrapper(),),
47+
)
48+
49+
wrap_object(
50+
module=_DSPY_MODULE_NAME,
51+
name="Retrieve.forward",
52+
factory=CopyableFunctionWrapper,
53+
args=(_GeneralDSPyWrapper("forward"),),
54+
)
55+
56+
wrap_object(
57+
module=_DSPY_MODULE_NAME,
58+
# At this time, dspy.Module does not have an abstract forward
59+
# method, but assumes that user-defined subclasses implement the
60+
# forward method and invokes that method using __call__.
61+
name="Module.__call__",
62+
factory=CopyableFunctionWrapper,
63+
args=(_GeneralDSPyWrapper("forward"),),
64+
)
65+
66+
# At this time, there is no common parent class for retriever models as
67+
# there is for language models. We instrument the retriever models on a
68+
# case-by-case basis.
69+
wrap_object(
70+
module=_DSP_MODULE_NAME,
71+
name="ColBERTv2.__call__",
72+
factory=CopyableFunctionWrapper,
73+
args=(_GeneralDSPyWrapper("__call__"),),
74+
)
75+
76+
77+
class CopyableBoundFunctionWrapper(BoundFunctionWrapper): # type: ignore
78+
"""
79+
A bound function wrapper that can be copied and deep-copied. When used to
80+
wrap a class method, this allows the entire class to be copied and
81+
deep-copied.
82+
83+
For reference, see
84+
https://github.com/GrahamDumpleton/wrapt/issues/86#issuecomment-426161271
85+
and
86+
https://wrapt.readthedocs.io/en/master/wrappers.html#custom-function-wrappers
87+
"""
88+
89+
def __copy__(self) -> "CopyableBoundFunctionWrapper":
90+
return CopyableBoundFunctionWrapper(copy(self.__wrapped__), self._self_instance, self._self_wrapper)
91+
92+
def __deepcopy__(self, memo: Dict[Any, Any]) -> "CopyableBoundFunctionWrapper":
93+
return CopyableBoundFunctionWrapper(deepcopy(self.__wrapped__, memo), self._self_instance, self._self_wrapper)
94+
95+
96+
class CopyableFunctionWrapper(FunctionWrapper): # type: ignore
97+
"""
98+
A function wrapper that can be copied and deep-copied. When used to wrap a
99+
class method, this allows the entire class to be copied and deep-copied.
100+
101+
For reference, see
102+
https://github.com/GrahamDumpleton/wrapt/issues/86#issuecomment-426161271
103+
and
104+
https://wrapt.readthedocs.io/en/master/wrappers.html#custom-function-wrappers
105+
"""
106+
107+
__bound_function_wrapper__ = CopyableBoundFunctionWrapper
108+
109+
def __copy__(self) -> "CopyableFunctionWrapper":
110+
return CopyableFunctionWrapper(copy(self.__wrapped__), self._self_wrapper)
111+
112+
def __deepcopy__(self, memo: Dict[Any, Any]) -> "CopyableFunctionWrapper":
113+
return CopyableFunctionWrapper(deepcopy(self.__wrapped__, memo), self._self_wrapper)
114+
115+
116+
class _GeneralDSPyWrapper:
117+
def __init__(self, method_name: str):
118+
self._method_name = method_name
119+
120+
def __call__(
121+
self,
122+
wrapped: Callable[..., Any],
123+
instance: Any,
124+
args: Tuple[type, Any],
125+
kwargs: Mapping[str, Any],
126+
) -> Any:
127+
span_name = instance.__class__.__name__ + "." + self._method_name
128+
return trace(name=span_name)(wrapped)(*args, **kwargs)
129+
130+
131+
class _PredictForwardWrapper:
132+
"""
133+
A wrapper for the Predict class to have a chain span for each prediction
134+
"""
135+
136+
def __call__(
137+
self,
138+
wrapped: Callable[..., Any],
139+
instance: Any,
140+
args: Tuple[type, Any],
141+
kwargs: Mapping[str, Any],
142+
) -> Any:
143+
from dspy import Predict
144+
145+
# At this time, subclasses of Predict override the base class' forward
146+
# method and invoke the parent class' forward method from within the
147+
# overridden method. The forward method for both Predict and its
148+
# subclasses have been instrumented. To avoid creating duplicate spans
149+
# for a single invocation, we don't create a span for the base class'
150+
# forward method if the instance belongs to a proper subclass of Predict
151+
# with an overridden forward method.
152+
is_instance_of_predict_subclass = isinstance(instance, Predict) and (cls := instance.__class__) is not Predict
153+
has_overridden_forward_method = getattr(cls, "forward", None) is not getattr(Predict, "forward", None)
154+
wrapped_method_is_base_class_forward_method = wrapped.__qualname__ == Predict.forward.__qualname__
155+
if is_instance_of_predict_subclass and has_overridden_forward_method and wrapped_method_is_base_class_forward_method:
156+
return wrapped(*args, **kwargs)
157+
else:
158+
return trace(name=_get_predict_span_name(instance))(wrapped)(*args, **kwargs)
159+
160+
161+
def _get_predict_span_name(instance: Any) -> str:
162+
"""
163+
Gets the name for the Predict span, which are the composition of a Predict
164+
class or subclass and a user-defined signature. An example name would be
165+
"Predict(UserDefinedSignature).forward".
166+
"""
167+
class_name = str(instance.__class__.__name__)
168+
if (signature := getattr(instance, "signature", None)) and (signature_name := _get_signature_name(signature)):
169+
return f"{class_name}({signature_name}).forward"
170+
return f"{class_name}.forward"
171+
172+
173+
def _get_signature_name(signature: Any) -> Optional[str]:
174+
"""
175+
A best-effort attempt to get the name of a signature.
176+
"""
177+
if (
178+
# At the time of this writing, the __name__ attribute on signatures does
179+
# not return the user-defined class name, but __qualname__ does.
180+
qual_name := getattr(signature, "__qualname__", None)
181+
) is None:
182+
return None
183+
return str(qual_name.split(".")[-1])

0 commit comments

Comments
 (0)