@@ -23,7 +23,7 @@ def instrument(self) -> None:
2323 module = _DSP_MODULE_NAME ,
2424 name = lm .__name__ + ".basic_request" ,
2525 factory = CopyableFunctionWrapper ,
26- args = (_LMBasicRequestWrapper ( ),),
26+ args = (_GeneralDSPyWrapper ( 'request' ),),
2727 )
2828
2929 # Predict is a concrete (non-abstract) class that may be invoked
@@ -51,7 +51,7 @@ def instrument(self) -> None:
5151 module = _DSPY_MODULE_NAME ,
5252 name = "Retrieve.forward" ,
5353 factory = CopyableFunctionWrapper ,
54- args = (_RetrieverForwardWrapper ( ),),
54+ args = (_GeneralDSPyWrapper ( 'forward' ),),
5555 )
5656
5757 wrap_object (
@@ -61,7 +61,7 @@ def instrument(self) -> None:
6161 # forward method and invokes that method using __call__.
6262 name = "Module.__call__" ,
6363 factory = CopyableFunctionWrapper ,
64- args = (_ModuleForwardWrapper ( ),),
64+ args = (_GeneralDSPyWrapper ( 'forward' ),),
6565 )
6666
6767 # At this time, there is no common parent class for retriever models as
@@ -71,7 +71,7 @@ def instrument(self) -> None:
7171 module = _DSP_MODULE_NAME ,
7272 name = "ColBERTv2.__call__" ,
7373 factory = CopyableFunctionWrapper ,
74- args = (_RetrieverModelCallWrapper ( ),),
74+ args = (_GeneralDSPyWrapper ( '__call__' ),),
7575 )
7676
7777
@@ -114,11 +114,9 @@ def __deepcopy__(self, memo: Dict[Any, Any]) -> "CopyableFunctionWrapper":
114114 return CopyableFunctionWrapper (deepcopy (self .__wrapped__ , memo ), self ._self_wrapper )
115115
116116
117- class _LMBasicRequestWrapper :
118- """
119- Wrapper for DSP LM.basic_request
120- Captures all calls to language models (lm)
121- """
117+ class _GeneralDSPyWrapper :
118+ def __init__ (self , method_name : str ):
119+ self ._method_name = method_name
122120
123121 def __call__ (
124122 self ,
@@ -127,7 +125,7 @@ def __call__(
127125 args : Tuple [type , Any ],
128126 kwargs : Mapping [str , Any ],
129127 ) -> Any :
130- span_name = instance .__class__ .__name__ + ".request"
128+ span_name = instance .__class__ .__name__ + "." + self . _method_name
131129 return trace (name = span_name )(wrapped )(* args , ** kwargs )
132130
133131
@@ -161,62 +159,6 @@ def __call__(
161159 return trace (name = _get_predict_span_name (instance ))(wrapped )(* args , ** kwargs )
162160
163161
164- class _ModuleForwardWrapper :
165- """
166- Instruments the __call__ method of dspy.Module. DSPy end users define custom
167- subclasses of Module implementing a forward method, loosely resembling the
168- ergonomics of torch.nn.Module. The __call__ method of dspy.Module invokes
169- the forward method of the user-defined subclass.
170- """
171-
172- def __call__ (
173- self ,
174- wrapped : Callable [..., Any ],
175- instance : Any ,
176- args : Tuple [type , Any ],
177- kwargs : Mapping [str , Any ],
178- ) -> Any :
179- span_name = instance .__class__ .__name__ + ".forward"
180- return trace (name = span_name )(wrapped )(* args , ** kwargs )
181-
182-
183- class _RetrieverForwardWrapper :
184- """
185- Instruments the forward method of dspy.Retrieve, which is a wrapper around
186- retriever models such as ColBERTv2. At this time, Retrieve does not contain
187- any additional information that cannot be gleaned from the underlying
188- retriever model sub-span. It is, however, a user-facing concept, so we have
189- decided to instrument it.
190- """
191-
192- def __call__ (
193- self ,
194- wrapped : Callable [..., Any ],
195- instance : Any ,
196- args : Tuple [type , Any ],
197- kwargs : Mapping [str , Any ],
198- ) -> Any :
199- span_name = instance .__class__ .__name__ + ".forward"
200- return trace (name = span_name )(wrapped )(* args , ** kwargs )
201-
202-
203- class _RetrieverModelCallWrapper :
204- """
205- Instruments the __call__ method of retriever models such as ColBERTv2.
206- """
207-
208- def __call__ (
209- self ,
210- wrapped : Callable [..., Any ],
211- instance : Any ,
212- args : Tuple [type , Any ],
213- kwargs : Mapping [str , Any ],
214- ) -> Any :
215- class_name = instance .__class__ .__name__
216- span_name = class_name + ".__call__"
217- return trace (name = span_name )(wrapped )(* args , ** kwargs )
218-
219-
220162def _get_predict_span_name (instance : Any ) -> str :
221163 """
222164 Gets the name for the Predict span, which are the composition of a Predict
0 commit comments