Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion integrations/dspy/src/databricks_dspy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from databricks_dspy.adapters import DatabricksCitations, DatabricksDocument
from databricks_dspy.clients import DatabricksLM
from databricks_dspy.retrievers import DatabricksRM
from databricks_dspy.streaming import DatabricksStreamListener

__all__ = ["DatabricksLM", "DatabricksRM"]
__all__ = ["DatabricksLM", "DatabricksRM", "DatabricksCitations", "DatabricksDocument", "DatabricksStreamListener"]
4 changes: 4 additions & 0 deletions integrations/dspy/src/databricks_dspy/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from databricks_dspy.adapters.types.citation import DatabricksCitations
from databricks_dspy.adapters.types.document import DatabricksDocument

__all__ = ["DatabricksCitations", "DatabricksDocument"]
164 changes: 164 additions & 0 deletions integrations/dspy/src/databricks_dspy/adapters/types/citation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from typing import Any

import pydantic
from dspy.adapters.types.base_type import Type


class DatabricksCitations(Type):
"""Citations extracted from an LM response with source references.

This type represents citations returned by language models that support
citation extraction, particularly Anthropic's Citations API through LiteLLM.
Citations include the quoted text and source information.

Example:
```python
import dspy
from dspy.signatures import Signature

class AnswerWithSources(Signature):
'''Answer questions using provided documents with citations.'''
documents: list[dspy.DatabricksDocument] = dspy.InputField()
question: str = dspy.InputField()
answer: str = dspy.OutputField()
citations: dspy.DatabricksCitations = dspy.OutputField()

# Create documents to provide as sources
docs = [
dspy.DatabricksDocument(
data="The Earth orbits the Sun in an elliptical path.",
title="Basic Astronomy Facts"
),
dspy.DatabricksDocument(
data="Water boils at 100°C at standard atmospheric pressure.",
title="Physics Fundamentals",
metadata={"author": "Dr. Smith", "year": 2023}
)
]

# Use with a model that supports citations like Claude
lm = dspy.LM("anthropic/claude-opus-4-1-20250805")
predictor = dspy.Predict(AnswerWithSources, lm=lm)
result = predictor(documents=docs, question="What temperature does water boil?")

for citation in result.citations.citations:
print(citation.format())
```
"""

class Citation(Type):
"""Individual citation with character location information."""
type: str = "char_location"
cited_text: str
document_index: int
document_title: str | None = None
start_char_index: int
end_char_index: int
supported_text: str | None = None

def format(self) -> dict[str, Any]:
"""Format citation as dictionary for LM consumption.

Returns:
A dictionary in the format expected by citation APIs.
"""
citation_dict = {
"type": self.type,
"cited_text": self.cited_text,
"document_index": self.document_index,
"start_char_index": self.start_char_index,
"end_char_index": self.end_char_index
}

if self.document_title:
citation_dict["document_title"] = self.document_title

if self.supported_text:
citation_dict["supported_text"] = self.supported_text

return citation_dict

citations: list[Citation]

@classmethod
def from_dict_list(cls, citations_dicts: list[dict[str, Any]]) -> "DatabricksCitations":
"""Convert a list of dictionaries to a Citations instance.

Args:
citations_dicts: A list of dictionaries, where each dictionary should have 'cited_text' key
and 'document_index', 'start_char_index', 'end_char_index' keys.

Returns:
A Citations instance.

Example:
```python
citations_dict = [
{
"cited_text": "The sky is blue",
"document_index": 0,
"document_title": "Weather Guide",
"start_char_index": 0,
"end_char_index": 15,
"supported_text": "The sky was blue yesterday."
}
]
citations = Citations.from_dict_list(citations_dict)
```
"""
citations = [cls.Citation(**item) for item in citations_dicts]
return cls(citations=citations)

@classmethod
def description(cls) -> str:
"""Description of the citations type for use in prompts."""
return (
"Citations with quoted text and source references. "
"Include the exact text being cited and information about its source."
)

def format(self) -> list[dict[str, Any]]:
"""Format citations as a list of dictionaries."""
return [citation.format() for citation in self.citations]

@pydantic.model_validator(mode="before")
@classmethod
def validate_input(cls, data: Any):
if isinstance(data, cls):
return data

# Handle case where data is a list of dicts with citation info
if isinstance(data, list) and all(
isinstance(item, dict) and "cited_text" in item for item in data
):
return {"citations": [cls.Citation(**item) for item in data]}

# Handle case where data is a dict
elif isinstance(data, dict):
if "citations" in data:
# Handle case where data is a dict with "citations" key
citations_data = data["citations"]
if isinstance(citations_data, list):
return {
"citations": [
cls.Citation(**item) if isinstance(item, dict) else item
for item in citations_data
]
}
elif "cited_text" in data:
# Handle case where data is a single citation dict
return {"citations": [cls.Citation(**data)]}

raise ValueError(f"Received invalid value for `dspy.Citations`: {data}")

def __iter__(self):
"""Allow iteration over citations."""
return iter(self.citations)

def __len__(self):
"""Return the number of citations."""
return len(self.citations)

def __getitem__(self, index):
"""Allow indexing into citations."""
return self.citations[index]
110 changes: 110 additions & 0 deletions integrations/dspy/src/databricks_dspy/adapters/types/document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Any, Literal

import pydantic
from dspy.adapters.types.base_type import Type


class DatabricksDocument(Type):
"""A document type for providing content that can be cited by language models.

This type represents documents that can be passed to language models for citation-enabled
responses, particularly useful with Anthropic's Citations API. Documents include the content
and metadata that helps the LM understand and reference the source material.

Attributes:
data: The text content of the document
title: Optional title for the document (used in citations)
media_type: MIME type of the document content (defaults to "text/plain")
context: Optional context information about the document

Example:
```python
import dspy
from dspy.signatures import Signature

class AnswerWithSources(Signature):
'''Answer questions using provided documents with citations.'''
documents: list[dspy.DatabricksDocument] = dspy.InputField()
question: str = dspy.InputField()
answer: str = dspy.OutputField()
citations: dspy.DatabricksCitations = dspy.OutputField()

# Create documents
docs = [
dspy.DatabricksDocument(
data="The Earth orbits the Sun in an elliptical path.",
title="Basic Astronomy Facts"
),
dspy.DatabricksDocument(
data="Water boils at 100°C at standard atmospheric pressure.",
title="Physics Fundamentals",
)
]

# Use with a citation-supporting model
lm = dspy.LM("anthropic/claude-opus-4-1-20250805")
predictor = dspy.Predict(AnswerWithSources)
result = predictor(documents=docs, question="What temperature does water boil?", lm=lm)
print(result.citations)
```
"""

data: str
title: str | None = None
media_type: Literal["text/plain", "application/pdf"] = "text/plain"
context: str | None = None

def format(self) -> list[dict[str, Any]]:
"""Format document for LM consumption.

Returns:
A list containing the document block in the format expected by citation-enabled language models.
"""
document_block = {
"type": "document",
"source": {
"type": "text",
"media_type": self.media_type,
"data": self.data
},
"citations": {"enabled": True}
}

if self.title:
document_block["title"] = self.title

if self.context:
document_block["context"] = self.context

return [document_block]



@classmethod
def description(cls) -> str:
"""Description of the document type for use in prompts."""
return (
"A document containing text content that can be referenced and cited. "
"Include the full text content and optionally a title for proper referencing."
)

@pydantic.model_validator(mode="before")
@classmethod
def validate_input(cls, data: Any):
if isinstance(data, cls):
return data

# Handle case where data is just a string (data only)
if isinstance(data, str):
return {"data": data}

# Handle case where data is a dict
elif isinstance(data, dict):
return data

raise ValueError(f"Received invalid value for `dspy.Document`: {data}")

def __str__(self) -> str:
"""String representation showing title and content length."""
title_part = f"'{self.title}': " if self.title else ""
return f"Document({title_part}{len(self.data)} chars)"
3 changes: 3 additions & 0 deletions integrations/dspy/src/databricks_dspy/streaming/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from databricks_dspy.streaming.streaming_listener import DatabricksStreamListener

__all__ = ["DatabricksStreamListener"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from dspy.streaming.messages import StreamResponse
from dspy.streaming.streaming_listener import StreamListener
from litellm import ModelResponseStream

from databricks_dspy.adapters.types.citation import DatabricksCitations


class DatabricksStreamListener(StreamListener):
def receive(self, chunk: ModelResponseStream):
# Handle anthropic citations. see https://docs.litellm.ai/docs/providers/anthropic#beta-citations-api
try:
if self._is_citation_type():
if chunk_citation := chunk.choices[0].delta.provider_specific_fields.get("citation", None):
return StreamResponse(
self.predict_name,
self.signature_field_name,
DatabricksCitations.from_dict_list([chunk_citation]),
is_last_chunk=False,
)
except Exception:
pass

super().receive(chunk)

def _is_citation_type(self) -> bool:
"""Check if the signature field is a citations field."""
from dspy.predict import Predict
return isinstance(self.predict, Predict) and getattr(self.predict.signature.output_fields.get(self.signature_field_name, None), "annotation", None) == Citations
Loading
Loading