Skip to content

Commit 11572b5

Browse files
committed
Implement HostResolver using CRT
1 parent b2a2625 commit 11572b5

File tree

3 files changed

+180
-1
lines changed

3 files changed

+180
-1
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package aws.sdk.kotlin.crt
7+
8+
// Indicates this API is internal-only and should not be depended on
9+
public annotation class InternalApi

aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/HostResolverNative.kt

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ package aws.sdk.kotlin.crt.io
88
import aws.sdk.kotlin.crt.*
99
import aws.sdk.kotlin.crt.util.ShutdownChannel
1010
import aws.sdk.kotlin.crt.util.shutdownChannel
11+
import aws.sdk.kotlin.crt.util.toAwsString
12+
import aws.sdk.kotlin.crt.util.toKString
1113
import kotlinx.cinterop.*
14+
import kotlinx.coroutines.channels.Channel
1215
import libcrt.*
1316

1417
@OptIn(ExperimentalForeignApi::class)
@@ -59,14 +62,101 @@ public actual class HostResolver private constructor(
5962

6063
if (manageElg) elg.close()
6164
}
65+
66+
public suspend fun resolve(hostname: String): List<CrtHostAddress> = memScoped {
67+
val awsHostname = hostname.toAwsString()
68+
val resultCallback = staticCFunction(::awsOnHostResolveFn)
69+
70+
val channel: Channel<List<CrtHostAddress>> = Channel(Channel.RENDEZVOUS)
71+
val channelStableRef = StableRef.create(channel)
72+
val userData = channelStableRef.asCPointer()
73+
74+
aws_host_resolver_resolve_host(ptr, awsHostname, resultCallback, aws_host_resolver_init_default_resolution_config(), userData)
75+
76+
return channel.receive().also {
77+
aws_string_destroy(awsHostname)
78+
channelStableRef.dispose()
79+
}
80+
}
6281
}
6382

6483
@OptIn(ExperimentalForeignApi::class)
6584
private fun onShutdownComplete(userData: COpaquePointer?) {
66-
if (userData == null) return
85+
if (userData == null) {
86+
return
87+
}
6788
val stableRef = userData.asStableRef<ShutdownChannel>()
6889
val ch = stableRef.get()
6990
ch.trySend(Unit)
7091
ch.close()
7192
stableRef.dispose()
7293
}
94+
95+
// implementation of `aws_on_host_resolved_result_fn`: https://github.com/awslabs/aws-c-io/blob/db7a1bddc9a29eca18734d0af189c3924775dcf1/include/aws/io/host_resolver.h#L53C14-L53C44
96+
private fun awsOnHostResolveFn(
97+
hostResolver: CPointer<aws_host_resolver>?,
98+
hostName: CPointer<aws_string>?,
99+
errCode: Int,
100+
hostAddresses: CPointer<aws_array_list>?, // list of `aws_host_address`
101+
userData: COpaquePointer?,
102+
): Unit = memScoped {
103+
println("In callback awsOnHostResolveFn")
104+
if (errCode != AWS_OP_SUCCESS) {
105+
throw CrtRuntimeException("aws_on_host_resolved_result_fn", ec = errCode)
106+
}
107+
if (userData == null) {
108+
throw CrtRuntimeException("aws_on_host_resolved_result_fn: userData unexpectedly null")
109+
}
110+
111+
val length = aws_array_list_length(hostAddresses)
112+
if (length == 0uL) {
113+
throw CrtRuntimeException("Failed to resolve host address for ${hostName?.toKString()}")
114+
}
115+
116+
val result = ArrayList<CrtHostAddress>(length.toInt())
117+
118+
val element = alloc<COpaquePointerVar>()
119+
for (i in 0uL until length) {
120+
awsAssertOpSuccess(
121+
aws_array_list_get_at_ptr(
122+
hostAddresses,
123+
element.ptr,
124+
i,
125+
),
126+
) { "aws_array_list_get_at_ptr failed at index $i" }
127+
128+
val elemOpaque = element.value ?: throw CrtRuntimeException("aws_host_addresses value at index $i unexpectedly null")
129+
val addr = elemOpaque.reinterpret<aws_host_address>().pointed
130+
131+
val hostStr = addr.host?.toKString() ?: throw CrtRuntimeException("aws_host_addresses `host` at index $i unexpectedly null")
132+
val addressStr = addr.address?.toKString() ?: throw CrtRuntimeException("aws_host_addresses `address` at index $i unexpectedly null")
133+
val addressType = when (addr.record_type) {
134+
aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_A -> AddressType.IpV4
135+
aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_AAAA -> AddressType.IpV6
136+
else -> throw CrtRuntimeException("received unsupported aws_host_address `aws_address_record_type`: ${addr.record_type}")
137+
}
138+
139+
result += CrtHostAddress(host = hostStr, address = addressStr, addressType)
140+
}
141+
142+
// Send results back through the channel
143+
val stableRef = userData.asStableRef<Channel<List<CrtHostAddress>>>()
144+
val channel = stableRef.get()
145+
channel.trySend(result)
146+
channel.close()
147+
}
148+
149+
// Minimal wrapper of aws_host_address
150+
// https://github.com/awslabs/aws-c-io/blob/db7a1bddc9a29eca18734d0af189c3924775dcf1/include/aws/io/host_resolver.h#L31
151+
@InternalApi
152+
public data class CrtHostAddress(
153+
val host: String,
154+
val address: String,
155+
val addressType: AddressType,
156+
)
157+
158+
@InternalApi
159+
public enum class AddressType {
160+
IpV4,
161+
IpV6,
162+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package aws.sdk.kotlin.crt.io
7+
8+
import kotlinx.coroutines.test.runTest
9+
import kotlin.test.Test
10+
import kotlin.test.assertEquals
11+
import kotlin.test.assertFails
12+
import kotlin.test.assertNotNull
13+
import kotlin.test.assertTrue
14+
15+
// Copied from smithy-kotlin
16+
class CrtHostResolverTest {
17+
private val ipv4Regex = Regex("""^(\d{1,3}\.){3}\d{1,3}$""")
18+
private val ipv6Regex = Regex("""^(([0-9a-fA-F]{1,4}:){1,7}[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,7}:|::([0-9a-fA-F]{1,4}:){0,6}[0-9a-fA-F]{1,4})$""")
19+
20+
@Test
21+
fun testResolveLocalhost() = runTest {
22+
val hr = HostResolver()
23+
val addresses = hr.resolve("localhost")
24+
assertTrue(addresses.isNotEmpty())
25+
26+
addresses.forEach { addr ->
27+
assertEquals("localhost", addr.host)
28+
29+
val localHostAddress = when (addr.addressType) {
30+
AddressType.IpV4 -> "127.0.0.1"
31+
AddressType.IpV6 -> "::1"
32+
}
33+
assertEquals(localHostAddress, addr.address)
34+
}
35+
}
36+
37+
@Test
38+
fun testResolveIpv4Address() = runTest {
39+
val addresses = HostResolver().resolve("127.0.0.1")
40+
assertTrue(addresses.isNotEmpty())
41+
42+
addresses.forEach { addr ->
43+
assertEquals(AddressType.IpV4, addr.addressType)
44+
assertEquals("127.0.0.1", addr.address)
45+
}
46+
}
47+
48+
@Test
49+
fun testResolveIpv6Address() = runTest {
50+
val addresses = HostResolver().resolve("::1")
51+
assertTrue(addresses.isNotEmpty())
52+
53+
addresses.forEach { addr ->
54+
assertEquals(AddressType.IpV6, addr.addressType)
55+
assertEquals("::1", addr.address)
56+
}
57+
}
58+
59+
@Test
60+
fun testResolveExampleDomain() = runTest {
61+
val addresses = HostResolver().resolve("example.com")
62+
assertNotNull(addresses)
63+
assertTrue(addresses.isNotEmpty())
64+
65+
addresses.forEach { addr ->
66+
assertEquals("example.com", addr.host)
67+
when (addr.addressType) {
68+
AddressType.IpV4 -> assertEquals(3, addr.address.count { it == '.' })
69+
AddressType.IpV6 -> assertEquals(if (addr.address.contains("::")) 6 else 7, addr.address.count { it == ':' })
70+
}
71+
}
72+
}
73+
74+
@Test
75+
fun testResolveInvalidDomain() = runTest {
76+
assertFails {
77+
HostResolver().resolve("this-domain-definitely-does-not-exist-12345.local")
78+
}
79+
}
80+
}

0 commit comments

Comments
 (0)