Skip to content

Commit 854e320

Browse files
DSPy enhancements (#362)
* fix dspy issue * DSPy enhancements v1 * Add support for LiteLLM * update readme * DSPy enhancements * bump version
1 parent 9c3673e commit 854e320

File tree

14 files changed

+1075
-26
lines changed

14 files changed

+1075
-26
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,14 @@ By default, prompt and completion data are captured. If you would like to opt ou
238238

239239
`TRACE_PROMPT_COMPLETION_DATA=false`
240240

241+
### Enable/Disable checkpoint tracing for DSPy
242+
243+
By default, checkpoints are traced for DSPy pipelines. If you would like to disable it, set the following env var,
244+
245+
`TRACE_DSPY_CHECKPOINT=false`
246+
247+
Note: Checkpoint tracing will increase the latency of executions as the state is serialized. Please disable it in production.
248+
241249
## Supported integrations
242250

243251
Langtrace automatically captures traces from the following vendors:
@@ -253,8 +261,9 @@ Langtrace automatically captures traces from the following vendors:
253261
| Gemini | LLM | :x: | :white_check_mark: |
254262
| Mistral | LLM | :x: | :white_check_mark: |
255263
| Langchain | Framework | :x: | :white_check_mark: |
256-
| LlamaIndex | Framework | :white_check_mark: | :white_check_mark: |
257264
| Langgraph | Framework | :x: | :white_check_mark: |
265+
| LlamaIndex | Framework | :white_check_mark: | :white_check_mark: |
266+
| LiteLLM | Framework | :x: | :white_check_mark: |
258267
| DSPy | Framework | :x: | :white_check_mark: |
259268
| CrewAI | Framework | :x: | :white_check_mark: |
260269
| Ollama | Framework | :x: | :white_check_mark: |

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
'fsspec>=2024.6.0',
3232
"transformers>=4.11.3",
3333
"sentry-sdk>=2.14.0",
34+
"ujson>=5.10.0",
3435
]
3536

3637
requires-python = ">=3.9"
@@ -47,6 +48,7 @@ dev = [
4748
"langchain-community",
4849
"langchain-openai",
4950
"langchain-openai",
51+
"litellm",
5052
"chromadb",
5153
"cohere",
5254
"qdrant_client",
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import dspy
2+
from dotenv import find_dotenv, load_dotenv
3+
from dspy.datasets import HotPotQA
4+
from dspy.teleprompt import BootstrapFewShot
5+
6+
from langtrace_python_sdk import inject_additional_attributes, langtrace
7+
8+
_ = load_dotenv(find_dotenv())
9+
10+
langtrace.init()
11+
12+
turbo = dspy.LM('openai/gpt-4o-mini')
13+
colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')
14+
15+
dspy.settings.configure(lm=turbo, rm=colbertv2_wiki17_abstracts)
16+
17+
18+
# Load the dataset.
19+
dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0)
20+
21+
# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
22+
trainset = [x.with_inputs('question') for x in dataset.train]
23+
devset = [x.with_inputs('question') for x in dataset.dev]
24+
25+
26+
class GenerateAnswer(dspy.Signature):
27+
"""Answer questions with short factoid answers."""
28+
29+
context = dspy.InputField(desc="may contain relevant facts")
30+
question = dspy.InputField()
31+
answer = dspy.OutputField(desc="often between 1 and 5 words")
32+
33+
34+
class RAG(dspy.Module):
35+
def __init__(self, num_passages=3):
36+
super().__init__()
37+
38+
self.retrieve = dspy.Retrieve(k=num_passages)
39+
self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
40+
41+
def forward(self, question):
42+
context = self.retrieve(question).passages
43+
prediction = self.generate_answer(context=context, question=question)
44+
return dspy.Prediction(context=context, answer=prediction.answer)
45+
46+
47+
# Validation logic: check that the predicted answer is correct.
48+
# Also check that the retrieved context does actually contain that answer.
49+
def validate_context_and_answer(example, prediction, trace=None):
50+
answer_em = dspy.evaluate.answer_exact_match(example, prediction)
51+
answer_pm = dspy.evaluate.answer_passage_match(example, prediction)
52+
return answer_em and answer_pm
53+
54+
55+
# Set up a basic optimizer, which will compile our RAG program.
56+
optimizer = BootstrapFewShot(metric=validate_context_and_answer)
57+
58+
# Compile!
59+
compiled_rag = optimizer.compile(RAG(), trainset=trainset)
60+
61+
# Ask any question you like to this simple RAG program.
62+
my_question = "Who was the hero of the movie peraanmai?"
63+
64+
# Get the prediction. This contains `pred.context` and `pred.answer`.
65+
# pred = compiled_rag(my_question)
66+
pred = inject_additional_attributes(lambda: compiled_rag(my_question), {'experiment': 'experiment 6', 'description': 'trying additional stuff', 'run_id': 'run_1'})
67+
# compiled_rag.save('compiled_rag_v1.json')
68+
69+
# Print the contexts and the answer.
70+
print(f"Question: {my_question}")
71+
print(f"Predicted Answer: {pred.answer}")
72+
print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}")
73+
74+
# print("Inspecting the history of the optimizer:")
75+
# turbo.inspect_history(n=1)
76+
77+
from dspy.evaluate import Evaluate
78+
79+
80+
def validate_answer(example, pred, trace=None):
81+
return True
82+
83+
84+
# Set up the evaluator, which can be used multiple times.
85+
evaluate = Evaluate(devset=devset, metric=validate_answer, num_threads=4, display_progress=True, display_table=0)
86+
87+
88+
# Evaluate our `optimized_cot` program.
89+
evaluate(compiled_rag)

src/examples/openai_example/chat_completion.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,19 @@
99

1010
_ = load_dotenv(find_dotenv())
1111

12-
langtrace.init(write_spans_to_console=True)
12+
langtrace.init()
1313
client = OpenAI()
1414

1515

1616
def api():
1717
response = client.chat.completions.create(
18-
model="gpt-4",
18+
model="o1-mini",
1919
messages=[
20-
{"role": "system", "content": "Talk like a pirate"},
21-
{"role": "user", "content": "Tell me a story in 3 sentences or less."},
20+
# {"role": "system", "content": "Talk like a pirate"},
21+
{"role": "user", "content": "How many r's are in strawberry?"},
2222
],
23-
stream=True,
24-
# stream=False,
23+
# stream=True,
24+
stream=False,
2525
)
2626
return response
2727

@@ -31,14 +31,17 @@ def chat_completion():
3131
response = api()
3232
# print(response)
3333
# Uncomment this for streaming
34-
result = []
35-
for chunk in response:
36-
if chunk.choices[0].delta.content is not None:
37-
content = [
38-
choice.delta.content if choice.delta and choice.delta.content else ""
39-
for choice in chunk.choices
40-
]
41-
result.append(content[0] if len(content) > 0 else "")
42-
43-
# print("".join(result))
34+
# result = []
35+
# for chunk in response:
36+
# if chunk.choices[0].delta.content is not None:
37+
# content = [
38+
# choice.delta.content if choice.delta and choice.delta.content else ""
39+
# for choice in chunk.choices
40+
# ]
41+
# result.append(content[0] if len(content) > 0 else "")
42+
43+
# # print("".join(result))
44+
print(response)
4445
return response
46+
47+
chat_completion()

src/langtrace_python_sdk/constants/instrumentation/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"LANGCHAIN_COMMUNITY": "Langchain Community",
2020
"LANGCHAIN_CORE": "Langchain Core",
2121
"LANGGRAPH": "Langgraph",
22+
"LITELLM": "Litellm",
2223
"LLAMAINDEX": "LlamaIndex",
2324
"OPENAI": "OpenAI",
2425
"PINECONE": "Pinecone",
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
APIS = {
2+
"CHAT_COMPLETION": {
3+
"METHOD": "chat.completions.create",
4+
"ENDPOINT": "/chat/completions",
5+
},
6+
"IMAGES_GENERATION": {
7+
"METHOD": "images.generate",
8+
"ENDPOINT": "/images/generations",
9+
},
10+
"IMAGES_EDIT": {
11+
"METHOD": "images.edit",
12+
"ENDPOINT": "/images/edits",
13+
},
14+
"EMBEDDINGS_CREATE": {
15+
"METHOD": "embeddings.create",
16+
"ENDPOINT": "/embeddings",
17+
},
18+
}

src/langtrace_python_sdk/instrumentation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .gemini import GeminiInstrumentation
2020
from .mistral import MistralInstrumentation
2121
from .embedchain import EmbedchainInstrumentation
22+
from .litellm import LiteLLMInstrumentation
2223

2324
__all__ = [
2425
"AnthropicInstrumentation",
@@ -31,6 +32,7 @@
3132
"LangchainCommunityInstrumentation",
3233
"LangchainCoreInstrumentation",
3334
"LanggraphInstrumentation",
35+
"LiteLLMInstrumentation",
3436
"LlamaindexInstrumentation",
3537
"OpenAIInstrumentation",
3638
"PineconeInstrumentation",

src/langtrace_python_sdk/instrumentation/dspy/patch.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
11
import json
2+
import os
3+
4+
import ujson
5+
from colorama import Fore
26
from importlib_metadata import version as v
7+
from langtrace.trace_attributes import FrameworkSpanAttributes
8+
from opentelemetry import baggage
9+
from opentelemetry.trace import SpanKind
10+
from opentelemetry.trace.status import Status, StatusCode
11+
312
from langtrace_python_sdk.constants import LANGTRACE_SDK_NAME
13+
from langtrace_python_sdk.constants.instrumentation.common import (
14+
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY,
15+
SERVICE_PROVIDERS,
16+
)
417
from langtrace_python_sdk.utils import set_span_attribute
518
from langtrace_python_sdk.utils.llm import (
619
get_extra_attributes,
@@ -9,14 +22,6 @@
922
set_span_attributes,
1023
)
1124
from langtrace_python_sdk.utils.silently_fail import silently_fail
12-
from langtrace_python_sdk.constants.instrumentation.common import (
13-
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY,
14-
SERVICE_PROVIDERS,
15-
)
16-
from opentelemetry import baggage
17-
from langtrace.trace_attributes import FrameworkSpanAttributes
18-
from opentelemetry.trace import SpanKind
19-
from opentelemetry.trace.status import Status, StatusCode
2025

2126

2227
def patch_bootstrapfewshot_optimizer(operation_name, version, tracer):
@@ -115,6 +120,8 @@ def traced_method(wrapped, instance, args, kwargs):
115120
**get_extra_attributes(),
116121
}
117122

123+
trace_checkpoint = os.environ.get("TRACE_DSPY_CHECKPOINT", "true").lower()
124+
118125
if instance.__class__.__name__:
119126
span_attributes["dspy.signature.name"] = instance.__class__.__name__
120127
span_attributes["dspy.signature"] = str(instance.signature)
@@ -136,6 +143,9 @@ def traced_method(wrapped, instance, args, kwargs):
136143
"dspy.signature.result",
137144
json.dumps(result.toDict()),
138145
)
146+
if trace_checkpoint == "true":
147+
print(Fore.RED + "Note: DSPy checkpoint tracing is enabled in Langtrace. To disable it, set the env var, TRACE_DSPY_CHECKPOINT to false" + Fore.RESET)
148+
set_span_attribute(span, "dspy.checkpoint", ujson.dumps(instance.dump_state(False), indent=2))
139149
span.set_status(Status(StatusCode.OK))
140150

141151
span.end()
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .instrumentation import LiteLLMInstrumentation
2+
3+
__all__ = [
4+
"LiteLLMInstrumentation",
5+
]
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""
2+
Copyright (c) 2024 Scale3 Labs
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
from typing import Collection, Optional, Any
15+
import importlib.metadata
16+
import logging
17+
18+
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
19+
from opentelemetry.trace import get_tracer, TracerProvider
20+
from wrapt import wrap_function_wrapper
21+
22+
from langtrace_python_sdk.instrumentation.litellm.patch import (
23+
async_chat_completions_create,
24+
async_embeddings_create,
25+
async_images_generate,
26+
chat_completions_create,
27+
embeddings_create,
28+
images_generate,
29+
)
30+
31+
logging.basicConfig(level=logging.FATAL)
32+
33+
34+
class LiteLLMInstrumentation(BaseInstrumentor): # type: ignore
35+
36+
def instrumentation_dependencies(self) -> Collection[str]:
37+
return ["litellm >= 1.48.0", "trace-attributes >= 4.0.5"]
38+
39+
def _instrument(self, **kwargs: Any) -> None:
40+
tracer_provider: Optional[TracerProvider] = kwargs.get("tracer_provider")
41+
tracer = get_tracer(__name__, "", tracer_provider)
42+
version: str = importlib.metadata.version("openai")
43+
44+
wrap_function_wrapper(
45+
"litellm",
46+
"completion",
47+
chat_completions_create(version, tracer),
48+
)
49+
50+
wrap_function_wrapper(
51+
"litellm",
52+
"text_completion",
53+
chat_completions_create(version, tracer),
54+
)
55+
56+
wrap_function_wrapper(
57+
"litellm.main",
58+
"acompletion",
59+
async_chat_completions_create(version, tracer),
60+
)
61+
62+
wrap_function_wrapper(
63+
"litellm.main",
64+
"image_generation",
65+
images_generate(version, tracer),
66+
)
67+
68+
wrap_function_wrapper(
69+
"litellm.main",
70+
"aimage_generation",
71+
async_images_generate(version, tracer),
72+
)
73+
74+
wrap_function_wrapper(
75+
"litellm.main",
76+
"embedding",
77+
embeddings_create(version, tracer),
78+
)
79+
80+
wrap_function_wrapper(
81+
"litellm.main",
82+
"aembedding",
83+
async_embeddings_create(version, tracer),
84+
)
85+
86+
def _uninstrument(self, **kwargs: Any) -> None:
87+
pass

0 commit comments

Comments
 (0)