Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions aws-crt-kotlin/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ import aws.sdk.kotlin.gradle.crt.CMakeBuildType
import aws.sdk.kotlin.gradle.crt.cmakeInstallDir
import aws.sdk.kotlin.gradle.crt.configureCrtCMakeBuild
import aws.sdk.kotlin.gradle.dsl.configurePublishing
import aws.sdk.kotlin.gradle.kmp.NATIVE_ENABLED
import aws.sdk.kotlin.gradle.kmp.configureIosSimulatorTasks
import aws.sdk.kotlin.gradle.kmp.configureKmpTargets
import aws.sdk.kotlin.gradle.util.typedProp
import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget
import org.jetbrains.kotlin.konan.target.Family
import org.jetbrains.kotlin.konan.target.HostManager
import java.nio.file.Files
import java.nio.file.Paths

plugins {
alias(libs.plugins.kotlin.multiplatform)
Expand Down Expand Up @@ -103,6 +106,58 @@ kotlin {
}
}
}

if (NATIVE_ENABLED && HostManager.hostIsMingw) {
mingwX64 {
val mingwHome = findMingwHome()
val defPath = layout.buildDirectory.file("cinterop/winver.def")

// Dynamically construct def file because of dynamic mingw paths
val defFileTask by tasks.registering {
outputs.file(defPath)

val mingwLibs = Paths.get(mingwHome, "lib").toString().replace("\\", "\\\\") // Windows path shenanigans

doLast {
Files.writeString(
defPath.get().asFile.toPath(),
"""
package = aws.sdk.kotlin.crt.winver
headers = windows.h
compilerOpts = \
-DUNICODE \
-DWINVER=0x0601 \
-D_WIN32_WINNT=0x0601 \
-DWINAPI_FAMILY=3 \
-DOEMRESOURCE \
-Wno-incompatible-pointer-types \
-Wno-deprecated-declarations
libraryPaths = $mingwLibs
staticLibraries = libversion.a
""".trimIndent(),
)
}
}
compilations["main"].cinterops {
create("winver") {
val mingwIncludes = Paths.get(mingwHome, "include").toString()
includeDirs(mingwIncludes)
definitionFile.set(defPath)

// Ensure that the def file is written first
tasks[interopProcessingTaskName].dependsOn(defFileTask)
}
}

// TODO clean up
val compilerArgs = listOf(
"-Xverbose-phases=linker", // Enable verbose linking phase from the compiler
"-linker-option",
"-v",
)
compilerOptions.freeCompilerArgs.addAll(compilerArgs)
}
}
}

configureIosSimulatorTasks()
Expand Down Expand Up @@ -168,3 +223,11 @@ private val KotlinNativeTarget.isBuildableOnHost: Boolean
throw Exception("Unsupported host: ${HostManager.host}")
}
}

private fun findMingwHome(): String =
System.getenv("MINGW_PREFIX")?.takeUnless { it.isBlank() }
?: typedProp("mingw.prefix")
?: throw IllegalStateException(
"Cannot determine MinGW prefix location. Please verify MinGW is installed correctly " +
"and that either the `MINGW_PREFIX` environment variable or the `mingw.prefix` Gradle property is set.",
)
93 changes: 93 additions & 0 deletions aws-crt-kotlin/mingw/src/aws/sdk/kotlin/crt/util/OsVersion.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package aws.sdk.kotlin.crt.util

import aws.sdk.kotlin.crt.winver.*
import kotlinx.cinterop.*
import platform.posix.memcpy

// The functions below are adapted from C++ SDK:
// https://github.com/aws/aws-sdk-cpp/blob/0e6085bf0dd9a1cb1f27d101c4cf2db6ade6f307/src/aws-cpp-sdk-core/source/platform/windows/OSVersionInfo.cpp#L49-L106
Comment on lines +16 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Let's add a comment explaining why this is here so that Future We are not confused.


private val wordHexFormat = HexFormat {
upperCase = false
number {
removeLeadingZeros = true
minLength = 4
}
}

private data class LangCodePage(
val language: UShort,
val codePage: UShort,
)

public fun osVersionFromKernel(): String? = memScoped {
withFileVersionInfo("Kernel32.dll") { versionInfoPtr ->
getLangCodePage(versionInfoPtr)?.let { langCodePage ->
getProductVersion(versionInfoPtr, langCodePage)
}
}
}

private inline fun <R> withFileVersionInfo(fileName: String, block: (CPointer<ByteVarOf<Byte>>) -> R?): R? {
val blobSize = GetFileVersionInfoSizeW(fileName, null)
val blob = ByteArray(blobSize.convert())
blob.usePinned { pinned ->
val result = GetFileVersionInfoW(fileName, 0u, blobSize, pinned.addressOf(0))
return if (result == 0) {
null
} else {
block(pinned.addressOf(0))
}
}
}

private fun MemScope.getLangCodePage(versionInfoPtr: CPointer<ByteVarOf<Byte>>): LangCodePage? {
// Get _any_ language pack and codepage since they should all have the same version
val langAndCodePagePtr = alloc<COpaquePointerVar>()
val codePageSize = alloc<UIntVar>()
val result = VerQueryValueW(
versionInfoPtr,
"""\VarFileInfo\Translation""",
langAndCodePagePtr.ptr,
codePageSize.ptr,
)

return if (result == 0) {
null
} else {
val langAndCodePage = langAndCodePagePtr.value!!.reinterpret<UIntVar>().pointed.value
val language = (langAndCodePage and 0x0000ffffu).toUShort() // low WORD
val codePage = (langAndCodePage and 0xffff0000u shr 16).toUShort() // high WORD
LangCodePage(language, codePage)
}
}

private fun MemScope.getProductVersion(versionInfoPtr: CPointer<ByteVarOf<Byte>>, langCodePage: LangCodePage): String? {
val versionId = buildString {
// Something like: \StringFileInfo\04090fb0\ProductVersion
append("""\StringFileInfo\""")
append(langCodePage.language.toHexString(wordHexFormat))
append(langCodePage.codePage.toHexString(wordHexFormat))
append("""\ProductVersion""")
}

// Get the block corresponding to versionId
val block = alloc<COpaquePointerVar>()
val blockSize = alloc<UIntVar>()
val result = VerQueryValueW(versionInfoPtr, versionId, block.ptr, blockSize.ptr)

return if (result == 0) {
null
} else {
// Copy the bytes into a Kotlin byte array
val blockBytes = ByteArray(blockSize.value.convert())
blockBytes.usePinned { pinned ->
memcpy(pinned.addressOf(0), block.value!!.reinterpret<ByteVar>(), blockSize.value.convert())
}
blockBytes.decodeToString()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package aws.sdk.kotlin.crt.util

import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertNotNull

class OsVersionTest {
@Test
fun testOsInfo() = runTest {
val version = osVersionFromKernel()
assertNotNull(version)
}
Comment on lines +13 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This only runs on Windows so we should verify the correct OS family is returned:

val version = osVersionFromKernel()
assertEquals(OsFamily.Windows, version.family)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

osVersionFromKernel just returns the string version, not the family

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ package aws.sdk.kotlin.crt.io
import aws.sdk.kotlin.crt.*
import aws.sdk.kotlin.crt.util.ShutdownChannel
import aws.sdk.kotlin.crt.util.shutdownChannel
import aws.sdk.kotlin.crt.util.toAwsString
import aws.sdk.kotlin.crt.util.toKString
import kotlinx.cinterop.*
import kotlinx.coroutines.channels.Channel
import libcrt.*

@OptIn(ExperimentalForeignApi::class)
Expand Down Expand Up @@ -59,14 +62,112 @@ public actual class HostResolver private constructor(

if (manageElg) elg.close()
}

public suspend fun resolve(hostname: String): List<CrtHostAddress> = memScoped {
val awsHostname = hostname.toAwsString()
val resultCallback = staticCFunction(::awsOnHostResolveFn)

val channel: Channel<Result<List<CrtHostAddress>>> = Channel(Channel.RENDEZVOUS)
val channelStableRef = StableRef.create(channel)
val userData = channelStableRef.asCPointer()

aws_host_resolver_resolve_host(ptr, awsHostname, resultCallback, aws_host_resolver_init_default_resolution_config(), userData)

return channel.receive().getOrThrow().also {
aws_string_destroy(awsHostname)
channelStableRef.dispose()
}
}
}

@OptIn(ExperimentalForeignApi::class)
private fun onShutdownComplete(userData: COpaquePointer?) {
if (userData == null) return
if (userData == null) {
return
}
val stableRef = userData.asStableRef<ShutdownChannel>()
val ch = stableRef.get()
ch.trySend(Unit)
ch.close()
stableRef.dispose()
}

// implementation of `aws_on_host_resolved_result_fn`: https://github.com/awslabs/aws-c-io/blob/db7a1bddc9a29eca18734d0af189c3924775dcf1/include/aws/io/host_resolver.h#L53C14-L53C44
private fun awsOnHostResolveFn(
hostResolver: CPointer<aws_host_resolver>?,
hostName: CPointer<aws_string>?,
errCode: Int,
hostAddresses: CPointer<aws_array_list>?, // list of `aws_host_address`
userData: COpaquePointer?,
): Unit = memScoped {
if (userData == null) {
throw CrtRuntimeException("aws_on_host_resolved_result_fn: userData unexpectedly null")
}

val stableRef = userData.asStableRef<Channel<Result<List<CrtHostAddress>>>>()
val channel = stableRef.get()

try {
if (errCode != AWS_OP_SUCCESS) {
throw CrtRuntimeException("aws_on_host_resolved_result_fn", ec = errCode)
}

val length = aws_array_list_length(hostAddresses)
if (length == 0uL) {
throw CrtRuntimeException("Failed to resolve host address for ${hostName?.toKString()}")
}

val addressList = ArrayList<CrtHostAddress>(length.toInt())

val element = alloc<COpaquePointerVar>()
for (i in 0uL until length) {
awsAssertOpSuccess(
aws_array_list_get_at_ptr(
hostAddresses,
element.ptr,
i,
),
) { "aws_array_list_get_at_ptr failed at index $i" }

val elemOpaque = element.value ?: run {
throw CrtRuntimeException("aws_host_addresses value at index $i unexpectedly null")
}

val addr = elemOpaque.reinterpret<aws_host_address>().pointed

val hostStr = addr.host?.toKString() ?: run {
throw CrtRuntimeException("aws_host_addresses `host` at index $i unexpectedly null")
}
val addressStr = addr.address?.toKString() ?: run {
throw CrtRuntimeException("aws_host_addresses `address` at index $i unexpectedly null")
}

val addressType = when (addr.record_type) {
aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_A -> AddressType.IpV4
aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_AAAA -> AddressType.IpV6
else -> throw CrtRuntimeException("received unsupported aws_host_address `aws_address_record_type`: ${addr.record_type}")
}

addressList += CrtHostAddress(host = hostStr, address = addressStr, addressType)
}

channel.trySend(Result.success(addressList))
} catch (e: Exception) {
channel.trySend(Result.failure(e))
} finally {
channel.close()
}
}

// Minimal wrapper of aws_host_address
// https://github.com/awslabs/aws-c-io/blob/db7a1bddc9a29eca18734d0af189c3924775dcf1/include/aws/io/host_resolver.h#L31
public data class CrtHostAddress(
val host: String,
val address: String,
val addressType: AddressType,
)

public enum class AddressType {
IpV4,
IpV6,
}
Loading