Skip to content

Commit a79931f

Browse files
authored
feat: implement HostResolver using CRT (#196)
1 parent e368318 commit a79931f

File tree

7 files changed

+376
-7
lines changed

7 files changed

+376
-7
lines changed

aws-crt-kotlin/build.gradle.kts

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@ import aws.sdk.kotlin.gradle.crt.CMakeBuildType
66
import aws.sdk.kotlin.gradle.crt.cmakeInstallDir
77
import aws.sdk.kotlin.gradle.crt.configureCrtCMakeBuild
88
import aws.sdk.kotlin.gradle.dsl.configurePublishing
9+
import aws.sdk.kotlin.gradle.kmp.NATIVE_ENABLED
910
import aws.sdk.kotlin.gradle.kmp.configureIosSimulatorTasks
1011
import aws.sdk.kotlin.gradle.kmp.configureKmpTargets
1112
import aws.sdk.kotlin.gradle.util.typedProp
1213
import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget
1314
import org.jetbrains.kotlin.konan.target.Family
1415
import org.jetbrains.kotlin.konan.target.HostManager
16+
import java.nio.file.Files
17+
import java.nio.file.Paths
1518

1619
plugins {
1720
alias(libs.plugins.kotlin.multiplatform)
@@ -103,6 +106,58 @@ kotlin {
103106
}
104107
}
105108
}
109+
110+
if (NATIVE_ENABLED && HostManager.hostIsMingw) {
111+
mingwX64 {
112+
val mingwHome = findMingwHome()
113+
val defPath = layout.buildDirectory.file("cinterop/winver.def")
114+
115+
// Dynamically construct def file because of dynamic mingw paths
116+
val defFileTask by tasks.registering {
117+
outputs.file(defPath)
118+
119+
val mingwLibs = Paths.get(mingwHome, "lib").toString().replace("\\", "\\\\") // Windows path shenanigans
120+
121+
doLast {
122+
Files.writeString(
123+
defPath.get().asFile.toPath(),
124+
"""
125+
package = aws.sdk.kotlin.crt.winver
126+
headers = windows.h
127+
compilerOpts = \
128+
-DUNICODE \
129+
-DWINVER=0x0601 \
130+
-D_WIN32_WINNT=0x0601 \
131+
-DWINAPI_FAMILY=3 \
132+
-DOEMRESOURCE \
133+
-Wno-incompatible-pointer-types \
134+
-Wno-deprecated-declarations
135+
libraryPaths = $mingwLibs
136+
staticLibraries = libversion.a
137+
""".trimIndent(),
138+
)
139+
}
140+
}
141+
compilations["main"].cinterops {
142+
create("winver") {
143+
val mingwIncludes = Paths.get(mingwHome, "include").toString()
144+
includeDirs(mingwIncludes)
145+
definitionFile.set(defPath)
146+
147+
// Ensure that the def file is written first
148+
tasks[interopProcessingTaskName].dependsOn(defFileTask)
149+
}
150+
}
151+
152+
// TODO clean up
153+
val compilerArgs = listOf(
154+
"-Xverbose-phases=linker", // Enable verbose linking phase from the compiler
155+
"-linker-option",
156+
"-v",
157+
)
158+
compilerOptions.freeCompilerArgs.addAll(compilerArgs)
159+
}
160+
}
106161
}
107162

108163
configureIosSimulatorTasks()
@@ -168,3 +223,11 @@ private val KotlinNativeTarget.isBuildableOnHost: Boolean
168223
throw Exception("Unsupported host: ${HostManager.host}")
169224
}
170225
}
226+
227+
private fun findMingwHome(): String =
228+
System.getenv("MINGW_PREFIX")?.takeUnless { it.isBlank() }
229+
?: typedProp("mingw.prefix")
230+
?: throw IllegalStateException(
231+
"Cannot determine MinGW prefix location. Please verify MinGW is installed correctly " +
232+
"and that either the `MINGW_PREFIX` environment variable or the `mingw.prefix` Gradle property is set.",
233+
)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package aws.sdk.kotlin.crt.util
6+
7+
import aws.sdk.kotlin.crt.winver.*
8+
import kotlinx.cinterop.*
9+
import platform.posix.memcpy
10+
11+
// The following code is used to resolve the OS version on Windows machines.
12+
// This eventually ends up in aws-sdk-kotlin's User-Agent header.
13+
// It briefly lived in smithy-kotlin but was removed because it needs a MinGW installation to compile
14+
// FIXME Can we get the same information without all these API calls?
15+
16+
// The functions below are adapted from C++ SDK:
17+
// https://github.com/aws/aws-sdk-cpp/blob/0e6085bf0dd9a1cb1f27d101c4cf2db6ade6f307/src/aws-cpp-sdk-core/source/platform/windows/OSVersionInfo.cpp#L49-L106
18+
private val wordHexFormat = HexFormat {
19+
upperCase = false
20+
number {
21+
removeLeadingZeros = true
22+
minLength = 4
23+
}
24+
}
25+
26+
private data class LangCodePage(
27+
val language: UShort,
28+
val codePage: UShort,
29+
)
30+
31+
public fun osVersionFromKernel(): String? = memScoped {
32+
withFileVersionInfo("Kernel32.dll") { versionInfoPtr ->
33+
getLangCodePage(versionInfoPtr)?.let { langCodePage ->
34+
getProductVersion(versionInfoPtr, langCodePage)
35+
}
36+
}
37+
}
38+
39+
private inline fun <R> withFileVersionInfo(fileName: String, block: (CPointer<ByteVarOf<Byte>>) -> R?): R? {
40+
val blobSize = GetFileVersionInfoSizeW(fileName, null)
41+
val blob = ByteArray(blobSize.convert())
42+
blob.usePinned { pinned ->
43+
val result = GetFileVersionInfoW(fileName, 0u, blobSize, pinned.addressOf(0))
44+
return if (result == 0) {
45+
null
46+
} else {
47+
block(pinned.addressOf(0))
48+
}
49+
}
50+
}
51+
52+
private fun MemScope.getLangCodePage(versionInfoPtr: CPointer<ByteVarOf<Byte>>): LangCodePage? {
53+
// Get _any_ language pack and codepage since they should all have the same version
54+
val langAndCodePagePtr = alloc<COpaquePointerVar>()
55+
val codePageSize = alloc<UIntVar>()
56+
val result = VerQueryValueW(
57+
versionInfoPtr,
58+
"""\VarFileInfo\Translation""",
59+
langAndCodePagePtr.ptr,
60+
codePageSize.ptr,
61+
)
62+
63+
return if (result == 0) {
64+
null
65+
} else {
66+
val langAndCodePage = langAndCodePagePtr.value!!.reinterpret<UIntVar>().pointed.value
67+
val language = (langAndCodePage and 0x0000ffffu).toUShort() // low WORD
68+
val codePage = (langAndCodePage and 0xffff0000u shr 16).toUShort() // high WORD
69+
LangCodePage(language, codePage)
70+
}
71+
}
72+
73+
private fun MemScope.getProductVersion(versionInfoPtr: CPointer<ByteVarOf<Byte>>, langCodePage: LangCodePage): String? {
74+
val versionId = buildString {
75+
// Something like: \StringFileInfo\04090fb0\ProductVersion
76+
append("""\StringFileInfo\""")
77+
append(langCodePage.language.toHexString(wordHexFormat))
78+
append(langCodePage.codePage.toHexString(wordHexFormat))
79+
append("""\ProductVersion""")
80+
}
81+
82+
// Get the block corresponding to versionId
83+
val block = alloc<COpaquePointerVar>()
84+
val blockSize = alloc<UIntVar>()
85+
val result = VerQueryValueW(versionInfoPtr, versionId, block.ptr, blockSize.ptr)
86+
87+
return if (result == 0) {
88+
null
89+
} else {
90+
// Copy the bytes into a Kotlin byte array
91+
val blockBytes = ByteArray(blockSize.value.convert())
92+
blockBytes.usePinned { pinned ->
93+
memcpy(pinned.addressOf(0), block.value!!.reinterpret<ByteVar>(), blockSize.value.convert())
94+
}
95+
blockBytes.decodeToString()
96+
}
97+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package aws.sdk.kotlin.crt.util
7+
8+
import kotlinx.coroutines.test.runTest
9+
import kotlin.test.Test
10+
import kotlin.test.assertNotNull
11+
12+
class OsVersionTest {
13+
@Test
14+
fun testOsInfo() = runTest {
15+
val version = osVersionFromKernel()
16+
assertNotNull(version)
17+
}
18+
}

aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/HostResolverNative.kt

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ package aws.sdk.kotlin.crt.io
88
import aws.sdk.kotlin.crt.*
99
import aws.sdk.kotlin.crt.util.ShutdownChannel
1010
import aws.sdk.kotlin.crt.util.shutdownChannel
11+
import aws.sdk.kotlin.crt.util.toAwsString
12+
import aws.sdk.kotlin.crt.util.toKString
1113
import kotlinx.cinterop.*
14+
import kotlinx.coroutines.channels.Channel
1215
import libcrt.*
1316

1417
@OptIn(ExperimentalForeignApi::class)
@@ -59,14 +62,112 @@ public actual class HostResolver private constructor(
5962

6063
if (manageElg) elg.close()
6164
}
65+
66+
public suspend fun resolve(hostname: String): List<CrtHostAddress> = memScoped {
67+
val awsHostname = hostname.toAwsString()
68+
val resultCallback = staticCFunction(::awsOnHostResolveFn)
69+
70+
val channel: Channel<Result<List<CrtHostAddress>>> = Channel(Channel.RENDEZVOUS)
71+
val channelStableRef = StableRef.create(channel)
72+
val userData = channelStableRef.asCPointer()
73+
74+
aws_host_resolver_resolve_host(ptr, awsHostname, resultCallback, aws_host_resolver_init_default_resolution_config(), userData)
75+
76+
return channel.receive().getOrThrow().also {
77+
aws_string_destroy(awsHostname)
78+
channelStableRef.dispose()
79+
}
80+
}
6281
}
6382

6483
@OptIn(ExperimentalForeignApi::class)
6584
private fun onShutdownComplete(userData: COpaquePointer?) {
66-
if (userData == null) return
85+
if (userData == null) {
86+
return
87+
}
6788
val stableRef = userData.asStableRef<ShutdownChannel>()
6889
val ch = stableRef.get()
6990
ch.trySend(Unit)
7091
ch.close()
7192
stableRef.dispose()
7293
}
94+
95+
// implementation of `aws_on_host_resolved_result_fn`: https://github.com/awslabs/aws-c-io/blob/db7a1bddc9a29eca18734d0af189c3924775dcf1/include/aws/io/host_resolver.h#L53C14-L53C44
96+
private fun awsOnHostResolveFn(
97+
hostResolver: CPointer<aws_host_resolver>?,
98+
hostName: CPointer<aws_string>?,
99+
errCode: Int,
100+
hostAddresses: CPointer<aws_array_list>?, // list of `aws_host_address`
101+
userData: COpaquePointer?,
102+
): Unit = memScoped {
103+
if (userData == null) {
104+
throw CrtRuntimeException("aws_on_host_resolved_result_fn: userData unexpectedly null")
105+
}
106+
107+
val stableRef = userData.asStableRef<Channel<Result<List<CrtHostAddress>>>>()
108+
val channel = stableRef.get()
109+
110+
try {
111+
if (errCode != AWS_OP_SUCCESS) {
112+
throw CrtRuntimeException("aws_on_host_resolved_result_fn", ec = errCode)
113+
}
114+
115+
val length = aws_array_list_length(hostAddresses)
116+
if (length == 0uL) {
117+
throw CrtRuntimeException("Failed to resolve host address for ${hostName?.toKString()}")
118+
}
119+
120+
val addressList = ArrayList<CrtHostAddress>(length.toInt())
121+
122+
val element = alloc<COpaquePointerVar>()
123+
for (i in 0uL until length) {
124+
awsAssertOpSuccess(
125+
aws_array_list_get_at_ptr(
126+
hostAddresses,
127+
element.ptr,
128+
i,
129+
),
130+
) { "aws_array_list_get_at_ptr failed at index $i" }
131+
132+
val elemOpaque = element.value ?: run {
133+
throw CrtRuntimeException("aws_host_addresses value at index $i unexpectedly null")
134+
}
135+
136+
val addr = elemOpaque.reinterpret<aws_host_address>().pointed
137+
138+
val hostStr = addr.host?.toKString() ?: run {
139+
throw CrtRuntimeException("aws_host_addresses `host` at index $i unexpectedly null")
140+
}
141+
val addressStr = addr.address?.toKString() ?: run {
142+
throw CrtRuntimeException("aws_host_addresses `address` at index $i unexpectedly null")
143+
}
144+
145+
val addressType = when (addr.record_type) {
146+
aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_A -> AddressType.IpV4
147+
aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_AAAA -> AddressType.IpV6
148+
else -> throw CrtRuntimeException("received unsupported aws_host_address `aws_address_record_type`: ${addr.record_type}")
149+
}
150+
151+
addressList += CrtHostAddress(host = hostStr, address = addressStr, addressType)
152+
}
153+
154+
channel.trySend(Result.success(addressList))
155+
} catch (e: Exception) {
156+
channel.trySend(Result.failure(e))
157+
} finally {
158+
channel.close()
159+
}
160+
}
161+
162+
// Minimal wrapper of aws_host_address
163+
// https://github.com/awslabs/aws-c-io/blob/db7a1bddc9a29eca18734d0af189c3924775dcf1/include/aws/io/host_resolver.h#L31
164+
public data class CrtHostAddress(
165+
val host: String,
166+
val address: String,
167+
val addressType: AddressType,
168+
)
169+
170+
public enum class AddressType {
171+
IpV4,
172+
IpV6,
173+
}

0 commit comments

Comments
 (0)