11using System . Text . Json ;
22using Devlooped . Grok ;
3- using Google . Protobuf ;
43using Grpc . Core ;
54using Grpc . Net . Client ;
65using Microsoft . Extensions . AI ;
@@ -23,93 +22,133 @@ internal GrokChatClient(GrpcChannel channel, GrokClientOptions clientOptions, st
2322
2423 public async Task < ChatResponse > GetResponseAsync ( IEnumerable < ChatMessage > messages , ChatOptions ? options = null , CancellationToken cancellationToken = default )
2524 {
26- var requestDto = MapToRequest ( messages , options ) ;
27- var protoResponse = await client . GetCompletionAsync ( requestDto , cancellationToken : cancellationToken ) ;
28- var lastOutput = protoResponse . Outputs . OrderByDescending ( x => x . Index ) . FirstOrDefault ( ) ;
25+ var request = MapToRequest ( messages , options ) ;
26+ var response = await client . GetCompletionAsync ( request , cancellationToken : cancellationToken ) ;
27+ var lastOutput = response . Outputs . OrderByDescending ( x => x . Index ) . FirstOrDefault ( ) ;
2928
3029 if ( lastOutput == null )
3130 {
3231 return new ChatResponse ( )
3332 {
34- ResponseId = protoResponse . Id ,
35- ModelId = protoResponse . Model ,
36- CreatedAt = protoResponse . Created . ToDateTimeOffset ( ) ,
37- Usage = MapToUsage ( protoResponse . Usage ) ,
33+ ResponseId = response . Id ,
34+ ModelId = response . Model ,
35+ CreatedAt = response . Created . ToDateTimeOffset ( ) ,
36+ Usage = MapToUsage ( response . Usage ) ,
3837 } ;
3938 }
4039
4140 var message = new ChatMessage ( MapRole ( lastOutput . Message . Role ) , default ( string ) ) ;
42- var citations = protoResponse . Citations ? . Distinct ( ) . Select ( MapCitation ) . ToList < AIAnnotation > ( ) ;
41+ var citations = response . Citations ? . Distinct ( ) . Select ( MapCitation ) . ToList < AIAnnotation > ( ) ;
4342
44- foreach ( var output in protoResponse . Outputs . OrderBy ( x => x . Index ) )
43+ foreach ( var output in response . Outputs . OrderBy ( x => x . Index ) )
4544 {
4645 if ( output . Message . Content is { Length : > 0 } text )
4746 {
48- var content = new TextContent ( text )
47+ // Special-case output from tools
48+ if ( output . Message . Role == MessageRole . RoleTool &&
49+ output . Message . ToolCalls . Count == 1 &&
50+ output . Message . ToolCalls [ 0 ] is { } toolCall )
4951 {
50- Annotations = citations
51- } ;
52+ if ( toolCall . Type == ToolCallType . McpTool )
53+ {
54+ message . Contents . Add ( new McpServerToolCallContent ( toolCall . Id , toolCall . Function . Name , null )
55+ {
56+ RawRepresentation = toolCall
57+ } ) ;
58+ message . Contents . Add ( new McpServerToolResultContent ( toolCall . Id )
59+ {
60+ RawRepresentation = toolCall ,
61+ Output = [ new TextContent ( text ) ]
62+ } ) ;
63+ continue ;
64+ }
65+ else if ( toolCall . Type == ToolCallType . CodeExecutionTool )
66+ {
67+ message . Contents . Add ( new CodeInterpreterToolCallContent ( )
68+ {
69+ CallId = toolCall . Id ,
70+ RawRepresentation = toolCall
71+ } ) ;
72+ message . Contents . Add ( new CodeInterpreterToolResultContent ( )
73+ {
74+ CallId = toolCall . Id ,
75+ RawRepresentation = toolCall ,
76+ Outputs = [ new TextContent ( text ) ]
77+ } ) ;
78+ continue ;
79+ }
80+ }
81+
82+ var content = new TextContent ( text ) { Annotations = citations } ;
5283
5384 foreach ( var citation in output . Message . Citations )
54- {
5585 ( content . Annotations ??= [ ] ) . Add ( MapInlineCitation ( citation ) ) ;
56- }
86+
5787 message . Contents . Add ( content ) ;
5888 }
5989
6090 foreach ( var toolCall in output . Message . ToolCalls )
61- {
62- if ( toolCall . Type == ToolCallType . ClientSideTool )
63- {
64- var arguments = ! string . IsNullOrEmpty ( toolCall . Function . Arguments )
65- ? JsonSerializer . Deserialize < IDictionary < string , object ? > > ( toolCall . Function . Arguments )
66- : null ;
67-
68- var content = new FunctionCallContent (
69- toolCall . Id ,
70- toolCall . Function . Name ,
71- arguments ) ;
72-
73- message . Contents . Add ( content ) ;
74- }
75- else
76- {
77- message . Contents . Add ( new HostedToolCallContent ( toolCall ) ) ;
78- }
79- }
91+ message . Contents . Add ( MapToolCall ( toolCall ) ) ;
8092 }
8193
8294 return new ChatResponse ( message )
8395 {
84- ResponseId = protoResponse . Id ,
85- ModelId = protoResponse . Model ,
86- CreatedAt = protoResponse . Created . ToDateTimeOffset ( ) ,
96+ ResponseId = response . Id ,
97+ ModelId = response . Model ,
98+ CreatedAt = response . Created . ToDateTimeOffset ( ) ,
8799 FinishReason = lastOutput != null ? MapFinishReason ( lastOutput . FinishReason ) : null ,
88- Usage = MapToUsage ( protoResponse . Usage ) ,
100+ Usage = MapToUsage ( response . Usage ) ,
89101 } ;
90102 }
91103
104+ AIContent MapToolCall ( ToolCall toolCall ) => toolCall . Type switch
105+ {
106+ ToolCallType . ClientSideTool => new FunctionCallContent (
107+ toolCall . Id ,
108+ toolCall . Function . Name ,
109+ ! string . IsNullOrEmpty ( toolCall . Function . Arguments )
110+ ? JsonSerializer . Deserialize < IDictionary < string , object ? > > ( toolCall . Function . Arguments )
111+ : null )
112+ {
113+ RawRepresentation = toolCall
114+ } ,
115+ ToolCallType . McpTool => new McpServerToolCallContent ( toolCall . Id , toolCall . Function . Name , null )
116+ {
117+ RawRepresentation = toolCall
118+ } ,
119+ ToolCallType . CodeExecutionTool => new CodeInterpreterToolCallContent ( )
120+ {
121+ CallId = toolCall . Id ,
122+ RawRepresentation = toolCall
123+ } ,
124+ _ => new HostedToolCallContent ( )
125+ {
126+ CallId = toolCall . Id ,
127+ RawRepresentation = toolCall
128+ }
129+ } ;
130+
92131 public IAsyncEnumerable < ChatResponseUpdate > GetStreamingResponseAsync ( IEnumerable < ChatMessage > messages , ChatOptions ? options = null , CancellationToken cancellationToken = default )
93132 {
94133 return CompleteChatStreamingCore ( messages , options , cancellationToken ) ;
95134
96135 async IAsyncEnumerable < ChatResponseUpdate > CompleteChatStreamingCore ( IEnumerable < ChatMessage > messages , ChatOptions ? options , [ System . Runtime . CompilerServices . EnumeratorCancellation ] CancellationToken cancellationToken )
97136 {
98- var requestDto = MapToRequest ( messages , options ) ;
99- var call = client . GetCompletionChunk ( requestDto , cancellationToken : cancellationToken ) ;
100-
137+ var request = MapToRequest ( messages , options ) ;
138+ var call = client . GetCompletionChunk ( request , cancellationToken : cancellationToken ) ;
139+
101140 await foreach ( var chunk in call . ResponseStream . ReadAllAsync ( cancellationToken ) )
102141 {
103- var outputChunk = chunk . Outputs [ 0 ] ;
104- var text = outputChunk . Delta . Content is { Length : > 0 } delta ? delta : null ;
142+ var output = chunk . Outputs [ 0 ] ;
143+ var text = output . Delta . Content is { Length : > 0 } delta ? delta : null ;
105144
106145 // Use positional arguments for ChatResponseUpdate
107- var update = new ChatResponseUpdate ( MapRole ( outputChunk . Delta . Role ) , text )
146+ var update = new ChatResponseUpdate ( MapRole ( output . Delta . Role ) , text )
108147 {
109148 ResponseId = chunk . Id ,
110149 ModelId = chunk . Model ,
111150 CreatedAt = chunk . Created ? . ToDateTimeOffset ( ) ,
112- FinishReason = outputChunk . FinishReason != FinishReason . ReasonInvalid ? MapFinishReason ( outputChunk . FinishReason ) : null ,
151+ FinishReason = output . FinishReason != FinishReason . ReasonInvalid ? MapFinishReason ( output . FinishReason ) : null ,
113152 } ;
114153
115154 if ( chunk . Citations is { Count : > 0 } citations )
@@ -122,31 +161,11 @@ async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable
122161 }
123162
124163 foreach ( var citation in citations . Distinct ( ) )
125- {
126164 ( textContent . Annotations ??= [ ] ) . Add ( MapCitation ( citation ) ) ;
127- }
128165 }
129166
130- foreach ( var toolCall in outputChunk . Delta . ToolCalls )
131- {
132- if ( toolCall . Type == ToolCallType . ClientSideTool )
133- {
134- var arguments = ! string . IsNullOrEmpty ( toolCall . Function . Arguments )
135- ? JsonSerializer . Deserialize < IDictionary < string , object ? > > ( toolCall . Function . Arguments )
136- : null ;
137-
138- var content = new FunctionCallContent (
139- toolCall . Id ,
140- toolCall . Function . Name ,
141- arguments ) ;
142-
143- update . Contents . Add ( content ) ;
144- }
145- else
146- {
147- update . Contents . Add ( new HostedToolCallContent ( toolCall ) ) ;
148- }
149- }
167+ foreach ( var toolCall in output . Delta . ToolCalls )
168+ update . Contents . Add ( MapToolCall ( toolCall ) ) ;
150169
151170 if ( update . Contents . Any ( ) )
152171 yield return update ;
@@ -191,6 +210,8 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
191210 {
192211 var request = new GetCompletionsRequest
193212 {
213+ // By default always include citations in the final output if available
214+ Include = { IncludeOption . InlineCitations } ,
194215 Model = options ? . ModelId ?? defaultModelId ,
195216 } ;
196217
@@ -211,6 +232,10 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
211232 {
212233 gmsg . Content . Add ( new Content { Text = textContent . Text } ) ;
213234 }
235+ else if ( content . RawRepresentation is ToolCall toolCall )
236+ {
237+ gmsg . ToolCalls . Add ( toolCall ) ;
238+ }
214239 else if ( content is FunctionCallContent functionCall )
215240 {
216241 gmsg . ToolCalls . Add ( new ToolCall
@@ -224,10 +249,6 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
224249 }
225250 } ) ;
226251 }
227- else if ( content is HostedToolCallContent serverFunction )
228- {
229- gmsg . ToolCalls . Add ( serverFunction . ToolCall ) ;
230- }
231252 else if ( content is FunctionResultContent resultContent )
232253 {
233254 request . Messages . Add ( new Message
@@ -236,19 +257,49 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
236257 Content = { new Content { Text = JsonSerializer . Serialize ( resultContent . Result ) ?? "null" } }
237258 } ) ;
238259 }
260+ else if ( content is McpServerToolResultContent mcpResult &&
261+ mcpResult . RawRepresentation is ToolCall mcpToolCall &&
262+ mcpResult . Output is { Count : 1 } &&
263+ mcpResult . Output [ 0 ] is TextContent mcpText )
264+ {
265+ request . Messages . Add ( new Message
266+ {
267+ Role = MessageRole . RoleTool ,
268+ ToolCalls = { mcpToolCall } ,
269+ Content = { new Content { Text = mcpText . Text } }
270+ } ) ;
271+ }
272+ else if ( content is CodeInterpreterToolResultContent codeResult &&
273+ codeResult . RawRepresentation is ToolCall codeToolCall &&
274+ codeResult . Outputs is { Count : 1 } &&
275+ codeResult . Outputs [ 0 ] is TextContent codeText )
276+ {
277+ request . Messages . Add ( new Message
278+ {
279+ Role = MessageRole . RoleTool ,
280+ ToolCalls = { codeToolCall } ,
281+ Content = { new Content { Text = codeText . Text } }
282+ } ) ;
283+ }
239284 }
240285
241286 if ( gmsg . Content . Count == 0 && gmsg . ToolCalls . Count == 0 )
242287 continue ;
243288
289+ // If we have only tool calls and no content, the gRPC enpoint fails, so add an empty one.
244290 if ( gmsg . Content . Count == 0 )
245291 gmsg . Content . Add ( new Content ( ) ) ;
246292
247293 request . Messages . Add ( gmsg ) ;
248294 }
249295
296+ IList < IncludeOption > includes = [ IncludeOption . InlineCitations ] ;
250297 if ( options is GrokChatOptions grokOptions )
251298 {
299+ // NOTE: overrides our default include for inline citations, potentially.
300+ request . Include . Clear ( ) ;
301+ request . Include . AddRange ( grokOptions . Include ) ;
302+
252303 if ( grokOptions . Search . HasFlag ( GrokSearch . X ) )
253304 {
254305 ( options . Tools ??= [ ] ) . Insert ( 0 , new GrokXSearchTool ( ) ) ;
0 commit comments