Skip to content

Commit 22d011b

Browse files
sebstoCopilot
andauthored
Add support for cross region inference (#25)
* support cross region inference * swift-format * Update Sources/InvokeModel/BedrockService+InvokeModelImage.swift Co-authored-by: Copilot <[email protected]> --------- Co-authored-by: Copilot <[email protected]>
1 parent e5ad4e7 commit 22d011b

25 files changed

+131
-89
lines changed

Examples/web-playground/frontend/helpers/chatModelData.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ export const chatModels = [
125125
},
126126
{
127127
modelName: "Anthropic Claude 3.5 Haiku",
128-
modelId: "us.anthropic.claude-3-5-haiku-20241022-v1:0",
128+
modelId: "anthropic.claude-3-5-haiku-20241022-v1:0",
129129
temperatureRange: {
130130
default: 1,
131131
min: 0,
@@ -287,7 +287,7 @@ export const chatModels = [
287287
// DeepSeek
288288
// {
289289
// modelName: "Deep Seek",
290-
// modelId: "us.deepseek.r1-v1:0",
290+
// modelId: "deepseek.r1-v1:0",
291291
// topPRange: {
292292
// max: 1,
293293
// default: 1,

Examples/web-playground/frontend/helpers/modelData.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ export const models = [
6060
},
6161
{
6262
modelName: "Anthropic Claude 3.5 Haiku",
63-
modelId: "us.anthropic.claude-3-5-haiku-20241022-v1:0",
63+
modelId: "anthropic.claude-3-5-haiku-20241022-v1:0",
6464
temperatureRange: {
6565
min: 0,
6666
max: 1,
@@ -173,7 +173,7 @@ export const models = [
173173
},
174174
// {
175175
// modelName: "Deep Seek",
176-
// modelId: "us.deepseek.r1-v1:0",
176+
// modelId: "deepseek.r1-v1:0",
177177
// temperatureRange: {
178178
// min: 0,
179179
// max: 1,

Examples/web-playground/frontend/helpers/reasoningModelData.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export const defaultModel = {
22
modelName: "Claude V3.7 Sonnet",
3-
modelId: "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
3+
modelId: "anthropic.claude-3-7-sonnet-20250219-v1:0",
44
topKRange: {
55
max: 500,
66
default: 0,
@@ -32,7 +32,7 @@ export const models = [
3232
defaultModel,
3333
// {
3434
// modelName: "Deep Seek",
35-
// modelId: "us.deepseek.r1-v1:0",
35+
// modelId: "deepseek.r1-v1:0",
3636
// topPRange: {
3737
// max: 1,
3838
// default: 1,

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,7 @@ You can now create instances for any of the models that follow the request and r
925925
```swift
926926
extension BedrockModel {
927927
public static let llama3_3_70b_instruct: BedrockModel = BedrockModel(
928-
id: "us.meta.llama3-3-70b-instruct-v1:0",
928+
id: "meta.llama3-3-70b-instruct-v1:0",
929929
name: "Llama 3.3 70B Instruct",
930930
modality: LlamaText(
931931
parameters: TextGenerationParameters(

Sources/BedrockModel.swift

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ import Foundation
1818
public struct BedrockModel: Hashable, Sendable, Equatable, RawRepresentable {
1919
public var rawValue: String { id }
2020

21-
public var id: String
22-
public var name: String
21+
public let id: String
22+
public let name: String
2323
public let modality: any Modality
2424

2525
/// Creates a new BedrockModel instance
@@ -106,6 +106,17 @@ public struct BedrockModel: Hashable, Sendable, Equatable, RawRepresentable {
106106
}
107107
}
108108

109+
// MARK: Cross region inference
110+
public func getModelIdWithCrossRegionInferencePrefix(region: Region) -> String {
111+
// If the model does not support cross region inference, return the model ID as is
112+
guard let crossRegionInferenceModality = modality as? CrossRegionInferenceModality else {
113+
return id
114+
}
115+
// If the model supports cross region inference, return the model ID with the appropriate prefix
116+
let prefix = crossRegionInferenceModality.crossRegionPrefix(forRegion: region)
117+
return "\(prefix)\(id)"
118+
}
119+
109120
// MARK: Modality checks
110121

111122
// MARK - Text completion

Sources/Converse/BedrockService+Converse.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ extension BedrockService {
8484
)
8585

8686
logger.trace("Creating ConverseInput")
87-
let input = try converseRequest.getConverseInput()
87+
let input = try converseRequest.getConverseInput(forRegion: self.region)
8888

8989
logger.trace(
9090
"Sending ConverseInput to BedrockRuntimeClient",

Sources/Converse/BedrockService+ConverseStreaming.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ extension BedrockService {
9292
)
9393

9494
logger.trace("Creating ConverseStreamingInput")
95-
let input = try converseRequest.getConverseStreamingInput()
95+
let input = try converseRequest.getConverseStreamingInput(forRegion: region)
9696

9797
logger.trace(
9898
"Sending ConverseStreaminInput to BedrockRuntimeClient",

Sources/Converse/ConverseRequest.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ public struct ConverseRequest {
5353
}
5454
}
5555

56-
func getConverseInput() throws -> ConverseInput {
56+
func getConverseInput(forRegion region: Region) throws -> ConverseInput {
5757
ConverseInput(
5858
additionalModelRequestFields: try getAdditionalModelRequestFields(),
5959
inferenceConfig: inferenceConfig?.getSDKInferenceConfig(),
6060
messages: try getSDKMessages(),
61-
modelId: model.id,
61+
modelId: model.getModelIdWithCrossRegionInferencePrefix(region: region),
6262
system: getSDKSystemPrompts(),
6363
toolConfig: try toolConfig?.getSDKToolConfig()
6464
)

Sources/Converse/ConverseRequestStreaming.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
public typealias ConverseStreamingRequest = ConverseRequest
1919
extension ConverseStreamingRequest {
20-
func getConverseStreamingInput() throws -> ConverseStreamInput {
20+
func getConverseStreamingInput(forRegion region: Region) throws -> ConverseStreamInput {
2121
ConverseStreamInput(
2222
additionalModelRequestFields: try getAdditionalModelRequestFields(),
2323
inferenceConfig: inferenceConfig?.getSDKInferenceConfig(),
2424
messages: try getSDKMessages(),
25-
modelId: model.id,
25+
modelId: model.getModelIdWithCrossRegionInferencePrefix(region: region),
2626
system: getSDKSystemPrompts(),
2727
toolConfig: try toolConfig?.getSDKToolConfig()
2828
)

Sources/InvokeModel/BedrockService+InvokeModelImage.swift

Lines changed: 30 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -77,31 +77,9 @@ extension BedrockService {
7777
quality: quality,
7878
resolution: resolution
7979
)
80-
let input: InvokeModelInput = try request.getInvokeModelInput()
81-
logger.trace(
82-
"Sending request to invokeModel",
83-
metadata: [
84-
"model": .string(model.id), "request": .string(String(describing: input)),
85-
]
86-
)
87-
let response = try await self.bedrockRuntimeClient.invokeModel(input: input)
88-
guard let responseBody = response.body else {
89-
logger.trace(
90-
"Invalid response",
91-
metadata: [
92-
"response": .string(String(describing: response)),
93-
"hasBody": .stringConvertible(response.body != nil),
94-
]
95-
)
96-
throw BedrockLibraryError.invalidSDKResponse(
97-
"Something went wrong while extracting body from response."
98-
)
99-
}
100-
let invokemodelResponse: InvokeModelResponse = try InvokeModelResponse.createImageResponse(
101-
body: responseBody,
102-
model: model
103-
)
104-
return try invokemodelResponse.getGeneratedImage()
80+
81+
return try await sendRequest(request: request, model: model)
82+
10583
} catch {
10684
try handleCommonError(error, context: "listing foundation models")
10785
}
@@ -174,34 +152,39 @@ extension BedrockService {
174152
quality: quality,
175153
resolution: resolution
176154
)
177-
let input: InvokeModelInput = try request.getInvokeModelInput()
155+
return try await sendRequest(request: request, model: model)
156+
} catch {
157+
try handleCommonError(error, context: "invoking image model")
158+
}
159+
}
160+
161+
/// Sends the request to invoke the model and returns the generated image(s)
162+
private func sendRequest(request: InvokeModelRequest, model: BedrockModel) async throws -> ImageGenerationOutput {
163+
let input: InvokeModelInput = try request.getInvokeModelInput(forRegion: self.region)
164+
logger.trace(
165+
"Sending request to invokeModel",
166+
metadata: [
167+
"model": .string(model.id), "request": .string(String(describing: input)),
168+
]
169+
)
170+
let response = try await self.bedrockRuntimeClient.invokeModel(input: input)
171+
guard let responseBody = response.body else {
178172
logger.trace(
179-
"Sending request to invokeModel",
173+
"Invalid response",
180174
metadata: [
181-
"model": .string(model.id), "request": .string(String(describing: input)),
175+
"response": .string(String(describing: response)),
176+
"hasBody": .stringConvertible(response.body != nil),
182177
]
183178
)
184-
let response = try await self.bedrockRuntimeClient.invokeModel(input: input)
185-
guard let responseBody = response.body else {
186-
logger.trace(
187-
"Invalid response",
188-
metadata: [
189-
"response": .string(String(describing: response)),
190-
"hasBody": .stringConvertible(response.body != nil),
191-
]
192-
)
193-
throw BedrockLibraryError.invalidSDKResponse(
194-
"Something went wrong while extracting body from response."
195-
)
196-
}
197-
let invokemodelResponse: InvokeModelResponse = try InvokeModelResponse.createImageResponse(
198-
body: responseBody,
199-
model: model
179+
throw BedrockLibraryError.invalidSDKResponse(
180+
"Something went wrong while extracting body from response."
200181
)
201-
return try invokemodelResponse.getGeneratedImage()
202-
} catch {
203-
try handleCommonError(error, context: "listing foundation models")
204182
}
183+
let invokemodelResponse: InvokeModelResponse = try InvokeModelResponse.createImageResponse(
184+
body: responseBody,
185+
model: model
186+
)
187+
return try invokemodelResponse.getGeneratedImage()
205188
}
206189

207190
/// Generates 1 to 5 image variation(s) from reference images and a text prompt using a specific model

0 commit comments

Comments
 (0)