Skip to content

Commit 432d70b

Browse files
committed
[Firebase AI] Replace Any with TemplateInputRepresentable
1 parent 287dd12 commit 432d70b

File tree

5 files changed

+103
-61
lines changed

5 files changed

+103
-61
lines changed

FirebaseAI/Sources/TemplateChatSession.swift

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,14 @@ public final class TemplateChatSession: Sendable {
5050
/// - Returns: The content generated by the model.
5151
/// - Throws: A ``GenerateContentError`` if the request failed.
5252
public func sendMessage(_ content: [ModelContent],
53-
inputs: [String: Any],
53+
inputs: [String: any TemplateInputRepresentable],
5454
options: RequestOptions = RequestOptions()) async throws
5555
-> GenerateContentResponse {
56-
let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) }
5756
let newContent = content.map(populateContentRole)
5857
let response = try await model.generateContentWithHistory(
5958
history: _history.history + newContent,
6059
template: templateID,
61-
inputs: templateInputs,
60+
inputs: inputs.mapValues { $0.templateInputRepresentation },
6261
options: options
6362
)
6463
_history.append(contentsOf: newContent)
@@ -80,7 +79,7 @@ public final class TemplateChatSession: Sendable {
8079
/// - Returns: The content generated by the model.
8180
/// - Throws: A ``GenerateContentError`` if the request failed.
8281
public func sendMessage(_ message: any PartsRepresentable,
83-
inputs: [String: Any],
82+
inputs: [String: any TemplateInputRepresentable],
8483
options: RequestOptions = RequestOptions()) async throws
8584
-> GenerateContentResponse {
8685
return try await sendMessage([ModelContent(parts: message.partsValue)],
@@ -101,15 +100,14 @@ public final class TemplateChatSession: Sendable {
101100
/// - Returns: An `AsyncThrowingStream` that yields `GenerateContentResponse` objects.
102101
/// - Throws: A ``GenerateContentError`` if the request failed.
103102
public func sendMessageStream(_ content: [ModelContent],
104-
inputs: [String: Any],
103+
inputs: [String: any TemplateInputRepresentable],
105104
options: RequestOptions = RequestOptions()) throws
106105
-> AsyncThrowingStream<GenerateContentResponse, Error> {
107-
let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) }
108106
let newContent = content.map(populateContentRole)
109107
let stream = try model.generateContentStreamWithHistory(
110108
history: _history.history + newContent,
111109
template: templateID,
112-
inputs: templateInputs,
110+
inputs: inputs.mapValues { $0.templateInputRepresentation },
113111
options: options
114112
)
115113
return AsyncThrowingStream { continuation in
@@ -156,7 +154,7 @@ public final class TemplateChatSession: Sendable {
156154
/// - Returns: An `AsyncThrowingStream` that yields `GenerateContentResponse` objects.
157155
/// - Throws: A ``GenerateContentError`` if the request failed.
158156
public func sendMessageStream(_ message: any PartsRepresentable,
159-
inputs: [String: Any],
157+
inputs: [String: any TemplateInputRepresentable],
160158
options: RequestOptions = RequestOptions()) throws
161159
-> AsyncThrowingStream<GenerateContentResponse, Error> {
162160
return try sendMessageStream([ModelContent(parts: message.partsValue)],

FirebaseAI/Sources/TemplateGenerativeModel.swift

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,13 @@ public final class TemplateGenerativeModel: Sendable {
4040
/// - Returns: The content generated by the model.
4141
/// - Throws: A ``GenerateContentError`` if the request failed.
4242
public func generateContent(templateID: String,
43-
inputs: [String: Any],
43+
inputs: [String: any TemplateInputRepresentable],
4444
options: RequestOptions = RequestOptions()) async throws
4545
-> GenerateContentResponse {
46-
let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) }
4746
return try await generateContentWithHistory(
4847
history: [],
4948
template: templateID,
50-
inputs: templateInputs,
49+
inputs: inputs.mapValues { $0.templateInputRepresentation },
5150
options: options
5251
)
5352
}
@@ -90,13 +89,12 @@ public final class TemplateGenerativeModel: Sendable {
9089
/// - Returns: An `AsyncThrowingStream` that yields `GenerateContentResponse` objects.
9190
/// - Throws: A ``GenerateContentError`` if the request failed.
9291
public func generateContentStream(templateID: String,
93-
inputs: [String: Any],
92+
inputs: [String: any TemplateInputRepresentable],
9493
options: RequestOptions = RequestOptions()) throws
9594
-> AsyncThrowingStream<GenerateContentResponse, Error> {
96-
let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) }
9795
let request = TemplateGenerateContentRequest(
9896
template: templateID,
99-
inputs: templateInputs,
97+
inputs: inputs.mapValues { $0.templateInputRepresentation },
10098
history: [],
10199
projectID: generativeAIService.firebaseInfo.projectID,
102100
stream: true,

FirebaseAI/Sources/TemplateImagenModel.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,13 @@ public final class TemplateImagenModel: Sendable {
3939
/// - Returns: The images generated by the model.
4040
/// - Throws: An error if the request failed.
4141
public func generateImages(templateID: String,
42-
inputs: [String: Any],
42+
inputs: [String: any TemplateInputRepresentable],
4343
options: RequestOptions = RequestOptions()) async throws
4444
-> ImagenGenerationResponse<ImagenInlineImage> {
45-
let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) }
4645
let projectID = generativeAIService.firebaseInfo.projectID
4746
let request = TemplateImagenGenerationRequest<ImagenInlineImage>(
4847
template: templateID,
49-
inputs: templateInputs,
48+
inputs: inputs.mapValues { $0.templateInputRepresentation },
5049
projectID: projectID,
5150
apiConfig: apiConfig,
5251
options: options

FirebaseAI/Sources/TemplateInput.swift

Lines changed: 36 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,53 +14,45 @@
1414

1515
import Foundation
1616

17-
enum TemplateInput: Encodable, Sendable {
18-
case string(String)
19-
case int(Int)
20-
case double(Double)
21-
case bool(Bool)
22-
case array([TemplateInput])
23-
case dictionary([String: TemplateInput])
17+
public struct TemplateInput: Sendable {
18+
let kind: Kind
2419

25-
init(value: Any) throws {
26-
switch value {
27-
case let value as String:
28-
self = .string(value)
29-
case let value as Int:
30-
self = .int(value)
31-
case let value as Double:
32-
self = .double(value)
33-
case let value as Float:
34-
self = .double(Double(value))
35-
case let value as Bool:
36-
self = .bool(value)
37-
case let value as [Any]:
38-
self = try .array(value.map { try TemplateInput(value: $0) })
39-
case let value as [String: Any]:
40-
self = try .dictionary(value.mapValues { try TemplateInput(value: $0) })
41-
default:
42-
throw EncodingError.invalidValue(
43-
value,
44-
EncodingError.Context(codingPath: [], debugDescription: "Invalid value")
45-
)
46-
}
20+
public init(_ input: some TemplateInputRepresentable) {
21+
self = .init(kind: input.templateInputRepresentation.kind)
22+
}
23+
24+
init(kind: Kind) {
25+
self.kind = kind
4726
}
4827

49-
func encode(to encoder: Encoder) throws {
50-
var container = encoder.singleValueContainer()
51-
switch self {
52-
case let .string(value):
53-
try container.encode(value)
54-
case let .int(value):
55-
try container.encode(value)
56-
case let .double(value):
57-
try container.encode(value)
58-
case let .bool(value):
59-
try container.encode(value)
60-
case let .array(value):
61-
try container.encode(value)
62-
case let .dictionary(value):
63-
try container.encode(value)
28+
enum Kind: Encodable, Sendable {
29+
case string(String)
30+
case int(Int)
31+
case double(Double)
32+
case bool(Bool)
33+
case array([Kind])
34+
case dictionary([String: Kind])
35+
36+
func encode(to encoder: Encoder) throws {
37+
var container = encoder.singleValueContainer()
38+
switch self {
39+
case let .string(value):
40+
try container.encode(value)
41+
case let .int(value):
42+
try container.encode(value)
43+
case let .double(value):
44+
try container.encode(value)
45+
case let .bool(value):
46+
try container.encode(value)
47+
case let .array(value):
48+
try container.encode(value)
49+
case let .dictionary(value):
50+
try container.encode(value)
51+
}
6452
}
6553
}
6654
}
55+
56+
extension TemplateInput: TemplateInputRepresentable {
57+
public var templateInputRepresentation: TemplateInput { self }
58+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
/// A type that can be represented as a ``TemplateInput``.
18+
public protocol TemplateInputRepresentable: Encodable, Sendable {
19+
var templateInputRepresentation: TemplateInput { get }
20+
}
21+
22+
extension String: TemplateInputRepresentable {
23+
public var templateInputRepresentation: TemplateInput { TemplateInput(kind: .string(self)) }
24+
}
25+
26+
extension Int: TemplateInputRepresentable {
27+
public var templateInputRepresentation: TemplateInput { TemplateInput(kind: .int(self)) }
28+
}
29+
30+
extension Double: TemplateInputRepresentable {
31+
public var templateInputRepresentation: TemplateInput { TemplateInput(kind: .double(self)) }
32+
}
33+
34+
extension Float: TemplateInputRepresentable {
35+
public var templateInputRepresentation: TemplateInput {
36+
TemplateInput(kind: .double(Double(self)))
37+
}
38+
}
39+
40+
extension Bool: TemplateInputRepresentable {
41+
public var templateInputRepresentation: TemplateInput { TemplateInput(kind: .bool(self)) }
42+
}
43+
44+
extension Array: TemplateInputRepresentable where Element: TemplateInputRepresentable {
45+
public var templateInputRepresentation: TemplateInput {
46+
TemplateInput(kind: .array(map { TemplateInput($0).kind }))
47+
}
48+
}
49+
50+
extension Dictionary: TemplateInputRepresentable
51+
where Key == String, Value: TemplateInputRepresentable {
52+
public var templateInputRepresentation: TemplateInput {
53+
TemplateInput(kind: .dictionary(mapValues { TemplateInput($0).kind }))
54+
}
55+
}

0 commit comments

Comments
 (0)