Skip to content

Commit 88cb401

Browse files
authored
VertexAI - Make SystemInstructions set a role (#1212)
1 parent b187bbe commit 88cb401

File tree

4 files changed

+30
-26
lines changed

4 files changed

+30
-26
lines changed

vertexai/src/Chat.cs

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
using System.Collections.ObjectModel;
2020
using System.Linq;
2121
using System.Threading.Tasks;
22+
using Firebase.VertexAI.Internal;
2223

2324
namespace Firebase.VertexAI {
2425

@@ -123,26 +124,10 @@ public IAsyncEnumerable<GenerateContentResponse> SendMessageStreamAsync(
123124
return SendMessageStreamAsyncInternal(content);
124125
}
125126

126-
private ModelContent GuaranteeRole(ModelContent content, string role) {
127-
if (content.Role == role) {
128-
return content;
129-
} else {
130-
return new ModelContent(role, content.Parts);
131-
}
132-
}
133-
134-
private ModelContent GuaranteeUser(ModelContent content) {
135-
return GuaranteeRole(content, "user");
136-
}
137-
138-
private ModelContent GuaranteeModel(ModelContent content) {
139-
return GuaranteeRole(content, "model");
140-
}
141-
142127
private async Task<GenerateContentResponse> SendMessageAsyncInternal(
143128
IEnumerable<ModelContent> requestContent) {
144129
// Make sure that the requests are set to to role "user".
145-
List<ModelContent> fixedRequests = requestContent.Select(GuaranteeUser).ToList();
130+
List<ModelContent> fixedRequests = requestContent.Select(VertexAIExtensions.ConvertToUser).ToList();
146131
// Set up the context to send in the request
147132
List<ModelContent> fullRequest = new(chatHistory);
148133
fullRequest.AddRange(fixedRequests);
@@ -157,7 +142,7 @@ private async Task<GenerateContentResponse> SendMessageAsyncInternal(
157142
ModelContent responseContent = response.Candidates.First().Content;
158143

159144
chatHistory.AddRange(fixedRequests);
160-
chatHistory.Add(GuaranteeModel(responseContent));
145+
chatHistory.Add(responseContent.ConvertToModel());
161146
}
162147

163148
return response;
@@ -166,7 +151,7 @@ private async Task<GenerateContentResponse> SendMessageAsyncInternal(
166151
private async IAsyncEnumerable<GenerateContentResponse> SendMessageStreamAsyncInternal(
167152
IEnumerable<ModelContent> requestContent) {
168153
// Make sure that the requests are set to to role "user".
169-
List<ModelContent> fixedRequests = requestContent.Select(GuaranteeUser).ToList();
154+
List<ModelContent> fixedRequests = requestContent.Select(VertexAIExtensions.ConvertToUser).ToList();
170155
// Set up the context to send in the request
171156
List<ModelContent> fullRequest = new(chatHistory);
172157
fullRequest.AddRange(fixedRequests);
@@ -181,7 +166,7 @@ private async IAsyncEnumerable<GenerateContentResponse> SendMessageStreamAsyncIn
181166
// but we don't want to save the history anymore.
182167
if (response.Candidates.Any()) {
183168
ModelContent responseContent = response.Candidates.First().Content;
184-
responseContents.Add(GuaranteeModel(responseContent));
169+
responseContents.Add(responseContent.ConvertToModel());
185170
} else {
186171
saveHistory = false;
187172
}

vertexai/src/GenerativeModel.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
using System.Text;
2323
using System.Threading.Tasks;
2424
using Google.MiniJSON;
25+
using Firebase.VertexAI.Internal;
2526

2627
namespace Firebase.VertexAI {
2728

@@ -66,7 +67,8 @@ internal GenerativeModel(FirebaseApp firebaseApp,
6667
_safetySettings = safetySettings;
6768
_tools = tools;
6869
_toolConfig = toolConfig;
69-
_systemInstruction = systemInstruction;
70+
// Make sure that the system instructions have the role "system".
71+
_systemInstruction = systemInstruction?.ConvertToSystem();
7072
_requestOptions = requestOptions;
7173

7274
// Create a HttpClient using the timeout requested, or the default one.

vertexai/src/Internal/InternalHelpers.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,26 @@ public static List<T> ConvertJsonList<T>(this List<object> list,
193193
.Select(converter)
194194
.ToList();
195195
}
196+
197+
public static ModelContent ConvertRole(this ModelContent content, string role) {
198+
if (content.Role == role) {
199+
return content;
200+
} else {
201+
return new ModelContent(role, content.Parts);
202+
}
203+
}
204+
205+
public static ModelContent ConvertToUser(this ModelContent content) {
206+
return content.ConvertRole("user");
207+
}
208+
209+
public static ModelContent ConvertToModel(this ModelContent content) {
210+
return content.ConvertRole("model");
211+
}
212+
213+
public static ModelContent ConvertToSystem(this ModelContent content) {
214+
return content.ConvertRole("system");
215+
}
196216
}
197217

198218
}

vertexai/src/ModelContent.cs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,14 @@ public ModelContent(IEnumerable<Part> parts) : this("user", parts) { }
5555
/// <summary>
5656
/// Creates a `ModelContent` with the given role and `Part`s.
5757
/// </summary>
58-
public ModelContent(string role, params Part[] parts) {
59-
_role = role;
60-
_parts = new ReadOnlyCollection<Part>(parts.ToList());
61-
}
58+
public ModelContent(string role, params Part[] parts) : this(role, (IEnumerable<Part>)parts) { }
6259

6360
/// <summary>
6461
/// Creates a `ModelContent` with the given role and `Part`s.
6562
/// </summary>
6663
public ModelContent(string role, IEnumerable<Part> parts) {
6764
_role = role;
68-
_parts = new ReadOnlyCollection<Part>(parts.ToList());
65+
_parts = new ReadOnlyCollection<Part>(parts == null ? new List<Part>() : parts.ToList());
6966
}
7067

7168
#region Helper Factories

0 commit comments

Comments
 (0)