Skip to content

Commit 50d0d29

Browse files
andrewheardMorgan Chenpaulb777peterfrieseryanwilson
committed
Fork google-generative-ai for Firebase (#12564)
Co-authored-by: Morgan Chen <[email protected]> Co-authored-by: Paul Beusterien <[email protected]> Co-authored-by: Peter Friese <[email protected]> Co-authored-by: Ryan Wilson <[email protected]>
1 parent 1bf861c commit 50d0d29

17 files changed

+2067
-1
lines changed

FirebaseVertexAI/Sources/Chat.swift

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
/// An object that represents a back-and-forth chat with a model, capturing the history and saving
18+
/// the context in memory between each message sent.
19+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
20+
public class Chat {
21+
private let model: GenerativeModel
22+
23+
/// Initializes a new chat representing a 1:1 conversation between model and user.
24+
init(model: GenerativeModel, history: [ModelContent]) {
25+
self.model = model
26+
self.history = history
27+
}
28+
29+
/// The previous content from the chat that has been successfully sent and received from the
30+
/// model. This will be provided to the model for each message sent as context for the discussion.
31+
public var history: [ModelContent]
32+
33+
/// See ``sendMessage(_:)-3ify5``.
34+
public func sendMessage(_ parts: any ThrowingPartsRepresentable...) async throws
35+
-> GenerateContentResponse {
36+
return try await sendMessage([ModelContent(parts: parts)])
37+
}
38+
39+
/// Sends a message using the existing history of this chat as context. If successful, the message
40+
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
41+
/// - Parameter content: The new content to send as a single chat message.
42+
/// - Returns: The model's response if no error occurred.
43+
/// - Throws: A ``GenerateContentError`` if an error occurred.
44+
public func sendMessage(_ content: @autoclosure () throws -> [ModelContent]) async throws
45+
-> GenerateContentResponse {
46+
// Ensure that the new content has the role set.
47+
let newContent: [ModelContent]
48+
do {
49+
newContent = try content().map(populateContentRole(_:))
50+
} catch let underlying {
51+
if let contentError = underlying as? ImageConversionError {
52+
throw GenerateContentError.promptImageContentError(underlying: contentError)
53+
} else {
54+
throw GenerateContentError.internalError(underlying: underlying)
55+
}
56+
}
57+
58+
// Send the history alongside the new message as context.
59+
let request = history + newContent
60+
let result = try await model.generateContent(request)
61+
guard let reply = result.candidates.first?.content else {
62+
let error = NSError(domain: "com.google.generative-ai",
63+
code: -1,
64+
userInfo: [
65+
NSLocalizedDescriptionKey: "No candidates with content available.",
66+
])
67+
throw GenerateContentError.internalError(underlying: error)
68+
}
69+
70+
// Make sure we inject the role into the content received.
71+
let toAdd = ModelContent(role: "model", parts: reply.parts)
72+
73+
// Append the request and successful result to history, then return the value.
74+
history.append(contentsOf: newContent)
75+
history.append(toAdd)
76+
return result
77+
}
78+
79+
/// See ``sendMessageStream(_:)-4abs3``.
80+
@available(macOS 12.0, *)
81+
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...)
82+
-> AsyncThrowingStream<GenerateContentResponse, Error> {
83+
return try sendMessageStream([ModelContent(parts: parts)])
84+
}
85+
86+
/// Sends a message using the existing history of this chat as context. If successful, the message
87+
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
88+
/// - Parameter content: The new content to send as a single chat message.
89+
/// - Returns: A stream containing the model's response or an error if an error occurred.
90+
@available(macOS 12.0, *)
91+
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent])
92+
-> AsyncThrowingStream<GenerateContentResponse, Error> {
93+
let resolvedContent: [ModelContent]
94+
do {
95+
resolvedContent = try content()
96+
} catch let underlying {
97+
return AsyncThrowingStream { continuation in
98+
let error: Error
99+
if let contentError = underlying as? ImageConversionError {
100+
error = GenerateContentError.promptImageContentError(underlying: contentError)
101+
} else {
102+
error = GenerateContentError.internalError(underlying: underlying)
103+
}
104+
continuation.finish(throwing: error)
105+
}
106+
}
107+
108+
return AsyncThrowingStream { continuation in
109+
Task {
110+
var aggregatedContent: [ModelContent] = []
111+
112+
// Ensure that the new content has the role set.
113+
let newContent: [ModelContent] = resolvedContent.map(populateContentRole(_:))
114+
115+
// Send the history alongside the new message as context.
116+
let request = history + newContent
117+
let stream = model.generateContentStream(request)
118+
do {
119+
for try await chunk in stream {
120+
// Capture any content that's streaming. This should be populated if there's no error.
121+
if let chunkContent = chunk.candidates.first?.content {
122+
aggregatedContent.append(chunkContent)
123+
}
124+
125+
// Pass along the chunk.
126+
continuation.yield(chunk)
127+
}
128+
} catch {
129+
// Rethrow the error that the underlying stream threw. Don't add anything to history.
130+
continuation.finish(throwing: error)
131+
return
132+
}
133+
134+
// Save the request.
135+
history.append(contentsOf: newContent)
136+
137+
// Aggregate the content to add it to the history before we finish.
138+
let aggregated = aggregatedChunks(aggregatedContent)
139+
history.append(aggregated)
140+
141+
continuation.finish()
142+
}
143+
}
144+
}
145+
146+
private func aggregatedChunks(_ chunks: [ModelContent]) -> ModelContent {
147+
var parts: [ModelContent.Part] = []
148+
var combinedText = ""
149+
for aggregate in chunks {
150+
// Loop through all the parts, aggregating the text and adding the images.
151+
for part in aggregate.parts {
152+
switch part {
153+
case let .text(str):
154+
combinedText += str
155+
156+
case .data(mimetype: _, _):
157+
// Don't combine it, just add to the content. If there's any text pending, add that as
158+
// a part.
159+
if !combinedText.isEmpty {
160+
parts.append(.text(combinedText))
161+
combinedText = ""
162+
}
163+
164+
parts.append(part)
165+
}
166+
}
167+
}
168+
169+
if !combinedText.isEmpty {
170+
parts.append(.text(combinedText))
171+
}
172+
173+
return ModelContent(role: "model", parts: parts)
174+
}
175+
176+
/// Populates the `role` field with `user` if it doesn't exist. Required in chat sessions.
177+
private func populateContentRole(_ content: ModelContent) -> ModelContent {
178+
if content.role != nil {
179+
return content
180+
} else {
181+
return ModelContent(role: "user", parts: content.parts)
182+
}
183+
}
184+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
18+
struct CountTokensRequest {
19+
let model: String
20+
let contents: [ModelContent]
21+
let options: RequestOptions
22+
}
23+
24+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
25+
extension CountTokensRequest: Encodable {
26+
enum CodingKeys: CodingKey {
27+
case contents
28+
}
29+
}
30+
31+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
32+
extension CountTokensRequest: GenerativeAIRequest {
33+
typealias Response = CountTokensResponse
34+
35+
var url: URL {
36+
URL(string: "\(GenerativeAISwift.baseURL)/\(options.apiVersion)/\(model):countTokens")!
37+
}
38+
}
39+
40+
/// The model's response to a count tokens request.
41+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
42+
public struct CountTokensResponse: Decodable {
43+
/// The total number of tokens in the input given to the model as a prompt.
44+
public let totalTokens: Int
45+
}

0 commit comments

Comments
 (0)