Skip to content

Commit fc487d9

Browse files
committed
feat: Add some async extension methods
1 parent 156790d commit fc487d9

File tree

1 file changed

+86
-3
lines changed

1 file changed

+86
-3
lines changed

apis/Google.Cloud.VertexAI.Extensions/Google.Cloud.VertexAI.Extensions/VertexAIExtensions.cs

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
using Google.Cloud.AIPlatform.V1;
1717
using Microsoft.Extensions.AI;
1818
using System;
19+
using System.Threading;
20+
using System.Threading.Tasks;
1921

2022
namespace Google.Cloud.VertexAI.Extensions;
2123

@@ -52,7 +54,8 @@ public static IChatClient AsIChatClient(
5254
/// <returns>An <see cref="IChatClient"/> that wraps the built client.</returns>
5355
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
5456
public static IChatClient BuildIChatClient(
55-
this PredictionServiceClientBuilder builder, IServiceProvider? provider = null, string? defaultModelId = null)
57+
this PredictionServiceClientBuilder builder,
58+
IServiceProvider? provider = null, string? defaultModelId = null)
5659
{
5760
GaxPreconditions.CheckNotNull(builder, nameof(builder));
5861

@@ -63,6 +66,32 @@ public static IChatClient BuildIChatClient(
6366
return client.AsIChatClient(defaultModelId);
6467
}
6568

69+
/// <summary>
70+
/// Builds a <see cref="PredictionServiceClient"/> and creates an <see cref="IChatClient"/> wrapper around it.
71+
/// </summary>
72+
/// <param name="builder">The <see cref="PredictionServiceClientBuilder"/> with which to build the <see cref="PredictionServiceClient"/>.</param>
73+
/// <param name="provider">An optional <see cref="IServiceProvider"/> from which services are requested when building the client.</param>
74+
/// <param name="defaultModelId">
75+
/// The default model ID to use for chat requests if not specified in <see cref="ChatOptions.ModelId"/>.
76+
/// This must be the full resource name of the model, e.g. "projects/{projectId}/locations/{location}/publishers/{publisher}/models/{model}".
77+
/// </param>
78+
/// <param name="cancellationToken">A token to cancel the async operation.</param>
79+
/// <returns>An <see cref="IChatClient"/> that wraps the built client.</returns>
80+
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
81+
public static async Task<IChatClient> BuildIChatClientAsync(
82+
this PredictionServiceClientBuilder builder,
83+
IServiceProvider? provider = null, string? defaultModelId = null,
84+
CancellationToken cancellationToken = default)
85+
{
86+
GaxPreconditions.CheckNotNull(builder, nameof(builder));
87+
88+
PredictionServiceClient client = await (provider is not null ?
89+
builder.BuildAsync(provider, cancellationToken) :
90+
builder.BuildAsync(cancellationToken)).ConfigureAwait(false);
91+
92+
return client.AsIChatClient(defaultModelId);
93+
}
94+
6695
/// <summary>
6796
/// Creates an <see cref="IEmbeddingGenerator{String, Embedding}"/> wrapper around the specified <see cref="PredictionServiceClient"/>.
6897
/// </summary>
@@ -94,7 +123,8 @@ public static IEmbeddingGenerator<string, Embedding<float>> AsIEmbeddingGenerato
94123
/// <returns>An <see cref="IEmbeddingGenerator{String, Embedding}"/> that wraps the built client.</returns>
95124
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
96125
public static IEmbeddingGenerator<string, Embedding<float>> BuildIEmbeddingGenerator(
97-
this PredictionServiceClientBuilder builder, IServiceProvider? provider = null, string? defaultModelId = null)
126+
this PredictionServiceClientBuilder builder,
127+
IServiceProvider? provider = null, string? defaultModelId = null)
98128
{
99129
GaxPreconditions.CheckNotNull(builder, nameof(builder));
100130

@@ -105,6 +135,32 @@ public static IEmbeddingGenerator<string, Embedding<float>> BuildIEmbeddingGener
105135
return client.AsIEmbeddingGenerator(defaultModelId);
106136
}
107137

138+
/// <summary>
139+
/// Builds a <see cref="PredictionServiceClient"/> and creates an <see cref="IEmbeddingGenerator{String, Embedding}"/> wrapper around it.
140+
/// </summary>
141+
/// <param name="builder">The <see cref="PredictionServiceClientBuilder"/> with which to build the <see cref="PredictionServiceClient"/>.</param>
142+
/// <param name="provider">An optional <see cref="IServiceProvider"/> from which services are requested when building the client.</param>
143+
/// <param name="defaultModelId">
144+
/// The default model ID to use for chat requests if not specified in <see cref="EmbeddingGenerationOptions.ModelId"/>.
145+
/// This must be the full resource name of the model, e.g. "projects/{projectId}/locations/{location}/publishers/{publisher}/models/{model}".
146+
/// </param>
147+
/// <param name="cancellationToken">A token to cancel the async operation.</param>
148+
/// <returns>An <see cref="IEmbeddingGenerator{String, Embedding}"/> that wraps the built client.</returns>
149+
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
150+
public static async Task<IEmbeddingGenerator<string, Embedding<float>>> BuildIEmbeddingGeneratorAsync(
151+
this PredictionServiceClientBuilder builder,
152+
IServiceProvider? provider = null, string? defaultModelId = null,
153+
CancellationToken cancellationToken = default)
154+
{
155+
GaxPreconditions.CheckNotNull(builder, nameof(builder));
156+
157+
PredictionServiceClient client = await (provider is not null ?
158+
builder.BuildAsync(provider, cancellationToken) :
159+
builder.BuildAsync(cancellationToken)).ConfigureAwait(false);
160+
161+
return client.AsIEmbeddingGenerator(defaultModelId);
162+
}
163+
108164
/// <summary>
109165
/// Creates an <see cref="IImageGenerator"/> wrapper around the specified <see cref="PredictionServiceClient"/>.
110166
/// </summary>
@@ -135,7 +191,8 @@ public static IImageGenerator AsIImageGenerator(
135191
/// <returns>An <see cref="IImageGenerator"/> that wraps the built client.</returns>
136192
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
137193
public static IImageGenerator BuildIImageGenerator(
138-
this PredictionServiceClientBuilder builder, IServiceProvider? provider = null, string? defaultModelId = null)
194+
this PredictionServiceClientBuilder builder,
195+
IServiceProvider? provider = null, string? defaultModelId = null)
139196
{
140197
GaxPreconditions.CheckNotNull(builder, nameof(builder));
141198

@@ -146,6 +203,32 @@ public static IImageGenerator BuildIImageGenerator(
146203
return client.AsIImageGenerator(defaultModelId);
147204
}
148205

206+
/// <summary>
207+
/// Builds a <see cref="PredictionServiceClient"/> and creates an <see cref="IImageGenerator"/> wrapper around it.
208+
/// </summary>
209+
/// <param name="builder">The <see cref="PredictionServiceClientBuilder"/> with which to build the <see cref="PredictionServiceClient"/>.</param>
210+
/// <param name="defaultModelId">
211+
/// The default model ID to use for chat requests if not specified in <see cref="ImageGenerationOptions.ModelId"/>.
212+
/// This must be the full resource name of the model, e.g. "projects/{projectId}/locations/{location}/publishers/{publisher}/models/{model}".
213+
/// </param>
214+
/// <param name="provider">An optional <see cref="IServiceProvider"/> from which services are requested when building the client.</param>
215+
/// <param name="cancellationToken">A token to cancel the async operation.</param>
216+
/// <returns>An <see cref="IImageGenerator"/> that wraps the built client.</returns>
217+
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
218+
public static async Task<IImageGenerator> BuildIImageGeneratorAsync(
219+
this PredictionServiceClientBuilder builder,
220+
IServiceProvider? provider = null, string? defaultModelId = null,
221+
CancellationToken cancellationToken = default)
222+
{
223+
GaxPreconditions.CheckNotNull(builder, nameof(builder));
224+
225+
PredictionServiceClient client = await (provider is not null ?
226+
builder.BuildAsync(provider, cancellationToken) :
227+
builder.BuildAsync(cancellationToken)).ConfigureAwait(false);
228+
229+
return client.AsIImageGenerator(defaultModelId);
230+
}
231+
149232
/// <summary>Creates an <see cref="AITool"/> to represent a raw <see cref="Tool"/>.</summary>
150233
/// <param name="tool">The tool to wrap as an <see cref="AITool"/>.</param>
151234
/// <returns>The <paramref name="tool"/> wrapped as an <see cref="AITool"/>.</returns>

0 commit comments

Comments
 (0)