Skip to content

Commit 5c784c7

Browse files
committed
Add sendMessage modelContent variations
1 parent bc4ab28 commit 5c784c7

File tree

4 files changed

+107
-15
lines changed

4 files changed

+107
-15
lines changed

FirebaseAI/Sources/TemplateChatSession.swift

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,48 +43,71 @@ public final class TemplateChatSession: Sendable {
4343
/// **Public Preview**: This API is a public preview and may be subject to change.
4444
///
4545
/// - Parameters:
46-
/// - message: The message to send to the model.
46+
/// - content: The message to send to the model.
4747
/// - inputs: A dictionary of variables to substitute into the template.
48-
/// - options: The ``RequestOptions`` for the request, currently used to override default request timeout.
48+
/// - options: The ``RequestOptions`` for the request, currently used to override default
49+
/// request timeout.
4950
/// - Returns: The content generated by the model.
5051
/// - Throws: A ``GenerateContentError`` if the request failed.
51-
public func sendMessage(_ message: any PartsRepresentable,
52+
public func sendMessage(_ content: [ModelContent],
5253
inputs: [String: Any],
5354
options: RequestOptions = RequestOptions()) async throws
5455
-> GenerateContentResponse {
5556
let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) }
56-
let newContent = populateContentRole(ModelContent(parts: message.partsValue))
57+
let newContent = content.map(populateContentRole)
5758
let response = try await model.generateContentWithHistory(
58-
history: _history.history + [newContent],
59+
history: _history.history + newContent,
5960
template: templateID,
6061
inputs: templateInputs,
6162
options: options
6263
)
63-
_history.append(newContent)
64+
_history.append(contentsOf: newContent)
6465
if let modelResponse = response.candidates.first {
6566
_history.append(modelResponse.content)
6667
}
6768
return response
6869
}
6970

70-
/// Sends a message to the model and returns the response as a stream of `GenerateContentResponse`s.
71+
/// Sends a message to the model and returns the response.
7172
///
7273
/// **Public Preview**: This API is a public preview and may be subject to change.
7374
///
7475
/// - Parameters:
7576
/// - message: The message to send to the model.
7677
/// - inputs: A dictionary of variables to substitute into the template.
77-
/// - options: The ``RequestOptions`` for the request, currently used to override default request timeout.
78+
/// - options: The ``RequestOptions`` for the request, currently used to override default
79+
/// request timeout.
80+
/// - Returns: The content generated by the model.
81+
/// - Throws: A ``GenerateContentError`` if the request failed.
82+
public func sendMessage(_ message: any PartsRepresentable,
83+
inputs: [String: Any],
84+
options: RequestOptions = RequestOptions()) async throws
85+
-> GenerateContentResponse {
86+
return try await sendMessage([ModelContent(parts: message.partsValue)],
87+
inputs: inputs,
88+
options: options)
89+
}
90+
91+
/// Sends a message to the model and returns the response as a stream of
92+
/// `GenerateContentResponse`s.
93+
///
94+
/// **Public Preview**: This API is a public preview and may be subject to change.
95+
///
96+
/// - Parameters:
97+
/// - content: The message to send to the model.
98+
/// - inputs: A dictionary of variables to substitute into the template.
99+
/// - options: The ``RequestOptions`` for the request, currently used to override default
100+
/// request timeout.
78101
/// - Returns: An `AsyncThrowingStream` that yields `GenerateContentResponse` objects.
79102
/// - Throws: A ``GenerateContentError`` if the request failed.
80-
public func sendMessageStream(_ message: any PartsRepresentable,
103+
public func sendMessageStream(_ content: [ModelContent],
81104
inputs: [String: Any],
82105
options: RequestOptions = RequestOptions()) throws
83106
-> AsyncThrowingStream<GenerateContentResponse, Error> {
84107
let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) }
85-
let newContent = populateContentRole(ModelContent(parts: message.partsValue))
108+
let newContent = content.map(populateContentRole)
86109
let stream = try model.generateContentStreamWithHistory(
87-
history: _history.history + [newContent],
110+
history: _history.history + newContent,
88111
template: templateID,
89112
inputs: templateInputs,
90113
options: options
@@ -110,7 +133,7 @@ public final class TemplateChatSession: Sendable {
110133
}
111134

112135
// Save the request.
113-
_history.append(newContent)
136+
_history.append(contentsOf: newContent)
114137

115138
// Aggregate the content to add it to the history before we finish.
116139
let aggregated = _history.aggregatedChunks(aggregatedContent)
@@ -120,6 +143,27 @@ public final class TemplateChatSession: Sendable {
120143
}
121144
}
122145

146+
/// Sends a message to the model and returns the response as a stream of
147+
/// `GenerateContentResponse`s.
148+
///
149+
/// **Public Preview**: This API is a public preview and may be subject to change.
150+
///
151+
/// - Parameters:
152+
/// - message: The message to send to the model.
153+
/// - inputs: A dictionary of variables to substitute into the template.
154+
/// - options: The ``RequestOptions`` for the request, currently used to override default
155+
/// request timeout.
156+
/// - Returns: An `AsyncThrowingStream` that yields `GenerateContentResponse` objects.
157+
/// - Throws: A ``GenerateContentError`` if the request failed.
158+
public func sendMessageStream(_ message: any PartsRepresentable,
159+
inputs: [String: Any],
160+
options: RequestOptions = RequestOptions()) throws
161+
-> AsyncThrowingStream<GenerateContentResponse, Error> {
162+
return try sendMessageStream([ModelContent(parts: message.partsValue)],
163+
inputs: inputs,
164+
options: options)
165+
}
166+
123167
private func populateContentRole(_ content: ModelContent) -> ModelContent {
124168
if content.role != nil {
125169
return content

FirebaseAI/Sources/TemplateGenerativeModel.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ public final class TemplateGenerativeModel: Sendable {
3535
/// - Parameters:
3636
/// - templateID: The ID of the prompt template to use.
3737
/// - inputs: A dictionary of variables to substitute into the template.
38-
/// - options: The ``RequestOptions`` for the request, currently used to override default request timeout.
38+
/// - options: The ``RequestOptions`` for the request, currently used to override default
39+
/// request timeout.
3940
/// - Returns: The content generated by the model.
4041
/// - Throws: A ``GenerateContentError`` if the request failed.
4142
public func generateContent(templateID: String,
@@ -84,7 +85,8 @@ public final class TemplateGenerativeModel: Sendable {
8485
/// - Parameters:
8586
/// - templateID: The ID of the prompt template to use.
8687
/// - inputs: A dictionary of variables to substitute into the template.
87-
/// - options: The ``RequestOptions`` for the request, currently used to override default request timeout.
88+
/// - options: The ``RequestOptions`` for the request, currently used to override default
89+
/// request timeout.
8890
/// - Returns: An `AsyncThrowingStream` that yields `GenerateContentResponse` objects.
8991
/// - Throws: A ``GenerateContentError`` if the request failed.
9092
public func generateContentStream(templateID: String,

FirebaseAI/Sources/TemplateImagenModel.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ public final class TemplateImagenModel: Sendable {
3434
/// - Parameters:
3535
/// - template: The prompt template to use.
3636
/// - variables: A dictionary of variables to substitute into the template.
37-
/// - options: The ``RequestOptions`` for the request, currently used to override default request timeout.
37+
/// - options: The ``RequestOptions`` for the request, currently used to override default
38+
/// request timeout.
3839
/// - Returns: The images generated by the model.
3940
/// - Throws: An error if the request failed.
4041
public func generateImages(templateID: String,

FirebaseAI/Tests/Unit/TemplateChatSessionTests.swift

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,49 @@ final class TemplateChatSessionTests: XCTestCase {
7373
XCTAssertEqual((chat.history[0].parts.first as? TextPart)?.text, "Hello")
7474
XCTAssertEqual(chat.history[1].role, "model")
7575
}
76+
77+
func testSendMessageWithModelContent() async throws {
78+
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
79+
forResource: "unary-success-basic-reply-short",
80+
withExtension: "json",
81+
subdirectory: "mock-responses/googleai",
82+
isTemplateRequest: true
83+
)
84+
let chat = model.startChat(templateID: "test-template")
85+
let response = try await chat.sendMessage(
86+
[ModelContent(parts: [TextPart("Hello")])],
87+
inputs: ["name": "test"]
88+
)
89+
XCTAssertEqual(chat.history.count, 2)
90+
XCTAssertEqual(chat.history[0].role, "user")
91+
XCTAssertEqual((chat.history[0].parts.first as? TextPart)?.text, "Hello")
92+
XCTAssertEqual(chat.history[1].role, "model")
93+
XCTAssertEqual(
94+
(chat.history[1].parts.first as? TextPart)?.text,
95+
"Google's headquarters, also known as the Googleplex, is located in **Mountain View, California**.\n"
96+
)
97+
XCTAssertEqual(response.candidates.count, 1)
98+
}
99+
100+
func testSendMessageStreamWithModelContent() async throws {
101+
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
102+
forResource: "streaming-success-basic-reply-short",
103+
withExtension: "txt",
104+
subdirectory: "mock-responses/googleai",
105+
isTemplateRequest: true
106+
)
107+
let chat = model.startChat(templateID: "test-template")
108+
let stream = try chat.sendMessageStream(
109+
[ModelContent(parts: [TextPart("Hello")])],
110+
inputs: ["name": "test"]
111+
)
112+
113+
let content = try await GenerativeModelTestUtil.collectTextFromStream(stream)
114+
115+
XCTAssertEqual(content, "The capital of Wyoming is **Cheyenne**.\n")
116+
XCTAssertEqual(chat.history.count, 2)
117+
XCTAssertEqual(chat.history[0].role, "user")
118+
XCTAssertEqual((chat.history[0].parts.first as? TextPart)?.text, "Hello")
119+
XCTAssertEqual(chat.history[1].role, "model")
120+
}
76121
}

0 commit comments

Comments
 (0)