Skip to content

Commit 7984b3d

Browse files
committed
initial work on python conversation streaming
1 parent ea45284 commit 7984b3d

File tree

15 files changed

+3103
-510
lines changed

15 files changed

+3103
-510
lines changed

dapr/aio/clients/grpc/client.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from warnings import warn
2626

27-
from typing import Callable, Dict, Optional, Text, Union, Sequence, List, Any, Awaitable
27+
from typing import Callable, Dict, Optional, Text, Union, Sequence, List, Any, Awaitable, AsyncIterator
2828
from typing_extensions import Self
2929

3030
from google.protobuf.message import Message as GrpcMessage
@@ -82,6 +82,8 @@
8282
BindingResponse,
8383
ConversationResponse,
8484
ConversationResult,
85+
ConversationStreamResponse,
86+
ConversationUsage,
8587
DaprResponse,
8688
GetSecretResponse,
8789
GetBulkSecretResponse,
@@ -1771,6 +1773,86 @@ async def converse_alpha1(
17711773
except grpc.aio.AioRpcError as err:
17721774
raise DaprGrpcError(err) from err
17731775

1776+
async def converse_stream_alpha1(
1777+
self,
1778+
name: str,
1779+
inputs: List[ConversationInput],
1780+
*,
1781+
context_id: Optional[str] = None,
1782+
parameters: Optional[Dict[str, GrpcAny]] = None,
1783+
metadata: Optional[Dict[str, str]] = None,
1784+
scrub_pii: Optional[bool] = None,
1785+
temperature: Optional[float] = None,
1786+
) -> AsyncIterator[ConversationStreamResponse]:
1787+
"""Invoke an LLM using the streaming conversation API (Alpha).
1788+
1789+
Args:
1790+
name: Name of the LLM component to invoke
1791+
inputs: List of conversation inputs
1792+
context_id: Optional ID for continuing an existing chat
1793+
parameters: Optional custom parameters for the request
1794+
metadata: Optional metadata for the component
1795+
scrub_pii: Optional flag to scrub PII from inputs and outputs
1796+
temperature: Optional temperature setting for the LLM to optimize for creativity or predictability
1797+
1798+
Yields:
1799+
ConversationStreamResponse containing conversation result chunks
1800+
1801+
Raises:
1802+
DaprGrpcError: If the Dapr runtime returns an error
1803+
"""
1804+
from dapr.clients.grpc._response import ConversationStreamResponse
1805+
1806+
inputs_pb = [
1807+
api_v1.ConversationInput(content=inp.content, role=inp.role, scrubPII=inp.scrub_pii)
1808+
for inp in inputs
1809+
]
1810+
1811+
request = api_v1.ConversationRequest(
1812+
name=name,
1813+
inputs=inputs_pb,
1814+
contextID=context_id,
1815+
parameters=parameters or {},
1816+
metadata=metadata or {},
1817+
scrubPII=scrub_pii,
1818+
temperature=temperature,
1819+
)
1820+
1821+
try:
1822+
response_stream = self._stub.ConverseStreamAlpha1(request)
1823+
1824+
async for response in response_stream:
1825+
context_id = None
1826+
result = None
1827+
usage = None
1828+
1829+
# Handle chunk response
1830+
if response.HasField('chunk'):
1831+
result = ConversationResult(
1832+
result=response.chunk.content,
1833+
parameters={}
1834+
)
1835+
1836+
# Handle completion response
1837+
elif response.HasField('complete'):
1838+
context_id = response.complete.contextID
1839+
1840+
# Extract usage information if available
1841+
if response.complete.HasField('usage'):
1842+
usage = ConversationUsage(
1843+
prompt_tokens=response.complete.usage.prompt_tokens,
1844+
completion_tokens=response.complete.usage.completion_tokens,
1845+
total_tokens=response.complete.usage.total_tokens
1846+
)
1847+
1848+
yield ConversationStreamResponse(
1849+
context_id=context_id,
1850+
result=result,
1851+
usage=usage
1852+
)
1853+
except grpc.aio.AioRpcError as err:
1854+
raise DaprGrpcError(err) from err
1855+
17741856
async def wait(self, timeout_s: float):
17751857
"""Waits for sidecar to be available within the timeout.
17761858

dapr/clients/grpc/_response.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,3 +1087,22 @@ class ConversationResponse:
10871087

10881088
context_id: Optional[str]
10891089
outputs: List[ConversationResult]
1090+
usage: Optional[ConversationUsage] = None
1091+
1092+
1093+
@dataclass
1094+
class ConversationUsage:
1095+
"""Token usage statistics from conversation API."""
1096+
1097+
prompt_tokens: int = 0
1098+
completion_tokens: int = 0
1099+
total_tokens: int = 0
1100+
1101+
1102+
@dataclass
1103+
class ConversationStreamResponse:
1104+
"""Single response chunk from the streaming conversation API."""
1105+
1106+
context_id: Optional[str]
1107+
result: Optional[ConversationResult] = None
1108+
usage: Optional[ConversationUsage] = None

dapr/clients/grpc/client.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from warnings import warn
2424

25-
from typing import Callable, Dict, Optional, Text, Union, Sequence, List, Any
25+
from typing import Callable, Dict, Optional, Text, Union, Sequence, List, Any, Iterator
2626
from typing_extensions import Self
2727
from datetime import datetime
2828
from google.protobuf.message import Message as GrpcMessage
@@ -92,6 +92,8 @@
9292
TopicEventResponse,
9393
ConversationResponse,
9494
ConversationResult,
95+
ConversationStreamResponse,
96+
ConversationUsage,
9597
)
9698

9799

@@ -1773,6 +1775,86 @@ def converse_alpha1(
17731775
except RpcError as err:
17741776
raise DaprGrpcError(err) from err
17751777

1778+
def converse_stream_alpha1(
1779+
self,
1780+
name: str,
1781+
inputs: List[ConversationInput],
1782+
*,
1783+
context_id: Optional[str] = None,
1784+
parameters: Optional[Dict[str, GrpcAny]] = None,
1785+
metadata: Optional[Dict[str, str]] = None,
1786+
scrub_pii: Optional[bool] = None,
1787+
temperature: Optional[float] = None,
1788+
) -> Iterator[ConversationStreamResponse]:
1789+
"""Invoke an LLM using the streaming conversation API (Alpha).
1790+
1791+
Args:
1792+
name: Name of the LLM component to invoke
1793+
inputs: List of conversation inputs
1794+
context_id: Optional ID for continuing an existing chat
1795+
parameters: Optional custom parameters for the request
1796+
metadata: Optional metadata for the component
1797+
scrub_pii: Optional flag to scrub PII from inputs and outputs
1798+
temperature: Optional temperature setting for the LLM to optimize for creativity or predictability
1799+
1800+
Yields:
1801+
ConversationStreamResponse containing conversation result chunks
1802+
1803+
Raises:
1804+
DaprGrpcError: If the Dapr runtime returns an error
1805+
"""
1806+
from dapr.clients.grpc._response import ConversationStreamResponse
1807+
1808+
inputs_pb = [
1809+
api_v1.ConversationInput(content=inp.content, role=inp.role, scrubPII=inp.scrub_pii)
1810+
for inp in inputs
1811+
]
1812+
1813+
request = api_v1.ConversationRequest(
1814+
name=name,
1815+
inputs=inputs_pb,
1816+
contextID=context_id,
1817+
parameters=parameters or {},
1818+
metadata=metadata or {},
1819+
scrubPII=scrub_pii,
1820+
temperature=temperature,
1821+
)
1822+
1823+
try:
1824+
response_stream = self.retry_policy.run_rpc(self._stub.ConverseStreamAlpha1, request)
1825+
1826+
for response in response_stream:
1827+
context_id = None
1828+
result = None
1829+
usage = None
1830+
1831+
# Handle chunk response
1832+
if response.HasField('chunk'):
1833+
result = ConversationResult(
1834+
result=response.chunk.content,
1835+
parameters={}
1836+
)
1837+
1838+
# Handle completion response
1839+
elif response.HasField('complete'):
1840+
context_id = response.complete.contextID
1841+
1842+
# Extract usage information if available
1843+
if response.complete.HasField('usage'):
1844+
usage = ConversationUsage(
1845+
prompt_tokens=response.complete.usage.prompt_tokens,
1846+
completion_tokens=response.complete.usage.completion_tokens,
1847+
total_tokens=response.complete.usage.total_tokens
1848+
)
1849+
1850+
yield ConversationStreamResponse(
1851+
context_id=context_id,
1852+
result=result,
1853+
usage=usage
1854+
)
1855+
except RpcError as err:
1856+
raise DaprGrpcError(err) from err
1857+
17761858
def wait(self, timeout_s: float):
17771859
"""Waits for sidecar to be available within the timeout.
17781860

0 commit comments

Comments
 (0)