Skip to content

Commit 8b80208

Browse files
committed
refactor: simplify
1 parent 63ed61e commit 8b80208

File tree

1 file changed

+8
-66
lines changed
  • parea/utils/trace_integrations

1 file changed

+8
-66
lines changed

parea/utils/trace_integrations/dspy.py

Lines changed: 8 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def instrument(self) -> None:
2323
module=_DSP_MODULE_NAME,
2424
name=lm.__name__ + ".basic_request",
2525
factory=CopyableFunctionWrapper,
26-
args=(_LMBasicRequestWrapper(),),
26+
args=(_GeneralDSPyWrapper('request'),),
2727
)
2828

2929
# Predict is a concrete (non-abstract) class that may be invoked
@@ -51,7 +51,7 @@ def instrument(self) -> None:
5151
module=_DSPY_MODULE_NAME,
5252
name="Retrieve.forward",
5353
factory=CopyableFunctionWrapper,
54-
args=(_RetrieverForwardWrapper(),),
54+
args=(_GeneralDSPyWrapper('forward'),),
5555
)
5656

5757
wrap_object(
@@ -61,7 +61,7 @@ def instrument(self) -> None:
6161
# forward method and invokes that method using __call__.
6262
name="Module.__call__",
6363
factory=CopyableFunctionWrapper,
64-
args=(_ModuleForwardWrapper(),),
64+
args=(_GeneralDSPyWrapper('forward'),),
6565
)
6666

6767
# At this time, there is no common parent class for retriever models as
@@ -71,7 +71,7 @@ def instrument(self) -> None:
7171
module=_DSP_MODULE_NAME,
7272
name="ColBERTv2.__call__",
7373
factory=CopyableFunctionWrapper,
74-
args=(_RetrieverModelCallWrapper(),),
74+
args=(_GeneralDSPyWrapper('__call__'),),
7575
)
7676

7777

@@ -114,11 +114,9 @@ def __deepcopy__(self, memo: Dict[Any, Any]) -> "CopyableFunctionWrapper":
114114
return CopyableFunctionWrapper(deepcopy(self.__wrapped__, memo), self._self_wrapper)
115115

116116

117-
class _LMBasicRequestWrapper:
118-
"""
119-
Wrapper for DSP LM.basic_request
120-
Captures all calls to language models (lm)
121-
"""
117+
class _GeneralDSPyWrapper:
118+
def __init__(self, method_name: str):
119+
self._method_name = method_name
122120

123121
def __call__(
124122
self,
@@ -127,7 +125,7 @@ def __call__(
127125
args: Tuple[type, Any],
128126
kwargs: Mapping[str, Any],
129127
) -> Any:
130-
span_name = instance.__class__.__name__ + ".request"
128+
span_name = instance.__class__.__name__ + "." + self._method_name
131129
return trace(name=span_name)(wrapped)(*args, **kwargs)
132130

133131

@@ -161,62 +159,6 @@ def __call__(
161159
return trace(name=_get_predict_span_name(instance))(wrapped)(*args, **kwargs)
162160

163161

164-
class _ModuleForwardWrapper:
165-
"""
166-
Instruments the __call__ method of dspy.Module. DSPy end users define custom
167-
subclasses of Module implementing a forward method, loosely resembling the
168-
ergonomics of torch.nn.Module. The __call__ method of dspy.Module invokes
169-
the forward method of the user-defined subclass.
170-
"""
171-
172-
def __call__(
173-
self,
174-
wrapped: Callable[..., Any],
175-
instance: Any,
176-
args: Tuple[type, Any],
177-
kwargs: Mapping[str, Any],
178-
) -> Any:
179-
span_name = instance.__class__.__name__ + ".forward"
180-
return trace(name=span_name)(wrapped)(*args, **kwargs)
181-
182-
183-
class _RetrieverForwardWrapper:
184-
"""
185-
Instruments the forward method of dspy.Retrieve, which is a wrapper around
186-
retriever models such as ColBERTv2. At this time, Retrieve does not contain
187-
any additional information that cannot be gleaned from the underlying
188-
retriever model sub-span. It is, however, a user-facing concept, so we have
189-
decided to instrument it.
190-
"""
191-
192-
def __call__(
193-
self,
194-
wrapped: Callable[..., Any],
195-
instance: Any,
196-
args: Tuple[type, Any],
197-
kwargs: Mapping[str, Any],
198-
) -> Any:
199-
span_name = instance.__class__.__name__ + ".forward"
200-
return trace(name=span_name)(wrapped)(*args, **kwargs)
201-
202-
203-
class _RetrieverModelCallWrapper:
204-
"""
205-
Instruments the __call__ method of retriever models such as ColBERTv2.
206-
"""
207-
208-
def __call__(
209-
self,
210-
wrapped: Callable[..., Any],
211-
instance: Any,
212-
args: Tuple[type, Any],
213-
kwargs: Mapping[str, Any],
214-
) -> Any:
215-
class_name = instance.__class__.__name__
216-
span_name = class_name + ".__call__"
217-
return trace(name=span_name)(wrapped)(*args, **kwargs)
218-
219-
220162
def _get_predict_span_name(instance: Any) -> str:
221163
"""
222164
Gets the name for the Predict span, which are the composition of a Predict

0 commit comments

Comments
 (0)