19
19
using System . IO ;
20
20
using System . Linq ;
21
21
using System . Net . Http ;
22
+ using System . Runtime . CompilerServices ;
22
23
using System . Text ;
24
+ using System . Threading ;
23
25
using System . Threading . Tasks ;
24
26
using Google . MiniJSON ;
25
27
using Firebase . VertexAI . Internal ;
@@ -81,94 +83,102 @@ internal GenerativeModel(FirebaseApp firebaseApp,
81
83
/// <summary>
82
84
/// Generates new content from input `ModelContent` given to the model as a prompt.
83
85
/// </summary>
84
- /// <param name="content">The input(s) given to the model as a prompt.</param>
86
+ /// <param name="content">The input given to the model as a prompt.</param>
87
+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
85
88
/// <returns>The generated content response from the model.</returns>
86
89
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
87
90
public Task < GenerateContentResponse > GenerateContentAsync (
88
- params ModelContent [ ] content ) {
89
- return GenerateContentAsync ( ( IEnumerable < ModelContent > ) content ) ;
91
+ ModelContent content , CancellationToken cancellationToken = default ) {
92
+ return GenerateContentAsync ( new [ ] { content } , cancellationToken ) ;
90
93
}
91
94
/// <summary>
92
95
/// Generates new content from input text given to the model as a prompt.
93
96
/// </summary>
94
97
/// <param name="text">The text given to the model as a prompt.</param>
98
+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
95
99
/// <returns>The generated content response from the model.</returns>
96
100
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
97
101
public Task < GenerateContentResponse > GenerateContentAsync (
98
- string text ) {
99
- return GenerateContentAsync ( new ModelContent [ ] { ModelContent . Text ( text ) } ) ;
102
+ string text , CancellationToken cancellationToken = default ) {
103
+ return GenerateContentAsync ( new [ ] { ModelContent . Text ( text ) } , cancellationToken ) ;
100
104
}
101
105
/// <summary>
102
106
/// Generates new content from input `ModelContent` given to the model as a prompt.
103
107
/// </summary>
104
- /// <param name="content">The input(s) given to the model as a prompt.</param>
108
+ /// <param name="content">The input given to the model as a prompt.</param>
109
+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
105
110
/// <returns>The generated content response from the model.</returns>
106
111
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
107
112
public Task < GenerateContentResponse > GenerateContentAsync (
108
- IEnumerable < ModelContent > content ) {
109
- return GenerateContentAsyncInternal ( content ) ;
113
+ IEnumerable < ModelContent > content , CancellationToken cancellationToken = default ) {
114
+ return GenerateContentAsyncInternal ( content , cancellationToken ) ;
110
115
}
111
116
112
117
/// <summary>
113
118
/// Generates new content as a stream from input `ModelContent` given to the model as a prompt.
114
119
/// </summary>
115
- /// <param name="content">The input(s) given to the model as a prompt.</param>
120
+ /// <param name="content">The input given to the model as a prompt.</param>
121
+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
116
122
/// <returns>A stream of generated content responses from the model.</returns>
117
123
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
118
124
public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
119
- params ModelContent [ ] content ) {
120
- return GenerateContentStreamAsync ( ( IEnumerable < ModelContent > ) content ) ;
125
+ ModelContent content , CancellationToken cancellationToken = default ) {
126
+ return GenerateContentStreamAsync ( new [ ] { content } , cancellationToken ) ;
121
127
}
122
128
/// <summary>
123
129
/// Generates new content as a stream from input text given to the model as a prompt.
124
130
/// </summary>
125
131
/// <param name="text">The text given to the model as a prompt.</param>
132
+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
126
133
/// <returns>A stream of generated content responses from the model.</returns>
127
134
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
128
135
public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
129
- string text ) {
130
- return GenerateContentStreamAsync ( new ModelContent [ ] { ModelContent . Text ( text ) } ) ;
136
+ string text , CancellationToken cancellationToken = default ) {
137
+ return GenerateContentStreamAsync ( new [ ] { ModelContent . Text ( text ) } , cancellationToken ) ;
131
138
}
132
139
/// <summary>
133
140
/// Generates new content as a stream from input `ModelContent` given to the model as a prompt.
134
141
/// </summary>
135
- /// <param name="content">The input(s) given to the model as a prompt.</param>
142
+ /// <param name="content">The input given to the model as a prompt.</param>
143
+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
136
144
/// <returns>A stream of generated content responses from the model.</returns>
137
145
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
138
146
public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
139
- IEnumerable < ModelContent > content ) {
140
- return GenerateContentStreamAsyncInternal ( content ) ;
147
+ IEnumerable < ModelContent > content , CancellationToken cancellationToken = default ) {
148
+ return GenerateContentStreamAsyncInternal ( content , cancellationToken ) ;
141
149
}
142
150
143
151
/// <summary>
144
152
/// Counts the number of tokens in a prompt using the model's tokenizer.
145
153
/// </summary>
146
- /// <param name="content">The input(s) given to the model as a prompt.</param>
154
+ /// <param name="content">The input given to the model as a prompt.</param>
147
155
/// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
148
156
/// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
149
157
public Task < CountTokensResponse > CountTokensAsync (
150
- params ModelContent [ ] content ) {
151
- return CountTokensAsync ( ( IEnumerable < ModelContent > ) content ) ;
158
+ ModelContent content , CancellationToken cancellationToken = default ) {
159
+ return CountTokensAsync ( new [ ] { content } , cancellationToken ) ;
152
160
}
153
161
/// <summary>
154
162
/// Counts the number of tokens in a prompt using the model's tokenizer.
155
163
/// </summary>
156
164
/// <param name="text">The text input given to the model as a prompt.</param>
165
+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
157
166
/// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
158
167
/// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
159
168
public Task < CountTokensResponse > CountTokensAsync (
160
- string text ) {
161
- return CountTokensAsync ( new ModelContent [ ] { ModelContent . Text ( text ) } ) ;
169
+ string text , CancellationToken cancellationToken = default ) {
170
+ return CountTokensAsync ( new [ ] { ModelContent . Text ( text ) } , cancellationToken ) ;
162
171
}
163
172
/// <summary>
164
173
/// Counts the number of tokens in a prompt using the model's tokenizer.
165
174
/// </summary>
166
- /// <param name="content">The input(s) given to the model as a prompt.</param>
175
+ /// <param name="content">The input given to the model as a prompt.</param>
176
+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
167
177
/// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
168
178
/// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
169
179
public Task < CountTokensResponse > CountTokensAsync (
170
- IEnumerable < ModelContent > content ) {
171
- return CountTokensAsyncInternal ( content ) ;
180
+ IEnumerable < ModelContent > content , CancellationToken cancellationToken = default ) {
181
+ return CountTokensAsyncInternal ( content , cancellationToken ) ;
172
182
}
173
183
174
184
/// <summary>
@@ -188,7 +198,8 @@ public Chat StartChat(IEnumerable<ModelContent> history) {
188
198
#endregion
189
199
190
200
private async Task < GenerateContentResponse > GenerateContentAsyncInternal (
191
- IEnumerable < ModelContent > content ) {
201
+ IEnumerable < ModelContent > content ,
202
+ CancellationToken cancellationToken ) {
192
203
HttpRequestMessage request = new ( HttpMethod . Post , GetURL ( ) + ":generateContent" ) ;
193
204
194
205
// Set the request headers
@@ -204,7 +215,7 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
204
215
205
216
HttpResponseMessage response ;
206
217
try {
207
- response = await _httpClient . SendAsync ( request ) ;
218
+ response = await _httpClient . SendAsync ( request , cancellationToken ) ;
208
219
response . EnsureSuccessStatusCode ( ) ;
209
220
} catch ( TaskCanceledException e ) when ( e . InnerException is TimeoutException ) {
210
221
throw new VertexAIRequestTimeoutException ( "Request timed out." , e ) ;
@@ -223,7 +234,8 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
223
234
}
224
235
225
236
private async IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsyncInternal (
226
- IEnumerable < ModelContent > content ) {
237
+ IEnumerable < ModelContent > content ,
238
+ [ EnumeratorCancellation ] CancellationToken cancellationToken ) {
227
239
HttpRequestMessage request = new ( HttpMethod . Post , GetURL ( ) + ":streamGenerateContent?alt=sse" ) ;
228
240
229
241
// Set the request headers
@@ -239,7 +251,7 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
239
251
240
252
HttpResponseMessage response ;
241
253
try {
242
- response = await _httpClient . SendAsync ( request , HttpCompletionOption . ResponseHeadersRead ) ;
254
+ response = await _httpClient . SendAsync ( request , HttpCompletionOption . ResponseHeadersRead , cancellationToken ) ;
243
255
response . EnsureSuccessStatusCode ( ) ;
244
256
} catch ( TaskCanceledException e ) when ( e . InnerException is TimeoutException ) {
245
257
throw new VertexAIRequestTimeoutException ( "Request timed out." , e ) ;
@@ -266,7 +278,8 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
266
278
}
267
279
268
280
private async Task < CountTokensResponse > CountTokensAsyncInternal (
269
- IEnumerable < ModelContent > content ) {
281
+ IEnumerable < ModelContent > content ,
282
+ CancellationToken cancellationToken ) {
270
283
HttpRequestMessage request = new ( HttpMethod . Post , GetURL ( ) + ":countTokens" ) ;
271
284
272
285
// Set the request headers
@@ -282,7 +295,7 @@ private async Task<CountTokensResponse> CountTokensAsyncInternal(
282
295
283
296
HttpResponseMessage response ;
284
297
try {
285
- response = await _httpClient . SendAsync ( request ) ;
298
+ response = await _httpClient . SendAsync ( request , cancellationToken ) ;
286
299
response . EnsureSuccessStatusCode ( ) ;
287
300
} catch ( TaskCanceledException e ) when ( e . InnerException is TimeoutException ) {
288
301
throw new VertexAIRequestTimeoutException ( "Request timed out." , e ) ;
0 commit comments