Skip to content

Commit f5863f8

Browse files
authored
VertexAI - Add logic for text input/output (#1189)
* VertexAI - Add logic for text input/output * Remove debug logs * Update GenerativeModel.cs * Address feedback
1 parent ba46a89 commit f5863f8

File tree

6 files changed

+553
-64
lines changed

6 files changed

+553
-64
lines changed

vertexai/src/Candidate.cs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
using System;
1818
using System.Collections.Generic;
19+
using System.Collections.ObjectModel;
1920

2021
namespace Firebase.VertexAI {
2122

@@ -32,11 +33,57 @@ public enum FinishReason {
3233
MalformedFunctionCall,
3334
}
3435

36+
/// <summary>
37+
/// A struct representing a possible reply to a content generation prompt.
38+
/// Each content generation prompt may produce multiple candidate responses.
39+
/// </summary>
3540
public readonly struct Candidate {
41+
private readonly ReadOnlyCollection<SafetyRating> _safetyRatings;
42+
43+
/// <summary>
44+
/// The response’s content.
45+
/// </summary>
3646
public ModelContent Content { get; }
37-
public IEnumerable<SafetyRating> SafetyRatings { get; }
47+
48+
/// <summary>
49+
/// The safety rating of the response content.
50+
/// </summary>
51+
public IEnumerable<SafetyRating> SafetyRatings =>
52+
_safetyRatings ?? new ReadOnlyCollection<SafetyRating>(new List<SafetyRating>());
53+
54+
/// <summary>
55+
/// The reason the model stopped generating content, if it exists;
56+
/// for example, if the model generated a predefined stop sequence.
57+
/// </summary>
3858
public FinishReason? FinishReason { get; }
59+
60+
/// <summary>
61+
/// Cited works in the model’s response content, if it exists.
62+
/// </summary>
3963
public CitationMetadata? CitationMetadata { get; }
64+
65+
// Hidden constructor, users don't need to make this, though they still technically can.
66+
internal Candidate(ModelContent content, List<SafetyRating> safetyRatings,
67+
FinishReason? finishReason, CitationMetadata? citationMetadata) {
68+
Content = content;
69+
_safetyRatings = new ReadOnlyCollection<SafetyRating>(safetyRatings ?? new List<SafetyRating>());
70+
FinishReason = finishReason;
71+
CitationMetadata = citationMetadata;
72+
}
73+
74+
internal static Candidate FromJson(Dictionary<string, object> jsonDict) {
75+
ModelContent content = new();
76+
if (jsonDict.TryGetValue("content", out object contentObj)) {
77+
if (contentObj is not Dictionary<string, object> contentDict) {
78+
throw new VertexAISerializationException("Invalid JSON format: 'content' is not a dictionary.");
79+
}
80+
// We expect this to be another dictionary to convert
81+
content = ModelContent.FromJson(contentDict);
82+
}
83+
84+
// TODO: Parse SafetyRatings, FinishReason, and CitationMetadata
85+
return new Candidate(content, null, null, null);
86+
}
4087
}
4188

4289
}

vertexai/src/GenerateContentResponse.cs

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,87 @@
1515
*/
1616

1717
using System.Collections.Generic;
18+
using System.Collections.ObjectModel;
19+
using System.Linq;
20+
using Google.MiniJSON;
1821

1922
namespace Firebase.VertexAI {
2023

24+
/// <summary>
25+
/// The model's response to a generate content request.
26+
/// </summary>
2127
public readonly struct GenerateContentResponse {
22-
public IEnumerable<Candidate> Candidates { get; }
28+
private readonly ReadOnlyCollection<Candidate> _candidates;
29+
30+
/// <summary>
31+
/// A list of candidate response content, ordered from best to worst.
32+
/// </summary>
33+
public IEnumerable<Candidate> Candidates =>
34+
_candidates ?? new ReadOnlyCollection<Candidate>(new List<Candidate>());
35+
36+
/// <summary>
37+
/// A value containing the safety ratings for the response, or,
38+
/// if the request was blocked, a reason for blocking the request.
39+
/// </summary>
2340
public PromptFeedback? PromptFeedback { get; }
41+
42+
/// <summary>
43+
/// Token usage metadata for processing the generate content request.
44+
/// </summary>
2445
public UsageMetadata? UsageMetadata { get; }
2546

26-
// Helper properties
27-
// The response's content as text, if it exists
28-
public string Text { get; }
47+
/// <summary>
48+
/// The response's content as text, if it exists.
49+
/// </summary>
50+
public string Text {
51+
get {
52+
// Concatenate all of the text parts from the first candidate.
53+
return string.Join(" ",
54+
Candidates.FirstOrDefault().Content.Parts
55+
.OfType<ModelContent.TextPart>().Select(tp => tp.Text));
56+
}
57+
}
58+
59+
/// <summary>
60+
/// Returns function calls found in any `Part`s of the first candidate of the response, if any.
61+
/// </summary>
62+
public IEnumerable<ModelContent.FunctionCallPart> FunctionCalls {
63+
get {
64+
return Candidates.FirstOrDefault().Content.Parts.OfType<ModelContent.FunctionCallPart>();
65+
}
66+
}
67+
68+
// Hidden constructor, users don't need to make this, though they still technically can.
69+
internal GenerateContentResponse(List<Candidate> candidates, PromptFeedback? promptFeedback,
70+
UsageMetadata? usageMetadata) {
71+
_candidates = new ReadOnlyCollection<Candidate>(candidates ?? new List<Candidate>());
72+
PromptFeedback = promptFeedback;
73+
UsageMetadata = usageMetadata;
74+
}
75+
76+
internal static GenerateContentResponse FromJson(string jsonString) {
77+
return FromJson(Json.Deserialize(jsonString) as Dictionary<string, object>);
78+
}
79+
80+
internal static GenerateContentResponse FromJson(Dictionary<string, object> jsonDict) {
81+
// Parse the Candidates
82+
List<Candidate> candidates = new();
83+
if (jsonDict.TryGetValue("candidates", out object candidatesObject)) {
84+
if (candidatesObject is not List<object> listOfCandidateObjects) {
85+
throw new VertexAISerializationException("Invalid JSON format: 'candidates' is not a list.");
86+
}
87+
88+
candidates = listOfCandidateObjects
89+
.Select(o => o as Dictionary<string, object>)
90+
.Where(dict => dict != null)
91+
.Select(Candidate.FromJson)
92+
.ToList();
93+
}
94+
95+
// TODO: Parse PromptFeedback and UsageMetadata
2996

30-
// Returns function calls found in any Parts of the first candidate of the response, if any.
31-
public IEnumerable<ModelContent.FunctionCallPart> FunctionCalls { get; }
97+
return new GenerateContentResponse(candidates, null, null);
98+
}
3299
}
33100

34101
public enum BlockReason {

vertexai/src/GenerativeModel.cs

Lines changed: 147 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,64 +14,190 @@
1414
* limitations under the License.
1515
*/
1616

17+
// For now, using this to hide some functions causing problems with the build.
18+
#define HIDE_IASYNCENUMERABLE
19+
1720
using System;
1821
using System.Collections.Generic;
22+
using System.Linq;
23+
using System.Net.Http;
24+
using System.Text;
1925
using System.Threading.Tasks;
26+
using Google.MiniJSON;
2027

2128
namespace Firebase.VertexAI {
2229

30+
/// <summary>
31+
/// A type that represents a remote multimodal model (like Gemini), with the ability to generate
32+
/// content based on various input types.
33+
/// </summary>
2334
public class GenerativeModel {
35+
private FirebaseApp _firebaseApp;
36+
37+
// Various setting fields provided by the user.
38+
private string _location;
39+
private string _modelName;
40+
private GenerationConfig? _generationConfig;
41+
private SafetySetting[] _safetySettings;
42+
private Tool[] _tools;
43+
private ToolConfig? _toolConfig;
44+
private ModelContent? _systemInstruction;
45+
private RequestOptions? _requestOptions;
46+
47+
HttpClient _httpClient;
48+
49+
internal GenerativeModel(FirebaseApp firebaseApp,
50+
string location,
51+
string modelName,
52+
GenerationConfig? generationConfig = null,
53+
SafetySetting[] safetySettings = null,
54+
Tool[] tools = null,
55+
ToolConfig? toolConfig = null,
56+
ModelContent? systemInstruction = null,
57+
RequestOptions? requestOptions = null) {
58+
_firebaseApp = firebaseApp;
59+
_location = location;
60+
_modelName = modelName;
61+
_generationConfig = generationConfig;
62+
_safetySettings = safetySettings;
63+
_tools = tools;
64+
_toolConfig = toolConfig;
65+
_systemInstruction = systemInstruction;
66+
_requestOptions = requestOptions;
67+
68+
// Create a HttpClient using the timeout requested, or the default one.
69+
_httpClient = new HttpClient() {
70+
Timeout = requestOptions?.Timeout ?? RequestOptions.DefaultTimeout
71+
};
72+
}
73+
74+
#region Public API
75+
/// <summary>
76+
/// Generates new content from input `ModelContent` given to the model as a prompt.
77+
/// </summary>
78+
/// <param name="content">The input(s) given to the model as a prompt.</param>
79+
/// <returns>The generated content response from the model.</returns>
80+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
2481
public Task<GenerateContentResponse> GenerateContentAsync(
2582
params ModelContent[] content) {
26-
throw new NotImplementedException();
83+
return GenerateContentAsync((IEnumerable<ModelContent>)content);
2784
}
85+
/// <summary>
86+
/// Generates new content from input text given to the model as a prompt.
87+
/// </summary>
88+
/// <param name="content">The text given to the model as a prompt.</param>
89+
/// <returns>The generated content response from the model.</returns>
90+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
2891
public Task<GenerateContentResponse> GenerateContentAsync(
29-
IEnumerable<ModelContent> content) {
30-
throw new NotImplementedException();
92+
string text) {
93+
return GenerateContentAsync(new ModelContent[] { ModelContent.Text(text) });
3194
}
95+
/// <summary>
96+
/// Generates new content from input `ModelContent` given to the model as a prompt.
97+
/// </summary>
98+
/// <param name="content">The input(s) given to the model as a prompt.</param>
99+
/// <returns>The generated content response from the model.</returns>
100+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
32101
public Task<GenerateContentResponse> GenerateContentAsync(
33-
string text) {
34-
throw new NotImplementedException();
102+
IEnumerable<ModelContent> content) {
103+
return GenerateContentAsyncInternal(content);
35104
}
36105

37-
// The build logic isn't able to resolve IAsyncEnumerable for some reason, even
38-
// though it is usable in Unity 2021.3. Will need to investigate further.
39-
/*
106+
#if !HIDE_IASYNCENUMERABLE
40107
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
41108
params ModelContent[] content) {
42-
throw new NotImplementedException();
109+
return GenerateContentStreamAsync((IEnumerable<ModelContent>)content);
43110
}
44111
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
45-
IEnumerable<ModelContent> content) {
46-
throw new NotImplementedException();
112+
string text) {
113+
return GenerateContentStreamAsync(new ModelContent[] { ModelContent.Text(text) });
47114
}
48115
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
49-
string text) {
50-
throw new NotImplementedException();
116+
IEnumerable<ModelContent> content) {
117+
return GenerateContentStreamAsyncInternal(content);
51118
}
52-
*/
119+
#endif
53120

54121
public Task<CountTokensResponse> CountTokensAsync(
55122
params ModelContent[] content) {
56-
throw new NotImplementedException();
123+
return CountTokensAsync((IEnumerable<ModelContent>)content);
57124
}
58125
public Task<CountTokensResponse> CountTokensAsync(
59-
IEnumerable<ModelContent> content) {
60-
throw new NotImplementedException();
126+
string text) {
127+
return CountTokensAsync(new ModelContent[] { ModelContent.Text(text) });
61128
}
62129
public Task<CountTokensResponse> CountTokensAsync(
63-
string text) {
64-
throw new NotImplementedException();
130+
IEnumerable<ModelContent> content) {
131+
return CountTokensAsyncInternal(content);
65132
}
66133

67134
public Chat StartChat(params ModelContent[] history) {
68-
throw new NotImplementedException();
135+
return StartChat((IEnumerable<ModelContent>)history);
69136
}
70137
public Chat StartChat(IEnumerable<ModelContent> history) {
138+
// TODO: Implementation
71139
throw new NotImplementedException();
72140
}
141+
#endregion
142+
143+
private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
144+
IEnumerable<ModelContent> content) {
145+
string bodyJson = ModelContentsToJson(content);
146+
147+
HttpRequestMessage request = new(HttpMethod.Post, GetURL() + ":generateContent");
148+
149+
// Set the request headers
150+
request.Headers.Add("x-goog-api-key", _firebaseApp.Options.ApiKey);
151+
request.Headers.Add("x-goog-api-client", "genai-csharp/0.1.0");
152+
153+
// Set the content
154+
request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json");
155+
156+
HttpResponseMessage response = await _httpClient.SendAsync(request);
157+
// TODO: Convert any timeout exception into a VertexAI equivalent
158+
// TODO: Convert any HttpRequestExceptions, see:
159+
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
160+
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
161+
response.EnsureSuccessStatusCode();
162+
163+
string result = await response.Content.ReadAsStringAsync();
164+
165+
return GenerateContentResponse.FromJson(result);
166+
}
73167

74-
// Note: No public constructor, get one through VertexAI.GetGenerativeModel
168+
#if !HIDE_IASYNCENUMERABLE
169+
private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsyncInternal(
170+
IEnumerable<ModelContent> content) {
171+
// TODO: Implementation
172+
await Task.CompletedTask;
173+
yield return new GenerateContentResponse();
174+
throw new NotImplementedException();
175+
}
176+
#endif
177+
178+
private async Task<CountTokensResponse> CountTokensAsyncInternal(
179+
IEnumerable<ModelContent> content) {
180+
// TODO: Implementation
181+
await Task.CompletedTask;
182+
throw new NotImplementedException();
183+
}
184+
185+
private string GetURL() {
186+
return "https://firebaseml.googleapis.com/v2beta" +
187+
"/projects/" + _firebaseApp.Options.ProjectId +
188+
"/locations/" + _location +
189+
"/publishers/google/models/" + _modelName;
190+
}
191+
192+
private string ModelContentsToJson(IEnumerable<ModelContent> contents) {
193+
Dictionary<string, object> jsonDict = new() {
194+
// Convert the Contents into a list of Json dictionaries
195+
["contents"] = contents.Select(c => c.ToJson()).ToList()
196+
};
197+
// TODO: All the other settings
198+
199+
return Json.Serialize(jsonDict);
200+
}
75201
}
76202

77203
}

0 commit comments

Comments
 (0)