Skip to content

Commit 95bb916

Browse files
authored
feat: support ML host customization (#459)
1 parent 5f29a17 commit 95bb916

File tree

6 files changed

+123
-2
lines changed

6 files changed

+123
-2
lines changed

AWSClientRuntime/Sources/Middlewares/EndpointResolverMiddleware.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ public struct EndpointResolverMiddleware<OperationStackOutput: HttpResponseBindi
3131
do {
3232
let awsEndpoint = try endpointResolver.resolve(serviceId: serviceId,
3333
region: region)
34-
let host = "\(context.getHostPrefix() ?? "")\(awsEndpoint.endpoint.host)"
34+
var host = ""
35+
if let overrideHost = context.getHost() {
36+
host = overrideHost
37+
} else {
38+
host = "\(context.getHostPrefix() ?? "")\(awsEndpoint.endpoint.host)"
39+
}
3540

3641
if let protocolType = awsEndpoint.endpoint.protocolType {
3742
input.withProtocol(protocolType)

codegen/protocol-test-codegen/build.gradle.kts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ val enabledProtocols = listOf(
3737
// service specific tests
3838
//ProtocolTest("apigateway", "com.amazonaws.apigateway#BackplaneControlService"),
3939
ProtocolTest("glacier", "com.amazonaws.glacier#Glacier", "GlacierTestSDK"),
40-
ProtocolTest("s3", "com.amazonaws.s3#AmazonS3", "S3TestSDK")
40+
ProtocolTest("s3", "com.amazonaws.s3#AmazonS3", "S3TestSDK"),
41+
ProtocolTest("machinelearning", "com.amazonaws.machinelearning#AmazonML_20141212", "MachineLearningTestSDK")
4142
)
4243

4344
// This project doesn't produce a JAR.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package software.amazon.smithy.aws.swift.codegen.customization.machinelearning
2+
3+
import software.amazon.smithy.aws.swift.codegen.middleware.PredictInputEndpointURLHostMiddlewareRenderable
4+
import software.amazon.smithy.aws.swift.codegen.middleware.handlers.PredictInputEndpointURLHostMiddlewareHandler
5+
import software.amazon.smithy.model.Model
6+
import software.amazon.smithy.model.shapes.OperationShape
7+
import software.amazon.smithy.model.shapes.ServiceShape
8+
import software.amazon.smithy.swift.codegen.MiddlewareGenerator
9+
import software.amazon.smithy.swift.codegen.SwiftDelegator
10+
import software.amazon.smithy.swift.codegen.SwiftDependency
11+
import software.amazon.smithy.swift.codegen.SwiftSettings
12+
import software.amazon.smithy.swift.codegen.core.CodegenContext
13+
import software.amazon.smithy.swift.codegen.core.toProtocolGenerationContext
14+
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
15+
import software.amazon.smithy.swift.codegen.integration.SwiftIntegration
16+
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils
17+
import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep
18+
import software.amazon.smithy.swift.codegen.middleware.OperationMiddleware
19+
import software.amazon.smithy.swift.codegen.model.expectShape
20+
21+
internal val ENABLED_OPERATIONS: Map<String, Set<String>> = mapOf(
22+
"com.amazonaws.machinelearning#AmazonML_20141212" to setOf(
23+
"com.amazonaws.machinelearning#Predict"
24+
)
25+
)
26+
27+
class PredictEndpointIntegration(private val enabledOperations: Map<String, Set<String>> = ENABLED_OPERATIONS) : SwiftIntegration {
28+
29+
override fun enabledForService(model: Model, settings: SwiftSettings): Boolean {
30+
val currentServiceId = model.expectShape<ServiceShape>(settings.service).id.toString()
31+
return enabledOperations.keys.contains(currentServiceId)
32+
}
33+
override fun writeAdditionalFiles(ctx: CodegenContext, delegator: SwiftDelegator) {
34+
val serviceShape = ctx.model.expectShape<ServiceShape>(ctx.settings.service)
35+
val protocolGeneratorContext = ctx.toProtocolGenerationContext(serviceShape, delegator)?.let { it } ?: run { return }
36+
val service = ctx.model.expectShape<ServiceShape>(ctx.settings.service)
37+
val operationsToGenerate = enabledOperations.getOrDefault(service.id.toString(), setOf())
38+
39+
operationsToGenerate.forEach { operation ->
40+
val op = ctx.model.expectShape<OperationShape>(operation)
41+
val inputSymbol = MiddlewareShapeUtils.inputSymbol(ctx.symbolProvider, ctx.model, op)
42+
val outputSymbol = MiddlewareShapeUtils.outputSymbol(ctx.symbolProvider, ctx.model, op)
43+
val outputErrorSymbol = MiddlewareShapeUtils.outputErrorSymbol(op)
44+
45+
val inputType = op.input.get()
46+
delegator.useFileWriter("${ctx.settings.moduleName}/models/$inputType+EndpointURLHostMiddleware.swift") { writer ->
47+
writer.addImport(SwiftDependency.CLIENT_RUNTIME.target)
48+
val predictMiddleware = PredictInputEndpointURLHostMiddlewareHandler(writer, protocolGeneratorContext, inputSymbol, outputSymbol, outputErrorSymbol)
49+
MiddlewareGenerator(writer, predictMiddleware).generate()
50+
}
51+
}
52+
}
53+
54+
override fun customizeMiddleware(
55+
ctx: ProtocolGenerator.GenerationContext,
56+
operationShape: OperationShape,
57+
operationMiddleware: OperationMiddleware
58+
) {
59+
if (enabledOperations.values.contains(setOf(operationShape.id.toString()))) {
60+
operationMiddleware.removeMiddleware(operationShape, MiddlewareStep.INITIALIZESTEP, "OperationInputUrlHostMiddleware")
61+
operationMiddleware.appendMiddleware(operationShape, PredictInputEndpointURLHostMiddlewareRenderable())
62+
}
63+
}
64+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package software.amazon.smithy.aws.swift.codegen.middleware
2+
3+
import software.amazon.smithy.model.shapes.OperationShape
4+
import software.amazon.smithy.swift.codegen.SwiftWriter
5+
import software.amazon.smithy.swift.codegen.middleware.MiddlewarePosition
6+
import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable
7+
import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep
8+
9+
class PredictInputEndpointURLHostMiddlewareRenderable : MiddlewareRenderable {
10+
override val name = "PredictInputEndpointURLHostMiddleware"
11+
12+
override val middlewareStep = MiddlewareStep.INITIALIZESTEP
13+
14+
override val position = MiddlewarePosition.AFTER
15+
16+
override fun render(writer: SwiftWriter, op: OperationShape, operationStackName: String) {
17+
writer.write("$operationStackName.${middlewareStep.stringValue()}.intercept(position: ${position.stringValue()}, middleware: PredictInputEndpointURLHostMiddleware())")
18+
}
19+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package software.amazon.smithy.aws.swift.codegen.middleware.handlers
2+
3+
import software.amazon.smithy.codegen.core.Symbol
4+
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
5+
import software.amazon.smithy.swift.codegen.Middleware
6+
import software.amazon.smithy.swift.codegen.SwiftWriter
7+
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
8+
import software.amazon.smithy.swift.codegen.integration.steps.OperationInitializeStep
9+
10+
class PredictInputEndpointURLHostMiddlewareHandler(
11+
private val writer: SwiftWriter,
12+
ctx: ProtocolGenerator.GenerationContext,
13+
inputSymbol: Symbol,
14+
outputSymbol: Symbol,
15+
outputErrorSymbol: Symbol
16+
) : Middleware(writer, inputSymbol, OperationInitializeStep(inputSymbol, outputSymbol, outputErrorSymbol)) {
17+
18+
override val typeName = "${inputSymbol.name}EndpointURLHostMiddleware"
19+
20+
override fun generateInit() {
21+
writer.write("public init() { }")
22+
}
23+
24+
override fun generateMiddlewareClosure() {
25+
writer.openBlock("if let endpoint = input.predictEndpoint, let url = \$N(string: endpoint), let host = url.host {", "}", ClientRuntimeTypes.Core.URL) {
26+
writer.write("var copiedContext = context")
27+
writer.write("copiedContext.attributes.set(key: AttributeKey<String>(name: \"Host\"), value: host)")
28+
writer.write("return next.handle(context: copiedContext, input: input)")
29+
}
30+
}
31+
}

codegen/smithy-aws-swift-codegen/src/main/resources/META-INF/services/software.amazon.smithy.swift.codegen.integration.SwiftIntegration

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ software.amazon.smithy.aws.swift.codegen.customization.glacier.GlacierAccountIdD
77
software.amazon.smithy.aws.swift.codegen.customization.glacier.GlacierChecksum
88
software.amazon.smithy.aws.swift.codegen.customization.BoxServices
99
software.amazon.smithy.aws.swift.codegen.customization.PresignableModelIntegration
10+
software.amazon.smithy.aws.swift.codegen.customization.machinelearning.PredictEndpointIntegration
1011
software.amazon.smithy.aws.swift.codegen.customization.polly.PollyGetPresignerIntegration
1112
software.amazon.smithy.aws.swift.codegen.PresignerGenerator

0 commit comments

Comments
 (0)