Skip to content

Commit b187bbe

Browse files
authored
VertexAI - Implement Chat (#1209)
* VertexAI - Implement Chat, with tests * Update Chat.cs
1 parent 8e999b0 commit b187bbe

File tree

3 files changed

+315
-31
lines changed

3 files changed

+315
-31
lines changed

vertexai/src/Chat.cs

Lines changed: 152 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,61 +16,185 @@
1616

1717
using System;
1818
using System.Collections.Generic;
19+
using System.Collections.ObjectModel;
20+
using System.Linq;
1921
using System.Threading.Tasks;
2022

2123
namespace Firebase.VertexAI {
2224

25+
/// <summary>
26+
/// An object that represents a back-and-forth chat with a model, capturing the history and saving
27+
/// the context in memory between each message sent.
28+
/// </summary>
2329
public class Chat {
30+
private readonly GenerativeModel generativeModel;
31+
private readonly List<ModelContent> chatHistory;
2432

25-
public IEnumerable<ModelContent> History { get; }
33+
/// <summary>
34+
/// The previous content from the chat that has been successfully sent and received from the
35+
/// model. This will be provided to the model for each message sent as context for the discussion.
36+
/// </summary>
37+
public IEnumerable<ModelContent> History => new ReadOnlyCollection<ModelContent>(chatHistory);
2638

27-
// Note: The generation functions are the same as the ones in GenerativeModel
39+
// Note: No public constructor, get one through GenerativeModel.StartChat
40+
private Chat(GenerativeModel model, IEnumerable<ModelContent> initialHistory) {
41+
generativeModel = model;
42+
43+
if (initialHistory != null) {
44+
chatHistory = new List<ModelContent>(initialHistory);
45+
} else {
46+
chatHistory = new List<ModelContent>();
47+
}
48+
}
2849

29-
public Task<GenerateContentResponse> GenerateContentAsync(
50+
/// <summary>
51+
/// Intended for internal use only.
52+
/// Use `GenerativeModel.StartChat` instead to ensure proper initialization and configuration of the `Chat`.
53+
/// </summary>
54+
internal static Chat InternalCreateChat(GenerativeModel model, IEnumerable<ModelContent> initialHistory) {
55+
return new Chat(model, initialHistory);
56+
}
57+
58+
/// <summary>
59+
/// Sends a message using the existing history of this chat as context. If successful, the message
60+
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
61+
/// </summary>
62+
/// <param name="content">The input(s) given to the model as a prompt.</param>
63+
/// <returns>The model's response if no error occurred.</returns>
64+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
65+
public Task<GenerateContentResponse> SendMessageAsync(
3066
params ModelContent[] content) {
31-
throw new NotImplementedException();
67+
return SendMessageAsync((IEnumerable<ModelContent>)content);
3268
}
33-
public Task<GenerateContentResponse> GenerateContentAsync(
69+
/// <summary>
70+
/// Sends a message using the existing history of this chat as context. If successful, the message
71+
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
72+
/// </summary>
73+
/// <param name="text">The text given to the model as a prompt.</param>
74+
/// <returns>The model's response if no error occurred.</returns>
75+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
76+
public Task<GenerateContentResponse> SendMessageAsync(
3477
string text) {
35-
throw new NotImplementedException();
78+
return SendMessageAsync(new ModelContent[] { ModelContent.Text(text) });
3679
}
37-
public Task<GenerateContentResponse> GenerateContentAsync(
80+
/// <summary>
81+
/// Sends a message using the existing history of this chat as context. If successful, the message
82+
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
83+
/// </summary>
84+
/// <param name="content">The input(s) given to the model as a prompt.</param>
85+
/// <returns>The model's response if no error occurred.</returns>
86+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
87+
public Task<GenerateContentResponse> SendMessageAsync(
3888
IEnumerable<ModelContent> content) {
39-
throw new NotImplementedException();
89+
return SendMessageAsyncInternal(content);
4090
}
4191

42-
// The build logic isn't able to resolve IAsyncEnumerable for some reason, even
43-
// though it is usable in Unity 2021.3. Will need to investigate further.
44-
/*
45-
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
92+
/// <summary>
93+
/// Sends a message using the existing history of this chat as context. If successful, the message
94+
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
95+
/// </summary>
96+
/// <param name="content">The input(s) given to the model as a prompt.</param>
97+
/// <returns>A stream of generated content responses from the model.</returns>
98+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
99+
public IAsyncEnumerable<GenerateContentResponse> SendMessageStreamAsync(
46100
params ModelContent[] content) {
47-
throw new NotImplementedException();
101+
return SendMessageStreamAsync((IEnumerable<ModelContent>)content);
48102
}
49-
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
103+
/// <summary>
104+
/// Sends a message using the existing history of this chat as context. If successful, the message
105+
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
106+
/// </summary>
107+
/// <param name="text">The text given to the model as a prompt.</param>
108+
/// <returns>A stream of generated content responses from the model.</returns>
109+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
110+
public IAsyncEnumerable<GenerateContentResponse> SendMessageStreamAsync(
50111
string text) {
51-
throw new NotImplementedException();
112+
return SendMessageStreamAsync(new ModelContent[] { ModelContent.Text(text) });
52113
}
53-
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
114+
/// <summary>
115+
/// Sends a message using the existing history of this chat as context. If successful, the message
116+
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
117+
/// </summary>
118+
/// <param name="content">The input(s) given to the model as a prompt.</param>
119+
/// <returns>A stream of generated content responses from the model.</returns>
120+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
121+
public IAsyncEnumerable<GenerateContentResponse> SendMessageStreamAsync(
54122
IEnumerable<ModelContent> content) {
55-
throw new NotImplementedException();
123+
return SendMessageStreamAsyncInternal(content);
56124
}
57-
*/
58125

59-
public Task<CountTokensResponse> CountTokensAsync(
60-
params ModelContent[] content) {
61-
throw new NotImplementedException();
126+
private ModelContent GuaranteeRole(ModelContent content, string role) {
127+
if (content.Role == role) {
128+
return content;
129+
} else {
130+
return new ModelContent(role, content.Parts);
131+
}
62132
}
63-
public Task<CountTokensResponse> CountTokensAsync(
64-
string text) {
65-
throw new NotImplementedException();
133+
134+
private ModelContent GuaranteeUser(ModelContent content) {
135+
return GuaranteeRole(content, "user");
66136
}
67-
public Task<CountTokensResponse> CountTokensAsync(
68-
IEnumerable<ModelContent> content) {
69-
throw new NotImplementedException();
137+
138+
private ModelContent GuaranteeModel(ModelContent content) {
139+
return GuaranteeRole(content, "model");
70140
}
71141

72-
// Note: No public constructor, get one through GenerativeModel.StartChat
142+
private async Task<GenerateContentResponse> SendMessageAsyncInternal(
143+
IEnumerable<ModelContent> requestContent) {
144+
// Make sure that the requests are set to to role "user".
145+
List<ModelContent> fixedRequests = requestContent.Select(GuaranteeUser).ToList();
146+
// Set up the context to send in the request
147+
List<ModelContent> fullRequest = new(chatHistory);
148+
fullRequest.AddRange(fixedRequests);
149+
150+
// Note: GenerateContentAsync can throw exceptions if there was a problem, but
151+
// we allow it to just be passed back to the user.
152+
GenerateContentResponse response = await generativeModel.GenerateContentAsync(fullRequest);
153+
154+
// Only after getting a valid response, add both to the history for later.
155+
// But either way pass the response along to the user.
156+
if (response.Candidates.Any()) {
157+
ModelContent responseContent = response.Candidates.First().Content;
158+
159+
chatHistory.AddRange(fixedRequests);
160+
chatHistory.Add(GuaranteeModel(responseContent));
161+
}
162+
163+
return response;
164+
}
165+
166+
private async IAsyncEnumerable<GenerateContentResponse> SendMessageStreamAsyncInternal(
167+
IEnumerable<ModelContent> requestContent) {
168+
// Make sure that the requests are set to to role "user".
169+
List<ModelContent> fixedRequests = requestContent.Select(GuaranteeUser).ToList();
170+
// Set up the context to send in the request
171+
List<ModelContent> fullRequest = new(chatHistory);
172+
fullRequest.AddRange(fixedRequests);
173+
174+
List<ModelContent> responseContents = new();
175+
bool saveHistory = true;
176+
// Note: GenerateContentStreamAsync can throw exceptions if there was a problem, but
177+
// we allow it to just be passed back to the user.
178+
await foreach (GenerateContentResponse response in
179+
generativeModel.GenerateContentStreamAsync(fullRequest)) {
180+
// If the response had a problem, we still want to pass it along to the user for context,
181+
// but we don't want to save the history anymore.
182+
if (response.Candidates.Any()) {
183+
ModelContent responseContent = response.Candidates.First().Content;
184+
responseContents.Add(GuaranteeModel(responseContent));
185+
} else {
186+
saveHistory = false;
187+
}
73188

189+
yield return response;
190+
}
191+
192+
// After getting all the responses, and they were all valid, add everything to the history
193+
if (saveHistory) {
194+
chatHistory.AddRange(fixedRequests);
195+
chatHistory.AddRange(responseContents);
196+
}
197+
}
74198
}
75199

76200
}

vertexai/src/GenerativeModel.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,19 @@ public Task<CountTokensResponse> CountTokensAsync(
151151
return CountTokensAsyncInternal(content);
152152
}
153153

154+
/// <summary>
155+
/// Creates a new chat conversation using this model with the provided history.
156+
/// </summary>
157+
/// <param name="history">Initial content history to start with.</param>
154158
public Chat StartChat(params ModelContent[] history) {
155159
return StartChat((IEnumerable<ModelContent>)history);
156160
}
161+
/// <summary>
162+
/// Creates a new chat conversation using this model with the provided history.
163+
/// </summary>
164+
/// <param name="history">Initial content history to start with.</param>
157165
public Chat StartChat(IEnumerable<ModelContent> history) {
158-
// TODO: Implementation
159-
throw new NotImplementedException();
166+
return Chat.InternalCreateChat(this, history);
160167
}
161168
#endregion
162169

0 commit comments

Comments
 (0)