Skip to content

Commit f34c5a5

Browse files
authored
VertexAI - Add more functionality around options in, and feedback out (#1192)
* VertexAI - Add logic for text input/output * Remove debug logs * Update GenerativeModel.cs * Implement more of the VertexAI setting structs * Minor updates * Fix SafetyRatings missing in PromptFeedback * Comments for Citation * Fix merge issues * Comment cleanup * More comment cleanup
1 parent f5863f8 commit f34c5a5

File tree

8 files changed

+853
-88
lines changed

8 files changed

+853
-88
lines changed

vertexai/src/Candidate.cs

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,54 @@
1717
using System;
1818
using System.Collections.Generic;
1919
using System.Collections.ObjectModel;
20+
using System.Linq;
21+
using Firebase.VertexAI.Internal;
2022

2123
namespace Firebase.VertexAI {
2224

25+
/// <summary>
26+
/// Represents the reason why the model stopped generating content.
27+
/// </summary>
2328
public enum FinishReason {
24-
Unknown,
29+
/// <summary>
30+
/// A new and not yet supported value.
31+
/// </summary>
32+
Unknown = 0,
33+
/// <summary>
34+
/// Natural stop point of the model or provided stop sequence.
35+
/// </summary>
2536
Stop,
37+
/// <summary>
38+
/// The maximum number of tokens as specified in the request was reached.
39+
/// </summary>
2640
MaxTokens,
41+
/// <summary>
42+
/// The token generation was stopped because the response was flagged for safety reasons.
43+
/// </summary>
2744
Safety,
45+
/// <summary>
46+
/// The token generation was stopped because the response was flagged for unauthorized citations.
47+
/// </summary>
2848
Recitation,
49+
/// <summary>
50+
/// All other reasons that stopped token generation.
51+
/// </summary>
2952
Other,
53+
/// <summary>
54+
/// Token generation was stopped because the response contained forbidden terms.
55+
/// </summary>
3056
Blocklist,
57+
/// <summary>
58+
/// Token generation was stopped because the response contained potentially prohibited content.
59+
/// </summary>
3160
ProhibitedContent,
61+
/// <summary>
62+
/// Token generation was stopped because of Sensitive Personally Identifiable Information (SPII).
63+
/// </summary>
3264
SPII,
65+
/// <summary>
66+
/// Token generation was stopped because the function call generated by the model was invalid.
67+
/// </summary>
3368
MalformedFunctionCall,
3469
}
3570

@@ -62,27 +97,37 @@ public readonly struct Candidate {
6297
/// </summary>
6398
public CitationMetadata? CitationMetadata { get; }
6499

65-
// Hidden constructor, users don't need to make this, though they still technically can.
66-
internal Candidate(ModelContent content, List<SafetyRating> safetyRatings,
100+
// Hidden constructor, users don't need to make this.
101+
private Candidate(ModelContent content, List<SafetyRating> safetyRatings,
67102
FinishReason? finishReason, CitationMetadata? citationMetadata) {
68103
Content = content;
69104
_safetyRatings = new ReadOnlyCollection<SafetyRating>(safetyRatings ?? new List<SafetyRating>());
70105
FinishReason = finishReason;
71106
CitationMetadata = citationMetadata;
72107
}
73108

74-
internal static Candidate FromJson(Dictionary<string, object> jsonDict) {
75-
ModelContent content = new();
76-
if (jsonDict.TryGetValue("content", out object contentObj)) {
77-
if (contentObj is not Dictionary<string, object> contentDict) {
78-
throw new VertexAISerializationException("Invalid JSON format: 'content' is not a dictionary.");
79-
}
80-
// We expect this to be another dictionary to convert
81-
content = ModelContent.FromJson(contentDict);
82-
}
109+
private static FinishReason ParseFinishReason(string str) {
110+
return str switch {
111+
"STOP" => Firebase.VertexAI.FinishReason.Stop,
112+
"MAX_TOKENS" => Firebase.VertexAI.FinishReason.MaxTokens,
113+
"SAFETY" => Firebase.VertexAI.FinishReason.Safety,
114+
"RECITATION" => Firebase.VertexAI.FinishReason.Recitation,
115+
"OTHER" => Firebase.VertexAI.FinishReason.Other,
116+
"BLOCKLIST" => Firebase.VertexAI.FinishReason.Blocklist,
117+
"PROHIBITED_CONTENT" => Firebase.VertexAI.FinishReason.ProhibitedContent,
118+
"SPII" => Firebase.VertexAI.FinishReason.SPII,
119+
"MALFORMED_FUNCTION_CALL" => Firebase.VertexAI.FinishReason.MalformedFunctionCall,
120+
_ => Firebase.VertexAI.FinishReason.Unknown,
121+
};
122+
}
83123

84-
// TODO: Parse SafetyRatings, FinishReason, and CitationMetadata
85-
return new Candidate(content, null, null, null);
124+
internal static Candidate FromJson(Dictionary<string, object> jsonDict) {
125+
return new Candidate(
126+
jsonDict.ParseObject("content", ModelContent.FromJson, defaultValue: new ModelContent("model")),
127+
jsonDict.ParseObjectList("safetyRatings", SafetyRating.FromJson),
128+
jsonDict.ParseNullableEnum("finishReason", ParseFinishReason),
129+
jsonDict.ParseNullableObject("citationMetadata",
130+
Firebase.VertexAI.CitationMetadata.FromJson));
86131
}
87132
}
88133

vertexai/src/Citation.cs

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,101 @@
1414
* limitations under the License.
1515
*/
1616

17+
using System;
1718
using System.Collections.Generic;
19+
using System.Collections.ObjectModel;
20+
using Firebase.VertexAI.Internal;
1821

1922
namespace Firebase.VertexAI {
2023

24+
/// <summary>
25+
/// A collection of source attributions for a piece of content.
26+
/// </summary>
2127
public readonly struct CitationMetadata {
22-
public IEnumerable<Citation> Citations { get; }
28+
private readonly ReadOnlyCollection<Citation> _citations;
2329

24-
// Hidden constructor, users don't need to make this
30+
/// <summary>
31+
/// A list of individual cited sources and the parts of the content to which they apply.
32+
/// </summary>
33+
public IEnumerable<Citation> Citations =>
34+
_citations ?? new ReadOnlyCollection<Citation>(new List<Citation>());
35+
36+
// Hidden constructor, users don't need to make this.
37+
private CitationMetadata(List<Citation> citations) {
38+
_citations = new ReadOnlyCollection<Citation>(citations ?? new List<Citation>());
39+
}
40+
41+
internal static CitationMetadata FromJson(Dictionary<string, object> jsonDict) {
42+
return new CitationMetadata(
43+
jsonDict.ParseObjectList("citations", Citation.FromJson));
44+
}
2545
}
2646

47+
/// <summary>
48+
/// A struct describing a source attribution.
49+
/// </summary>
2750
public readonly struct Citation {
51+
/// <summary>
52+
/// The inclusive beginning of a sequence in a model response that derives from a cited source.
53+
/// </summary>
2854
public int StartIndex { get; }
55+
/// <summary>
56+
/// The exclusive end of a sequence in a model response that derives from a cited source.
57+
/// </summary>
2958
public int EndIndex { get; }
59+
/// <summary>
60+
/// A link to the cited source, if available.
61+
/// </summary>
3062
public System.Uri Uri { get; }
63+
/// <summary>
64+
/// The title of the cited source, if available.
65+
/// </summary>
3166
public string Title { get; }
67+
/// <summary>
68+
/// The license the cited source work is distributed under, if specified.
69+
/// </summary>
3270
public string License { get; }
71+
/// <summary>
72+
/// The publication date of the cited source, if available.
73+
/// </summary>
3374
public System.DateTime? PublicationDate { get; }
3475

35-
// Hidden constructor, users don't need to make this
76+
// Hidden constructor, users don't need to make this.
77+
private Citation(int startIndex, int endIndex, Uri uri, string title,
78+
string license, DateTime? publicationDate) {
79+
StartIndex = startIndex;
80+
EndIndex = endIndex;
81+
Uri = uri;
82+
Title = title;
83+
License = license;
84+
PublicationDate = publicationDate;
85+
}
86+
87+
internal static Citation FromJson(Dictionary<string, object> jsonDict) {
88+
// If there is a Uri, need to convert it.
89+
Uri uri = null;
90+
if (jsonDict.TryParseValue("uri", out string uriString)) {
91+
uri = new Uri(uriString);
92+
}
93+
94+
// If there is a publication date, we need to convert it.
95+
DateTime? pubDate = null;
96+
if (jsonDict.TryParseValue("publicationDate", out Dictionary<string, object> dateDict)) {
97+
// Make sure that if any key is missing, it has a default value that will work with DateTime.
98+
pubDate = new DateTime(
99+
dateDict.ParseValue<int>("year", defaultValue: 1),
100+
dateDict.ParseValue<int>("month", defaultValue: 1),
101+
dateDict.ParseValue<int>("day", defaultValue: 1));
102+
}
103+
104+
return new Citation(
105+
jsonDict.ParseValue<int>("startIndex"),
106+
jsonDict.ParseValue<int>("endIndex"),
107+
uri,
108+
jsonDict.ParseValue<string>("title"),
109+
jsonDict.ParseValue<string>("license"),
110+
pubDate);
111+
}
36112
}
37113

38114
}

vertexai/src/GenerateContentResponse.cs

Lines changed: 95 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
using System.Collections.ObjectModel;
1919
using System.Linq;
2020
using Google.MiniJSON;
21+
using Firebase.VertexAI.Internal;
2122

2223
namespace Firebase.VertexAI {
2324

@@ -65,8 +66,8 @@ public IEnumerable<ModelContent.FunctionCallPart> FunctionCalls {
6566
}
6667
}
6768

68-
// Hidden constructor, users don't need to make this, though they still technically can.
69-
internal GenerateContentResponse(List<Candidate> candidates, PromptFeedback? promptFeedback,
69+
// Hidden constructor, users don't need to make this.
70+
private GenerateContentResponse(List<Candidate> candidates, PromptFeedback? promptFeedback,
7071
UsageMetadata? usageMetadata) {
7172
_candidates = new ReadOnlyCollection<Candidate>(candidates ?? new List<Candidate>());
7273
PromptFeedback = promptFeedback;
@@ -78,48 +79,119 @@ internal static GenerateContentResponse FromJson(string jsonString) {
7879
}
7980

8081
internal static GenerateContentResponse FromJson(Dictionary<string, object> jsonDict) {
81-
// Parse the Candidates
82-
List<Candidate> candidates = new();
83-
if (jsonDict.TryGetValue("candidates", out object candidatesObject)) {
84-
if (candidatesObject is not List<object> listOfCandidateObjects) {
85-
throw new VertexAISerializationException("Invalid JSON format: 'candidates' is not a list.");
86-
}
87-
88-
candidates = listOfCandidateObjects
89-
.Select(o => o as Dictionary<string, object>)
90-
.Where(dict => dict != null)
91-
.Select(Candidate.FromJson)
92-
.ToList();
93-
}
94-
95-
// TODO: Parse PromptFeedback and UsageMetadata
96-
97-
return new GenerateContentResponse(candidates, null, null);
82+
return new GenerateContentResponse(
83+
jsonDict.ParseObjectList("candidates", Candidate.FromJson),
84+
jsonDict.ParseNullableObject("promptFeedback",
85+
Firebase.VertexAI.PromptFeedback.FromJson),
86+
jsonDict.ParseNullableObject("usageMetadata",
87+
Firebase.VertexAI.UsageMetadata.FromJson));
9888
}
9989
}
10090

91+
/// <summary>
92+
/// A type describing possible reasons to block a prompt.
93+
/// </summary>
10194
public enum BlockReason {
102-
Unknown,
95+
/// <summary>
96+
/// A new and not yet supported value.
97+
/// </summary>
98+
Unknown = 0,
99+
/// <summary>
100+
/// The prompt was blocked because it was deemed unsafe.
101+
/// </summary>
103102
Safety,
103+
/// <summary>
104+
/// All other block reasons.
105+
/// </summary>
104106
Other,
107+
/// <summary>
108+
/// The prompt was blocked because it contained terms from the terminology blocklist.
109+
/// </summary>
105110
Blocklist,
111+
/// <summary>
112+
/// The prompt was blocked due to prohibited content.
113+
/// </summary>
106114
ProhibitedContent,
107115
}
108116

117+
/// <summary>
118+
/// A metadata struct containing any feedback the model had on the prompt it was provided.
119+
/// </summary>
109120
public readonly struct PromptFeedback {
121+
private readonly ReadOnlyCollection<SafetyRating> _safetyRatings;
122+
123+
/// <summary>
124+
/// The reason a prompt was blocked, if it was blocked.
125+
/// </summary>
110126
public BlockReason? BlockReason { get; }
127+
/// <summary>
128+
/// A human-readable description of the `BlockReason`.
129+
/// </summary>
111130
public string BlockReasonMessage { get; }
112-
public IEnumerable<SafetyRating> SafetyRatings { get; }
131+
/// <summary>
132+
/// The safety ratings of the prompt.
133+
/// </summary>
134+
public IEnumerable<SafetyRating> SafetyRatings =>
135+
_safetyRatings ?? new ReadOnlyCollection<SafetyRating>(new List<SafetyRating>());
136+
137+
// Hidden constructor, users don't need to make this.
138+
private PromptFeedback(BlockReason? blockReason, string blockReasonMessage,
139+
List<SafetyRating> safetyRatings) {
140+
BlockReason = blockReason;
141+
BlockReasonMessage = blockReasonMessage;
142+
_safetyRatings = new ReadOnlyCollection<SafetyRating>(safetyRatings ?? new List<SafetyRating>());
143+
}
113144

114-
// Hidden constructor, users don't need to make this
145+
private static BlockReason ParseBlockReason(string str) {
146+
return str switch {
147+
"SAFETY" => Firebase.VertexAI.BlockReason.Safety,
148+
"OTHER" => Firebase.VertexAI.BlockReason.Other,
149+
"BLOCKLIST" => Firebase.VertexAI.BlockReason.Blocklist,
150+
"PROHIBITED_CONTENT" => Firebase.VertexAI.BlockReason.ProhibitedContent,
151+
_ => Firebase.VertexAI.BlockReason.Unknown,
152+
};
153+
}
154+
155+
internal static PromptFeedback FromJson(Dictionary<string, object> jsonDict) {
156+
return new PromptFeedback(
157+
jsonDict.ParseNullableEnum("blockReason", ParseBlockReason),
158+
jsonDict.ParseValue<string>("blockReasonMessage"),
159+
jsonDict.ParseObjectList("safetyRatings", SafetyRating.FromJson));
160+
}
115161
}
116162

163+
/// <summary>
164+
/// Token usage metadata for processing the generate content request.
165+
/// </summary>
117166
public readonly struct UsageMetadata {
167+
/// <summary>
168+
/// The number of tokens in the request prompt.
169+
/// </summary>
118170
public int PromptTokenCount { get; }
171+
/// <summary>
172+
/// The total number of tokens across the generated response candidates.
173+
/// </summary>
119174
public int CandidatesTokenCount { get; }
175+
/// <summary>
176+
/// The total number of tokens in both the request and response.
177+
/// </summary>
120178
public int TotalTokenCount { get; }
121179

122-
// Hidden constructor, users don't need to make this
180+
// TODO: New fields about ModalityTokenCount
181+
182+
// Hidden constructor, users don't need to make this.
183+
private UsageMetadata(int promptTC, int candidatesTC, int totalTC) {
184+
PromptTokenCount = promptTC;
185+
CandidatesTokenCount = candidatesTC;
186+
TotalTokenCount = totalTC;
187+
}
188+
189+
internal static UsageMetadata FromJson(Dictionary<string, object> jsonDict) {
190+
return new UsageMetadata(
191+
jsonDict.ParseValue<int>("promptTokenCount"),
192+
jsonDict.ParseValue<int>("candidatesTokenCount"),
193+
jsonDict.ParseValue<int>("totalTokenCount"));
194+
}
123195
}
124196

125197
}

0 commit comments

Comments
 (0)