Skip to content

Commit 4227caf

Browse files
authored
feat: implement recursion detection middleware (#602)
1 parent 2338e00 commit 4227caf

File tree

5 files changed

+237
-0
lines changed

5 files changed

+237
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"id": "2d52d36d-564b-4e31-a7fa-14d8839d4a96",
3+
"type": "feature",
4+
"description": "Implement recursion detection middleware."
5+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package aws.sdk.kotlin.runtime.http.middleware
7+
8+
import aws.sdk.kotlin.runtime.InternalSdkApi
9+
import aws.smithy.kotlin.runtime.http.operation.ModifyRequestMiddleware
10+
import aws.smithy.kotlin.runtime.http.operation.SdkHttpRequest
11+
import aws.smithy.kotlin.runtime.util.EnvironmentProvider
12+
import aws.smithy.kotlin.runtime.util.Platform
13+
import aws.smithy.kotlin.runtime.util.text.percentEncodeTo
14+
15+
internal const val ENV_FUNCTION_NAME = "AWS_LAMBDA_FUNCTION_NAME"
16+
internal const val ENV_TRACE_ID = "_X_AMZN_TRACE_ID"
17+
internal const val HEADER_TRACE_ID = "X-Amzn-Trace-Id"
18+
19+
/**
20+
* HTTP middleware to add the recursion detection header where required.
21+
*/
22+
@InternalSdkApi
23+
public class RecursionDetection(
24+
private val env: EnvironmentProvider = Platform
25+
) : ModifyRequestMiddleware {
26+
override suspend fun modifyRequest(req: SdkHttpRequest): SdkHttpRequest {
27+
if (req.subject.headers.contains(HEADER_TRACE_ID)) return req
28+
29+
val traceId = env.getenv(ENV_TRACE_ID)
30+
if (env.getenv(ENV_FUNCTION_NAME) == null || traceId == null) return req
31+
32+
req.subject.headers[HEADER_TRACE_ID] = traceId.percentEncode()
33+
return req
34+
}
35+
}
36+
37+
/**
38+
* Percent-encode ISO control characters for the purposes of this specific header.
39+
*
40+
* The existing `Char::isISOControl` check cannot be used here, because that matches against characters in
41+
* `[0x00, 0x1f] U [0x7f, 0x9f]`. The SEP for recursion detection dictates we should only encode across
42+
* `[0x00, 0x1f]`.
43+
*/
44+
private fun String.percentEncode(): String {
45+
val sb = StringBuilder(this.length)
46+
val data = this.encodeToByteArray()
47+
for (cbyte in data) {
48+
val chr = cbyte.toInt().toChar()
49+
if (chr.code in 0x00..0x1f) {
50+
cbyte.percentEncodeTo(sb)
51+
} else {
52+
sb.append(chr)
53+
}
54+
}
55+
return sb.toString()
56+
}
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package aws.sdk.kotlin.runtime.http.middleware
7+
8+
import aws.sdk.kotlin.runtime.testing.TestPlatformProvider
9+
import aws.smithy.kotlin.runtime.client.ExecutionContext
10+
import aws.smithy.kotlin.runtime.http.Headers
11+
import aws.smithy.kotlin.runtime.http.HttpBody
12+
import aws.smithy.kotlin.runtime.http.HttpStatusCode
13+
import aws.smithy.kotlin.runtime.http.engine.HttpClientEngineBase
14+
import aws.smithy.kotlin.runtime.http.operation.*
15+
import aws.smithy.kotlin.runtime.http.request.HttpRequest
16+
import aws.smithy.kotlin.runtime.http.request.HttpRequestBuilder
17+
import aws.smithy.kotlin.runtime.http.response.HttpCall
18+
import aws.smithy.kotlin.runtime.http.response.HttpResponse
19+
import aws.smithy.kotlin.runtime.http.sdkHttpClient
20+
import aws.smithy.kotlin.runtime.time.Instant
21+
import aws.smithy.kotlin.runtime.util.get
22+
import kotlinx.coroutines.ExperimentalCoroutinesApi
23+
import kotlinx.coroutines.test.runTest
24+
import kotlin.test.Test
25+
import kotlin.test.assertEquals
26+
import kotlin.test.assertFalse
27+
28+
@OptIn(ExperimentalCoroutinesApi::class)
29+
class RecursionDetectionTest {
30+
private class TraceHeaderSerializer(
31+
private val traceHeader: String
32+
) : HttpSerialize<Unit> {
33+
override suspend fun serialize(context: ExecutionContext, input: Unit): HttpRequestBuilder {
34+
val builder = HttpRequestBuilder()
35+
builder.headers[HEADER_TRACE_ID] = traceHeader
36+
return builder
37+
}
38+
}
39+
40+
private val mockEngine = object : HttpClientEngineBase("test") {
41+
override suspend fun roundTrip(request: HttpRequest): HttpCall {
42+
val resp = HttpResponse(HttpStatusCode.fromValue(200), Headers.Empty, HttpBody.Empty)
43+
val now = Instant.now()
44+
return HttpCall(request, resp, now, now)
45+
}
46+
}
47+
48+
private val client = sdkHttpClient(mockEngine)
49+
50+
private suspend fun test(
51+
env: Map<String, String>,
52+
existingTraceHeader: String?,
53+
expectedTraceHeader: String?
54+
) {
55+
val op = SdkHttpOperation.build<Unit, HttpResponse> {
56+
serializer = if (existingTraceHeader != null) TraceHeaderSerializer(existingTraceHeader) else UnitSerializer
57+
deserializer = IdentityDeserializer
58+
context {
59+
service = "Test Service"
60+
operationName = "testOperation"
61+
}
62+
}
63+
64+
val provider = TestPlatformProvider(env)
65+
op.install(RecursionDetection(provider))
66+
op.roundTrip(client, Unit)
67+
68+
val request = op.context[HttpOperationContext.HttpCallList].last().request
69+
if (expectedTraceHeader != null) {
70+
assertEquals(expectedTraceHeader, request.headers[HEADER_TRACE_ID])
71+
} else {
72+
assertFalse(request.headers.contains(HEADER_TRACE_ID))
73+
}
74+
}
75+
76+
@Test
77+
fun `it noops if env unset`() = runTest {
78+
test(
79+
emptyMap(),
80+
null,
81+
null
82+
)
83+
}
84+
85+
@Test
86+
fun `it sets header when both envs are present`() = runTest {
87+
test(
88+
mapOf(
89+
ENV_FUNCTION_NAME to "some-function",
90+
ENV_TRACE_ID to "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2"
91+
),
92+
null,
93+
"Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2"
94+
)
95+
}
96+
97+
@Test
98+
fun `it noops if trace env set but no lambda env`() = runTest {
99+
test(
100+
mapOf(
101+
ENV_TRACE_ID to "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2"
102+
),
103+
null,
104+
null
105+
)
106+
}
107+
108+
@Test
109+
fun `it respects existing trace header`() = runTest {
110+
test(
111+
mapOf(
112+
ENV_FUNCTION_NAME to "some-function",
113+
ENV_TRACE_ID to "EnvValue"
114+
),
115+
"OriginalValue",
116+
"OriginalValue"
117+
)
118+
}
119+
120+
@Test
121+
fun `it url encodes new trace header`() = runTest {
122+
test(
123+
mapOf(
124+
ENV_FUNCTION_NAME to "some-function",
125+
ENV_TRACE_ID to "first\nsecond"
126+
),
127+
null,
128+
"first%0Asecond"
129+
)
130+
}
131+
132+
@Test
133+
fun `ignores other chars that are usually percent encoded`() = runTest {
134+
test(
135+
mapOf(
136+
ENV_FUNCTION_NAME to "some-function",
137+
ENV_TRACE_ID to "test123-=;:+&[]{}\"'"
138+
),
139+
null,
140+
"test123-=;:+&[]{}\"'"
141+
)
142+
}
143+
}

codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/core/AwsHttpBindingProtocolGenerator.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import aws.sdk.kotlin.codegen.AwsKotlinDependency
88
import aws.sdk.kotlin.codegen.AwsRuntimeTypes
99
import aws.sdk.kotlin.codegen.protocols.eventstream.EventStreamParserGenerator
1010
import aws.sdk.kotlin.codegen.protocols.eventstream.EventStreamSerializerGenerator
11+
import aws.sdk.kotlin.codegen.protocols.middleware.RecursionDetectionMiddleware
1112
import aws.sdk.kotlin.codegen.protocols.middleware.ResolveAwsEndpointMiddleware
1213
import aws.sdk.kotlin.codegen.protocols.middleware.UserAgentMiddleware
1314
import aws.sdk.kotlin.codegen.protocols.protocoltest.AwsHttpProtocolUnitTestErrorGenerator
@@ -48,6 +49,7 @@ abstract class AwsHttpBindingProtocolGenerator : HttpBindingProtocolGenerator()
4849
}.toMutableList()
4950

5051
middleware.add(UserAgentMiddleware())
52+
middleware.add(RecursionDetectionMiddleware())
5153
return middleware
5254
}
5355

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package aws.sdk.kotlin.codegen.protocols.middleware
7+
8+
import aws.sdk.kotlin.codegen.AwsKotlinDependency
9+
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
10+
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
11+
import software.amazon.smithy.kotlin.codegen.model.namespace
12+
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
13+
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolMiddleware
14+
import software.amazon.smithy.model.shapes.OperationShape
15+
16+
/**
17+
* HTTP middleware to add the recursion detection header where required.
18+
*/
19+
class RecursionDetectionMiddleware : ProtocolMiddleware {
20+
override val name: String = "RecursionDetection"
21+
override val order: Byte = 30
22+
23+
private val middlewareSymbol = buildSymbol {
24+
name = "RecursionDetection"
25+
namespace(AwsKotlinDependency.AWS_HTTP, subpackage = "middleware")
26+
}
27+
28+
override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) {
29+
writer.write("op.install(#T())", middlewareSymbol)
30+
}
31+
}

0 commit comments

Comments
 (0)