Skip to content

Commit 28ee19a

Browse files
authored
feat: enable machinelearning endpoint customization (#378)
1 parent d634ec1 commit 28ee19a

File tree

5 files changed

+113
-4
lines changed

5 files changed

+113
-4
lines changed

codegen/protocol-tests/build.gradle.kts

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ val enabledProtocols = listOf(
3737

3838
// service specific tests
3939
ProtocolTest("apigateway", "com.amazonaws.apigateway#BackplaneControlService"),
40-
ProtocolTest("glacier", "com.amazonaws.glacier#Glacier")
40+
ProtocolTest("glacier", "com.amazonaws.glacier#Glacier"),
41+
ProtocolTest("machinelearning", "com.amazonaws.machinelearning#AmazonML_20141212", sdkId = "Machine Learning"),
4142
)
4243

4344
// This project doesn't produce a JAR.
@@ -79,7 +80,7 @@ tasks.create<SmithyBuild>("generateSdk") {
7980
// force rebuild every time while developing
8081
tasks["generateSdk"].outputs.upToDateWhen { false }
8182

82-
data class ProtocolTest(val projectionName: String, val serviceShapeId: String) {
83+
data class ProtocolTest(val projectionName: String, val serviceShapeId: String, val sdkId: String? = null) {
8384
val packageName: String
8485
get() = projectionName.toLowerCase().filter { it.isLetterOrDigit() }
8586
}
@@ -90,6 +91,7 @@ data class ProtocolTest(val projectionName: String, val serviceShapeId: String)
9091
// it's rebuilt each time codegen is performed.
9192
fun generateSmithyBuild(tests: List<ProtocolTest>): String {
9293
val projections = tests.joinToString(",") { test ->
94+
val sdkIdEntry = test.sdkId?.let { """"sdkId": "$it",""" } ?: ""
9395
"""
9496
"${test.projectionName}": {
9597
"transforms": [
@@ -109,6 +111,7 @@ fun generateSmithyBuild(tests: List<ProtocolTest>): String {
109111
"name": "aws.sdk.kotlin.services.${test.packageName}",
110112
"version": "1.0"
111113
},
114+
$sdkIdEntry
112115
"build": {
113116
"rootProject": true,
114117
"optInAnnotations": [
@@ -131,7 +134,6 @@ fun generateSmithyBuild(tests: List<ProtocolTest>): String {
131134
""".trimIndent()
132135
}
133136

134-
135137
open class ProtocolTestTask : DefaultTask() {
136138
/**
137139
* The protocol name
@@ -194,7 +196,6 @@ enabledProtocols.forEach {
194196
}
195197
}
196198

197-
198199
tasks.register("testAllProtocols") {
199200
group = "Verification"
200201
val allTests = tasks.withType<ProtocolTestTask>()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package aws.sdk.kotlin.codegen.customization.machinelearning
7+
8+
import aws.sdk.kotlin.codegen.protocols.middleware.ResolveAwsEndpointMiddleware
9+
import aws.sdk.kotlin.codegen.sdkId
10+
import software.amazon.smithy.kotlin.codegen.KotlinSettings
11+
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
12+
import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration
13+
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
14+
import software.amazon.smithy.kotlin.codegen.model.expectShape
15+
import software.amazon.smithy.kotlin.codegen.rendering.protocol.*
16+
import software.amazon.smithy.model.Model
17+
import software.amazon.smithy.model.shapes.OperationShape
18+
import software.amazon.smithy.model.shapes.ServiceShape
19+
20+
class MachineLearningEndpointCustomization : KotlinIntegration {
21+
override fun customizeMiddleware(
22+
ctx: ProtocolGenerator.GenerationContext,
23+
resolved: List<ProtocolMiddleware>
24+
): List<ProtocolMiddleware> =
25+
super
26+
.customizeMiddleware(ctx, resolved)
27+
.replace(endpointResolverMiddleware) { it is ResolveAwsEndpointMiddleware }
28+
29+
private val endpointResolverMiddleware = object : HttpFeatureMiddleware() {
30+
override val name: String = "ResolvePredictEndpoint"
31+
32+
override fun isEnabledFor(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean =
33+
op.id.name == "Predict"
34+
35+
override fun renderConfigure(writer: KotlinWriter) {
36+
writer.addImport(machineLearningSymbol("ResolvePredictEndpoint"))
37+
}
38+
39+
private fun machineLearningSymbol(name: String) = buildSymbol {
40+
this.name = name
41+
namespace = "aws.sdk.kotlin.services.machinelearning.internal"
42+
}
43+
}
44+
45+
override fun enabledForService(model: Model, settings: KotlinSettings): Boolean =
46+
model.expectShape<ServiceShape>(settings.service).sdkId.equals("Machine Learning", ignoreCase = true)
47+
}

codegen/smithy-aws-kotlin-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ aws.sdk.kotlin.codegen.customization.glacier.GlacierAccountIdDefault
1515
aws.sdk.kotlin.codegen.customization.polly.PollyPresigner
1616
aws.sdk.kotlin.codegen.customization.BoxServices
1717
aws.sdk.kotlin.codegen.customization.glacier.GlacierBodyChecksum
18+
aws.sdk.kotlin.codegen.customization.machinelearning.MachineLearningEndpointCustomization
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package aws.sdk.kotlin.services.machinelearning.internal
7+
8+
import aws.sdk.kotlin.runtime.endpoint.AwsEndpoint
9+
import aws.smithy.kotlin.runtime.client.ExecutionContext
10+
import aws.smithy.kotlin.runtime.util.AttributeKey
11+
12+
internal val predictEndpointKey = AttributeKey<AwsEndpoint>("PredictEndpointKey")
13+
14+
internal var ExecutionContext.predictEndpoint: AwsEndpoint?
15+
get() = getOrNull(predictEndpointKey)
16+
set(value) = set(predictEndpointKey, value!!)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package aws.sdk.kotlin.services.machinelearning.internal
7+
8+
import aws.sdk.kotlin.runtime.endpoint.AwsEndpoint
9+
import aws.sdk.kotlin.services.machinelearning.model.MachineLearningException
10+
import aws.sdk.kotlin.services.machinelearning.model.PredictRequest
11+
import aws.smithy.kotlin.runtime.http.Feature
12+
import aws.smithy.kotlin.runtime.http.FeatureKey
13+
import aws.smithy.kotlin.runtime.http.HttpClientFeatureFactory
14+
import aws.smithy.kotlin.runtime.http.middleware.setRequestEndpoint
15+
import aws.smithy.kotlin.runtime.http.operation.SdkHttpOperation
16+
17+
internal class ResolvePredictEndpoint : Feature {
18+
companion object Feature : HttpClientFeatureFactory<Unit, ResolvePredictEndpoint> {
19+
override val key: FeatureKey<ResolvePredictEndpoint> = FeatureKey("ResolvePredictEndpoint")
20+
override fun create(block: Unit.() -> Unit): ResolvePredictEndpoint = ResolvePredictEndpoint()
21+
}
22+
23+
override fun <I, O> install(operation: SdkHttpOperation<I, O>) {
24+
operation.execution.initialize.intercept { req, next ->
25+
val input = req.subject as PredictRequest
26+
if (input.predictEndpoint == null || input.predictEndpoint.isBlank()) {
27+
throw MachineLearningException("Predict requires predictEnpoint to be set to a non-empty value")
28+
}
29+
// Stash the endpoint for later use by the mutate interceptor
30+
req.context.predictEndpoint = AwsEndpoint(input.predictEndpoint)
31+
32+
next.call(req)
33+
}
34+
35+
operation.execution.mutate.intercept { req, next ->
36+
// This should've been set by the initialize interceptor
37+
val endpoint = req.context.predictEndpoint
38+
requireNotNull(endpoint) { "Predict endpoint wasn't set by middleware." }
39+
setRequestEndpoint(req, endpoint.endpoint)
40+
41+
next.call(req)
42+
}
43+
}
44+
}

0 commit comments

Comments
 (0)