Skip to content

Commit 2a172f2

Browse files
authored
[VertexAI] Support cancellation in GenerateContent (#1239)
* [VertexAI] Support cancellation in GenerateContent * Comment cleanup
1 parent 69fe6f9 commit 2a172f2

File tree

3 files changed

+80
-59
lines changed

3 files changed

+80
-59
lines changed

vertexai/src/Chat.cs

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
* limitations under the License.
1515
*/
1616

17-
using System;
1817
using System.Collections.Generic;
1918
using System.Collections.ObjectModel;
2019
using System.Linq;
20+
using System.Runtime.CompilerServices;
21+
using System.Threading;
2122
using System.Threading.Tasks;
2223
using Firebase.VertexAI.Internal;
2324

@@ -60,72 +61,78 @@ internal static Chat InternalCreateChat(GenerativeModel model, IEnumerable<Model
6061
/// Sends a message using the existing history of this chat as context. If successful, the message
6162
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
6263
/// </summary>
63-
/// <param name="content">The input(s) given to the model as a prompt.</param>
64+
/// <param name="content">The input given to the model as a prompt.</param>
65+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
6466
/// <returns>The model's response if no error occurred.</returns>
6567
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
6668
public Task<GenerateContentResponse> SendMessageAsync(
67-
params ModelContent[] content) {
68-
return SendMessageAsync((IEnumerable<ModelContent>)content);
69+
ModelContent content, CancellationToken cancellationToken = default) {
70+
return SendMessageAsync(new[] { content }, cancellationToken);
6971
}
7072
/// <summary>
7173
/// Sends a message using the existing history of this chat as context. If successful, the message
7274
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
7375
/// </summary>
7476
/// <param name="text">The text given to the model as a prompt.</param>
77+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
7578
/// <returns>The model's response if no error occurred.</returns>
7679
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
7780
public Task<GenerateContentResponse> SendMessageAsync(
78-
string text) {
79-
return SendMessageAsync(new ModelContent[] { ModelContent.Text(text) });
81+
string text, CancellationToken cancellationToken = default) {
82+
return SendMessageAsync(new ModelContent[] { ModelContent.Text(text) }, cancellationToken);
8083
}
8184
/// <summary>
8285
/// Sends a message using the existing history of this chat as context. If successful, the message
8386
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
8487
/// </summary>
85-
/// <param name="content">The input(s) given to the model as a prompt.</param>
88+
/// <param name="content">The input given to the model as a prompt.</param>
89+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
8690
/// <returns>The model's response if no error occurred.</returns>
8791
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
8892
public Task<GenerateContentResponse> SendMessageAsync(
89-
IEnumerable<ModelContent> content) {
90-
return SendMessageAsyncInternal(content);
93+
IEnumerable<ModelContent> content, CancellationToken cancellationToken = default) {
94+
return SendMessageAsyncInternal(content, cancellationToken);
9195
}
9296

9397
/// <summary>
9498
/// Sends a message using the existing history of this chat as context. If successful, the message
9599
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
96100
/// </summary>
97-
/// <param name="content">The input(s) given to the model as a prompt.</param>
101+
/// <param name="content">The input given to the model as a prompt.</param>
102+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
98103
/// <returns>A stream of generated content responses from the model.</returns>
99104
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
100105
public IAsyncEnumerable<GenerateContentResponse> SendMessageStreamAsync(
101-
params ModelContent[] content) {
102-
return SendMessageStreamAsync((IEnumerable<ModelContent>)content);
106+
ModelContent content, CancellationToken cancellationToken = default) {
107+
return SendMessageStreamAsync(new[] { content }, cancellationToken);
103108
}
104109
/// <summary>
105110
/// Sends a message using the existing history of this chat as context. If successful, the message
106111
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
107112
/// </summary>
108113
/// <param name="text">The text given to the model as a prompt.</param>
114+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
109115
/// <returns>A stream of generated content responses from the model.</returns>
110116
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
111117
public IAsyncEnumerable<GenerateContentResponse> SendMessageStreamAsync(
112-
string text) {
113-
return SendMessageStreamAsync(new ModelContent[] { ModelContent.Text(text) });
118+
string text, CancellationToken cancellationToken = default) {
119+
return SendMessageStreamAsync(new ModelContent[] { ModelContent.Text(text) }, cancellationToken);
114120
}
115121
/// <summary>
116122
/// Sends a message using the existing history of this chat as context. If successful, the message
117123
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
118124
/// </summary>
119-
/// <param name="content">The input(s) given to the model as a prompt.</param>
125+
/// <param name="content">The input given to the model as a prompt.</param>
126+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
120127
/// <returns>A stream of generated content responses from the model.</returns>
121128
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
122129
public IAsyncEnumerable<GenerateContentResponse> SendMessageStreamAsync(
123-
IEnumerable<ModelContent> content) {
124-
return SendMessageStreamAsyncInternal(content);
130+
IEnumerable<ModelContent> content, CancellationToken cancellationToken = default) {
131+
return SendMessageStreamAsyncInternal(content, cancellationToken);
125132
}
126133

127134
private async Task<GenerateContentResponse> SendMessageAsyncInternal(
128-
IEnumerable<ModelContent> requestContent) {
135+
IEnumerable<ModelContent> requestContent, CancellationToken cancellationToken = default) {
129136
// Make sure that the requests are set to to role "user".
130137
List<ModelContent> fixedRequests = requestContent.Select(VertexAIExtensions.ConvertToUser).ToList();
131138
// Set up the context to send in the request
@@ -134,7 +141,7 @@ private async Task<GenerateContentResponse> SendMessageAsyncInternal(
134141

135142
// Note: GenerateContentAsync can throw exceptions if there was a problem, but
136143
// we allow it to just be passed back to the user.
137-
GenerateContentResponse response = await generativeModel.GenerateContentAsync(fullRequest);
144+
GenerateContentResponse response = await generativeModel.GenerateContentAsync(fullRequest, cancellationToken);
138145

139146
// Only after getting a valid response, add both to the history for later.
140147
// But either way pass the response along to the user.
@@ -149,7 +156,8 @@ private async Task<GenerateContentResponse> SendMessageAsyncInternal(
149156
}
150157

151158
private async IAsyncEnumerable<GenerateContentResponse> SendMessageStreamAsyncInternal(
152-
IEnumerable<ModelContent> requestContent) {
159+
IEnumerable<ModelContent> requestContent,
160+
[EnumeratorCancellation] CancellationToken cancellationToken = default) {
153161
// Make sure that the requests are set to to role "user".
154162
List<ModelContent> fixedRequests = requestContent.Select(VertexAIExtensions.ConvertToUser).ToList();
155163
// Set up the context to send in the request
@@ -161,7 +169,7 @@ private async IAsyncEnumerable<GenerateContentResponse> SendMessageStreamAsyncIn
161169
// Note: GenerateContentStreamAsync can throw exceptions if there was a problem, but
162170
// we allow it to just be passed back to the user.
163171
await foreach (GenerateContentResponse response in
164-
generativeModel.GenerateContentStreamAsync(fullRequest)) {
172+
generativeModel.GenerateContentStreamAsync(fullRequest, cancellationToken)) {
165173
// If the response had a problem, we still want to pass it along to the user for context,
166174
// but we don't want to save the history anymore.
167175
if (response.Candidates.Any()) {

vertexai/src/GenerativeModel.cs

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
using System.IO;
2020
using System.Linq;
2121
using System.Net.Http;
22+
using System.Runtime.CompilerServices;
2223
using System.Text;
24+
using System.Threading;
2325
using System.Threading.Tasks;
2426
using Google.MiniJSON;
2527
using Firebase.VertexAI.Internal;
@@ -81,94 +83,102 @@ internal GenerativeModel(FirebaseApp firebaseApp,
8183
/// <summary>
8284
/// Generates new content from input `ModelContent` given to the model as a prompt.
8385
/// </summary>
84-
/// <param name="content">The input(s) given to the model as a prompt.</param>
86+
/// <param name="content">The input given to the model as a prompt.</param>
87+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
8588
/// <returns>The generated content response from the model.</returns>
8689
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
8790
public Task<GenerateContentResponse> GenerateContentAsync(
88-
params ModelContent[] content) {
89-
return GenerateContentAsync((IEnumerable<ModelContent>)content);
91+
ModelContent content, CancellationToken cancellationToken = default) {
92+
return GenerateContentAsync(new[] { content }, cancellationToken);
9093
}
9194
/// <summary>
9295
/// Generates new content from input text given to the model as a prompt.
9396
/// </summary>
9497
/// <param name="text">The text given to the model as a prompt.</param>
98+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
9599
/// <returns>The generated content response from the model.</returns>
96100
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
97101
public Task<GenerateContentResponse> GenerateContentAsync(
98-
string text) {
99-
return GenerateContentAsync(new ModelContent[] { ModelContent.Text(text) });
102+
string text, CancellationToken cancellationToken = default) {
103+
return GenerateContentAsync(new[] { ModelContent.Text(text) }, cancellationToken);
100104
}
101105
/// <summary>
102106
/// Generates new content from input `ModelContent` given to the model as a prompt.
103107
/// </summary>
104-
/// <param name="content">The input(s) given to the model as a prompt.</param>
108+
/// <param name="content">The input given to the model as a prompt.</param>
109+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
105110
/// <returns>The generated content response from the model.</returns>
106111
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
107112
public Task<GenerateContentResponse> GenerateContentAsync(
108-
IEnumerable<ModelContent> content) {
109-
return GenerateContentAsyncInternal(content);
113+
IEnumerable<ModelContent> content, CancellationToken cancellationToken = default) {
114+
return GenerateContentAsyncInternal(content, cancellationToken);
110115
}
111116

112117
/// <summary>
113118
/// Generates new content as a stream from input `ModelContent` given to the model as a prompt.
114119
/// </summary>
115-
/// <param name="content">The input(s) given to the model as a prompt.</param>
120+
/// <param name="content">The input given to the model as a prompt.</param>
121+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
116122
/// <returns>A stream of generated content responses from the model.</returns>
117123
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
118124
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
119-
params ModelContent[] content) {
120-
return GenerateContentStreamAsync((IEnumerable<ModelContent>)content);
125+
ModelContent content, CancellationToken cancellationToken = default) {
126+
return GenerateContentStreamAsync(new[] { content }, cancellationToken);
121127
}
122128
/// <summary>
123129
/// Generates new content as a stream from input text given to the model as a prompt.
124130
/// </summary>
125131
/// <param name="text">The text given to the model as a prompt.</param>
132+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
126133
/// <returns>A stream of generated content responses from the model.</returns>
127134
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
128135
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
129-
string text) {
130-
return GenerateContentStreamAsync(new ModelContent[] { ModelContent.Text(text) });
136+
string text, CancellationToken cancellationToken = default) {
137+
return GenerateContentStreamAsync(new[] { ModelContent.Text(text) }, cancellationToken);
131138
}
132139
/// <summary>
133140
/// Generates new content as a stream from input `ModelContent` given to the model as a prompt.
134141
/// </summary>
135-
/// <param name="content">The input(s) given to the model as a prompt.</param>
142+
/// <param name="content">The input given to the model as a prompt.</param>
143+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
136144
/// <returns>A stream of generated content responses from the model.</returns>
137145
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
138146
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
139-
IEnumerable<ModelContent> content) {
140-
return GenerateContentStreamAsyncInternal(content);
147+
IEnumerable<ModelContent> content, CancellationToken cancellationToken = default) {
148+
return GenerateContentStreamAsyncInternal(content, cancellationToken);
141149
}
142150

143151
/// <summary>
144152
/// Counts the number of tokens in a prompt using the model's tokenizer.
145153
/// </summary>
146-
/// <param name="content">The input(s) given to the model as a prompt.</param>
154+
/// <param name="content">The input given to the model as a prompt.</param>
147155
/// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
148156
/// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
149157
public Task<CountTokensResponse> CountTokensAsync(
150-
params ModelContent[] content) {
151-
return CountTokensAsync((IEnumerable<ModelContent>)content);
158+
ModelContent content, CancellationToken cancellationToken = default) {
159+
return CountTokensAsync(new[] { content }, cancellationToken);
152160
}
153161
/// <summary>
154162
/// Counts the number of tokens in a prompt using the model's tokenizer.
155163
/// </summary>
156164
/// <param name="text">The text input given to the model as a prompt.</param>
165+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
157166
/// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
158167
/// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
159168
public Task<CountTokensResponse> CountTokensAsync(
160-
string text) {
161-
return CountTokensAsync(new ModelContent[] { ModelContent.Text(text) });
169+
string text, CancellationToken cancellationToken = default) {
170+
return CountTokensAsync(new[] { ModelContent.Text(text) }, cancellationToken);
162171
}
163172
/// <summary>
164173
/// Counts the number of tokens in a prompt using the model's tokenizer.
165174
/// </summary>
166-
/// <param name="content">The input(s) given to the model as a prompt.</param>
175+
/// <param name="content">The input given to the model as a prompt.</param>
176+
/// <param name="cancellationToken">An optional token to cancel the operation.</param>
167177
/// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
168178
/// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
169179
public Task<CountTokensResponse> CountTokensAsync(
170-
IEnumerable<ModelContent> content) {
171-
return CountTokensAsyncInternal(content);
180+
IEnumerable<ModelContent> content, CancellationToken cancellationToken = default) {
181+
return CountTokensAsyncInternal(content, cancellationToken);
172182
}
173183

174184
/// <summary>
@@ -188,7 +198,8 @@ public Chat StartChat(IEnumerable<ModelContent> history) {
188198
#endregion
189199

190200
private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
191-
IEnumerable<ModelContent> content) {
201+
IEnumerable<ModelContent> content,
202+
CancellationToken cancellationToken) {
192203
HttpRequestMessage request = new(HttpMethod.Post, GetURL() + ":generateContent");
193204

194205
// Set the request headers
@@ -204,7 +215,7 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
204215

205216
HttpResponseMessage response;
206217
try {
207-
response = await _httpClient.SendAsync(request);
218+
response = await _httpClient.SendAsync(request, cancellationToken);
208219
response.EnsureSuccessStatusCode();
209220
} catch (TaskCanceledException e) when (e.InnerException is TimeoutException) {
210221
throw new VertexAIRequestTimeoutException("Request timed out.", e);
@@ -223,7 +234,8 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
223234
}
224235

225236
private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsyncInternal(
226-
IEnumerable<ModelContent> content) {
237+
IEnumerable<ModelContent> content,
238+
[EnumeratorCancellation] CancellationToken cancellationToken) {
227239
HttpRequestMessage request = new(HttpMethod.Post, GetURL() + ":streamGenerateContent?alt=sse");
228240

229241
// Set the request headers
@@ -239,7 +251,7 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
239251

240252
HttpResponseMessage response;
241253
try {
242-
response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
254+
response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
243255
response.EnsureSuccessStatusCode();
244256
} catch (TaskCanceledException e) when (e.InnerException is TimeoutException) {
245257
throw new VertexAIRequestTimeoutException("Request timed out.", e);
@@ -266,7 +278,8 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
266278
}
267279

268280
private async Task<CountTokensResponse> CountTokensAsyncInternal(
269-
IEnumerable<ModelContent> content) {
281+
IEnumerable<ModelContent> content,
282+
CancellationToken cancellationToken) {
270283
HttpRequestMessage request = new(HttpMethod.Post, GetURL() + ":countTokens");
271284

272285
// Set the request headers
@@ -282,7 +295,7 @@ private async Task<CountTokensResponse> CountTokensAsyncInternal(
282295

283296
HttpResponseMessage response;
284297
try {
285-
response = await _httpClient.SendAsync(request);
298+
response = await _httpClient.SendAsync(request, cancellationToken);
286299
response.EnsureSuccessStatusCode();
287300
} catch (TaskCanceledException e) when (e.InnerException is TimeoutException) {
288301
throw new VertexAIRequestTimeoutException("Request timed out.", e);

0 commit comments

Comments
 (0)