Skip to content

Commit e5efeee

Browse files
committed
reafctor BedrockTypes
1 parent f152047 commit e5efeee

File tree

12 files changed

+170
-88
lines changed

12 files changed

+170
-88
lines changed

BedrockTypes/BedrockTypes.swift

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,88 @@ import Foundation
99

1010
// MARK: - Data structures
1111

12-
public typealias BedrockModelIdentifier = String
12+
// model enum
13+
public struct BedrockModel: RawRepresentable, Equatable, Hashable {
14+
public var rawValue: String
15+
16+
public init(rawValue: String) {
17+
self.rawValue = rawValue
18+
}
19+
}
1320

21+
// Anthropic
22+
public extension BedrockModel {
23+
static var instant: BedrockModel { .init(rawValue: "anthropic.claude-instant-v1") }
24+
static var claudev1: BedrockModel { .init(rawValue: "anthropic.claude-v1") }
25+
static var claudev2: BedrockModel { .init(rawValue: "anthropic.claude-v2") }
26+
static var claudev2_1: BedrockModel { .init(rawValue: "anthropic.claude-v2:1") }
27+
func isAnthropic() -> Bool {
28+
switch self {
29+
case .instant, .claudev1, .claudev2, .claudev2_1: return true
30+
default: return false
31+
}
32+
}
33+
}
1434

15-
//struct AnyStringKey: CodingKey, Hashable, ExpressibleByStringLiteral {
16-
// var stringValue: String
17-
// init(stringValue: String) { self.stringValue = stringValue }
18-
// init(_ stringValue: String) { self.init(stringValue: stringValue) }
19-
// // the three lines below are not used, they are there to comply to `CodingKey` protocol
20-
// var intValue: Int?
21-
// init?(intValue: Int) { nil }
22-
// init(stringLiteral value: String) { self.init(value) }
35+
// Meta
36+
public extension BedrockModel {
37+
static var llama2_13b: BedrockModel { .init(rawValue: "meta.llama2.13b") }
38+
static var llama2_70b: BedrockModel { .init(rawValue: "meta.llama2.70b") }
39+
}
40+
41+
public extension BedrockModel {
42+
init?(from: String?) {
43+
guard let model = from else {
44+
return nil
45+
}
46+
self.init(rawValue: model)
47+
switch self {
48+
case .instant,
49+
.claudev1,
50+
.claudev2,
51+
.claudev2_1,
52+
.llama2_13b: return
53+
default: return nil
54+
}
55+
}
56+
}
57+
58+
//public enum BedrockModel: Hashable {
59+
// case anthropicModel(AnthropicModel)
60+
// case metaModel(MetaModel)
61+
//
62+
// public func id() -> BedrockModelIdentifier {
63+
// switch self {
64+
// case .anthropicModel(let anthropic): return anthropic.rawValue
65+
// case .metaModel(let meta): return meta.rawValue
66+
// }
67+
// }
2368
//}
2469

70+
public protocol BedrockResponse: Decodable {
71+
init(from data: Data) throws
72+
}
73+
74+
public extension BedrockResponse {
75+
static func decode<T: Decodable>(_ data: Data) throws -> T {
76+
let decoder = JSONDecoder()
77+
return try decoder.decode(T.self, from: data)
78+
}
79+
static func decode<T: Decodable>(json: String) throws -> T {
80+
let data = json.data(using: .utf8)!
81+
return try self.decode(data)
82+
}
83+
}
2584

85+
public protocol BedrockRequest: Encodable {
86+
func encode() throws -> Data
87+
}
2688

89+
public extension BedrockRequest {
90+
func encode() throws -> Data {
91+
let encoder = JSONEncoder()
92+
encoder.keyEncodingStrategy = .convertToSnakeCase
93+
return try encoder.encode(self)
94+
}
95+
}
2796

BedrockTypes/ClaudeTypes.swift

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,8 @@
77

88
import Foundation
99

10-
public enum BedrockClaudeModel : BedrockModelIdentifier {
11-
case instant = "anthropic.claude-instant-v1"
12-
case claudev1 = "anthropic.claude-v1"
13-
case claudev2 = "anthropic.claude-v2"
14-
}
15-
16-
public struct ClaudeParameters: Encodable {
10+
public struct ClaudeRequest: BedrockRequest {
11+
1712
public init(prompt: String) {
1813
self.prompt = "Human: \(prompt)\n\nAssistant:"
1914
}
@@ -23,21 +18,14 @@ public struct ClaudeParameters: Encodable {
2318
public let topK: Int = 250
2419
public let maxTokensToSample: Int = 8191
2520
public let stopSequences: [String] = ["\n\nHuman:"]
26-
27-
public func encode() throws -> Data {
28-
let encoder = JSONEncoder()
29-
encoder.keyEncodingStrategy = .convertToSnakeCase
30-
return try encoder.encode(self)
31-
}
3221
}
3322

34-
public struct ClaudeInvokeResponse: Decodable {
23+
public struct ClaudeResponse: BedrockResponse {
24+
3525
public let completion: String
3626
public let stop_reason: String
3727

3828
public init(from data: Data) throws {
39-
let decoder = JSONDecoder()
40-
decoder.keyDecodingStrategy = .convertFromSnakeCase
41-
self = try decoder.decode(ClaudeInvokeResponse.self, from: data)
29+
self = try ClaudeResponse.decode(data)
4230
}
4331
}

BedrockTypes/CohereTypes.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ struct CohereEmbedDocument: Encodable {
2929
return try encoder.encode(self)
3030
}
3131
}
32-
struct CohereEmbedResponse: Decodable, CustomStringConvertible {
32+
struct CohereEmbedResponse: BedrockResponse, CustomStringConvertible {
3333
let embeddings: [[Double]]
3434
let id: String
3535
let texts: [String]

BedrockTypes/TitanTypes.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ struct TitanEmbedDocument: Encodable {
1616
}
1717
}
1818

19-
struct TitanEmbedResponse: Decodable, CustomStringConvertible {
19+
struct TitanEmbedResponse: BedrockResponse, CustomStringConvertible {
2020

2121
let embedding: [Double]
2222
let inputTextTokenCount: Int

SwiftFMPlayground/Bedrock/Bedrock.swift

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ struct Bedrock {
1616

1717
private var logger = Logger(label: "Bedrock")
1818

19-
var region: String = "us-east-1"
19+
var region: String = "us-east-1" //TODO: add a type safe way to express regions
2020

2121
init() {
2222
#if DEBUG
@@ -42,44 +42,33 @@ struct Bedrock {
4242
return try await bedrockClient().listFoundationModels(input: request)
4343
}
4444

45-
// TODO: make it generic for all Bedrock Models (returns T.AssociateType ?)
46-
func invokeModel(withId modelId: BedrockClaudeModel, prompt: String) async throws -> ClaudeInvokeResponse {
45+
func invokeClaude(model: BedrockModel, request: ClaudeRequest) async throws -> ClaudeResponse {
46+
let expectedModel = "anthropic.claude"
47+
guard model.rawValue.starts(with: expectedModel) else {
48+
throw BedrockError.invalidModel("Expecting \(expectedModel)*")
49+
}
50+
let data = try await self.invokeModel(withId: model, params: request.encode())
51+
return try ClaudeResponse(from: data)
52+
}
53+
private func invokeModel(withId model: BedrockModel, params: Data) async throws -> Data {
4754

48-
let params = ClaudeParameters(prompt: prompt)
49-
let request = InvokeModelInput(body: try self.encode(params),
55+
let request = InvokeModelInput(body: params,
5056
contentType: "application/json",
51-
modelId: modelId.rawValue)
57+
modelId: model.rawValue)
5258
let response = try await bedrockRuntimeClient().invokeModel(input: request)
5359

5460
guard response.contentType == "application/json",
5561
let data = response.body else {
5662
logger.debug("Invalid Bedrock response: \(response)")
5763
throw BedrockError.invalidResponse(response.body)
5864
}
59-
return try self.decode(data)
60-
}
61-
62-
private func decode<T: Decodable>(_ data: Data) throws -> T {
63-
let decoder = JSONDecoder()
64-
return try decoder.decode(T.self, from: data)
65-
}
66-
private func decode<T: Decodable>(json: String) throws -> T {
67-
let data = json.data(using: .utf8)!
68-
return try self.decode(data)
69-
}
70-
private func encode<T: Encodable>(_ value: T) throws -> Data {
71-
let encoder = JSONEncoder()
72-
encoder.keyEncodingStrategy = .convertToSnakeCase
73-
return try encoder.encode(value)
74-
}
75-
private func encode<T: Encodable>(_ value: T) throws -> String {
76-
let data : Data = try self.encode(value)
77-
return String(data: data, encoding: .utf8) ?? "error when encoding the string"
65+
return data
7866
}
7967
}
8068

8169
// MARK: - Errors
8270

8371
enum BedrockError: Error {
8472
case invalidResponse(Data?)
73+
case invalidModel(String)
8574
}

SwiftFMPlayground/Model/BedrockModelParameterUI.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import BedrockTypes
1010

1111
// modelid => list of model parameters
12-
public typealias AllModelParameters = [BedrockModelIdentifier: ModelParameters]
12+
public typealias AllModelParameters = [BedrockModel: ModelParameters]
1313

1414
public enum BedrockModelParameter: Encodable {
1515
case number(BedrockModelParameterNumber)

SwiftFMPlayground/Model/BedrockTypesUI.swift

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ struct BedrockModelSummaryUI: Hashable, Identifiable {
4343
responseStreamingSupported: model.responseStreamingSupported?.description ?? "unknown")
4444
}
4545
}
46+
47+
func bedrockModel() {
48+
49+
}
4650
}
4751

4852
enum InputCapabilities: String {
@@ -61,12 +65,13 @@ extension Array<BedrockModelSummaryUI> {
6165
/**
6266
Return the list of modelId for the given provider
6367
*/
64-
private func modelsId(for provider: String?) -> [String] {
68+
private func modelsId(for provider: String?) -> [BedrockModelSummaryUI] { // }[String] {
6569
self.filter {
6670
$0.providerName == provider
67-
}.map {
68-
$0.modelId
6971
}
72+
// .map {
73+
// $0.modelId
74+
// }
7075
}
7176

7277
/**
@@ -124,7 +129,7 @@ extension Array<BedrockModelSummaryUI> {
124129
*/
125130
func modelsId(forProvider provider: String?,
126131
withInputCapability inputCapability: InputCapabilities?,
127-
andOutpuCapability outputCapability: OutputCapabilities?) -> [String] {
132+
andOutpuCapability outputCapability: OutputCapabilities?) -> [BedrockModelSummaryUI] { //[String] {
128133

129134
// when no capability is passed, return all models for this provider
130135
guard let inputCapability, let outputCapability else {
@@ -147,18 +152,18 @@ extension Array<BedrockModelSummaryUI> {
147152
.contains(outputCapability.rawValue.lowercased())
148153
}
149154

150-
// 2. extract just the model id
151-
.map {
152-
$0.modelId
153-
}
154-
155-
// 3. keep only unique values
156-
.uniqued()
157-
158-
// 4. tranform to [String]
159-
.map {
160-
$0
161-
}
155+
// // 2. extract just the model id
156+
// .map {
157+
// $0.modelId
158+
// }
159+
//
160+
// // 3. keep only unique values
161+
// .uniqued()
162+
//
163+
// // 4. tranform to [String]
164+
// .map {
165+
// $0
166+
// }
162167
}
163168

164169
}

SwiftFMPlayground/Model/Model.swift

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@ struct Model {
1313
var listFoundationModels: [BedrockModelSummaryUI] = []
1414

1515
// allows to inject values for the Mock
16-
func modelsParameters(for selectedModel: BedrockModelIdentifier) -> ModelParameters? { return allModelsParameters[selectedModel] }
16+
func modelsParameters(for selectedModel: BedrockModel) -> ModelParameters? { return allModelsParameters[selectedModel] }
1717
private var allModelsParameters: AllModelParameters =
1818
[
1919
// https://docs.anthropic.com/claude/reference/complete_post
20-
BedrockClaudeModel.instant.rawValue : Model.claudeModelParameters,
21-
BedrockClaudeModel.claudev1.rawValue : Model.claudeModelParameters,
22-
BedrockClaudeModel.claudev2.rawValue : Model.claudeModelParameters
20+
BedrockModel.instant : Model.claudeModelParameters,
21+
BedrockModel.claudev1 : Model.claudeModelParameters,
22+
BedrockModel.claudev2 : Model.claudeModelParameters,
23+
BedrockModel.claudev2_1 : Model.claudeModelParameters
2324
]
2425

2526
// this methods returns a container used by the UI
@@ -32,6 +33,11 @@ struct Model {
3233
]
3334
}
3435

36+
// mock
37+
extension BedrockModel {
38+
static var mock1: BedrockModel { .init(rawValue: "id1") }
39+
}
40+
3541
extension Model {
3642
static func mock() -> Model {
3743

@@ -61,7 +67,7 @@ extension Model {
6167
providerName: "provider 2",
6268
responseStreamingSupported: "yes"))
6369

64-
m.allModelsParameters["id1"] = Model.claudeModelParameters
70+
m.allModelsParameters[.mock1] = Model.claudeModelParameters
6571
return m
6672
}
6773
}

SwiftFMPlayground/Model/ViewModel.swift

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import Foundation
99
import Logging
1010
import OrderedCollections
1111
import BedrockTypes
12-
import AWSBedrock
1312

1413
@MainActor
1514
final class ViewModel: ObservableObject {
@@ -19,7 +18,7 @@ final class ViewModel: ObservableObject {
1918
@Published var data : Model
2019
@Published var state: Status = .ready()
2120

22-
@Published var selectedModel: BedrockModelIdentifier = ""
21+
@Published var selectedModel: BedrockModelSummaryUI? = nil
2322

2423
init(model: Model = Model()) {
2524
#if DEBUG
@@ -38,10 +37,13 @@ final class ViewModel: ObservableObject {
3837
}
3938
return self.data.listFoundationModels
4039
}
41-
42-
func selectedModelParameter() -> OrderedDictionary<String, BedrockModelParameter> {
40+
func selectedBedrockModel() -> BedrockModel? {
41+
return BedrockModel.init(from: selectedModel?.modelId)
42+
}
43+
func selectedModelParameter() throws -> OrderedDictionary<String, BedrockModelParameter> {
4344

44-
guard let rawParameters = self.data.modelsParameters(for: selectedModel) else {
45+
guard let selectedModel = selectedBedrockModel(),
46+
let rawParameters = self.data.modelsParameters(for: selectedModel) else {
4547
return [:]
4648
}
4749

@@ -69,6 +71,24 @@ final class ViewModel: ObservableObject {
6971

7072
return parameters
7173
}
74+
75+
func invoke(with text: String) async throws -> String {
76+
guard let model = selectedBedrockModel() else {
77+
return "model is nil"
78+
}
79+
80+
let bedrock = Bedrock()
81+
82+
if model.isAnthropic() {
83+
// TODO: update status bar
84+
// TODO: show a progress() UI
85+
let request = ClaudeRequest(prompt: text)
86+
let response = try await bedrock.invokeClaude(model: model, request: request)
87+
return response.completion
88+
} else {
89+
return "not implemented"
90+
}
91+
}
7292
}
7393

7494
// MARK: Control the status bar

0 commit comments

Comments
 (0)