Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0346992
Add boundary checks to checksums V1877476818
lauzadis Aug 7, 2025
6251616
usePinned and dispose StableRef V1877474026
lauzadis Aug 7, 2025
5eb84bd
Support buffers of empty byte arrays V1877467535
lauzadis Aug 7, 2025
4987757
ktlint
lauzadis Aug 7, 2025
28f22d7
dispose more StableRef
lauzadis Aug 7, 2025
b84fbf7
ktlint
lauzadis Aug 7, 2025
91b5d87
remove comment
lauzadis Aug 7, 2025
5a73317
update comment
lauzadis Aug 7, 2025
2e4da9f
Dispose StableRefs
lauzadis Aug 12, 2025
09edf6d
revert signer StableRef changes
lauzadis Aug 12, 2025
6ec4e2c
lint
lauzadis Aug 12, 2025
e5c4803
Support multiple exception messages
lauzadis Aug 13, 2025
47c29d5
update comment
lauzadis Aug 13, 2025
202c878
Try reverting HttpStreamNative changes
lauzadis Aug 13, 2025
aca4a1e
Merge branch 'kn-main' of github.com:awslabs/aws-crt-kotlin into kn-f…
lauzadis Aug 13, 2025
48cd403
Revert "Try reverting HttpStreamNative changes"
lauzadis Aug 13, 2025
d3aa7b7
try reverting HttpClientConnectionManager changes
lauzadis Aug 13, 2025
c41c2e7
ktlint
lauzadis Aug 13, 2025
25d2284
Only revert acquireConnection StableRef disposal
lauzadis Aug 13, 2025
5490ba9
revert httpConnectionTest changes
lauzadis Aug 13, 2025
af178a7
refactor to handle errors in kShouldSignHeaderFn
lauzadis Aug 13, 2025
5ad95fc
Refactor onResponseHeaders to use StableRef.use
lauzadis Aug 13, 2025
41f384b
Refactor to use `takeUnless`
lauzadis Aug 13, 2025
d434c2d
Fix use of `use`
lauzadis Aug 13, 2025
af0831e
ktlint
lauzadis Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import aws.sdk.kotlin.crt.util.asAwsByteCursor
import aws.sdk.kotlin.crt.util.initFromCursor
import aws.sdk.kotlin.crt.util.toAwsString
import aws.sdk.kotlin.crt.util.toKString
import aws.sdk.kotlin.crt.util.use
import kotlinx.cinterop.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.runBlocking
Expand All @@ -31,34 +32,37 @@ public actual object AwsSigner : WithCrt() {
request: HttpRequest,
config: AwsSigningConfig,
): AwsSigningResult = memScoped {
val nativeRequest = request.toNativeRequest().pin()

// Pair of HTTP request and callback channel containing the signature
val userData = nativeRequest to Channel<ByteArray>(1)
val userDataStableRef = StableRef.create(userData)

val signable = checkNotNull(
aws_signable_new_http_request(
allocator = Allocator.Default.allocator,
request = nativeRequest.get(),
),
) { "aws_signable_new_http_request" }

val nativeSigningConfig: CPointer<aws_signing_config_base> = config.toNativeSigningConfig().reinterpret()

awsAssertOpSuccess(
aws_sign_request_aws(
allocator = Allocator.Default.allocator,
signable = signable,
base_config = nativeSigningConfig,
on_complete = staticCFunction(::signCallback),
userdata = userDataStableRef.asCPointer(),
),
) { "sign() aws_sign_request_aws" }

val callbackChannel = userDataStableRef.get().second
val signature = callbackChannel.receive() // wait for async signing to complete....
return AwsSigningResult(nativeRequest.get().toHttpRequest(), signature)
request.toNativeRequest().usePinned { nativeRequest ->
// Pair of HTTP request and callback channel containing the signature
val userData = nativeRequest to Channel<ByteArray>(1)
val userDataStableRef = StableRef.create(userData)

val signable = checkNotNull(
aws_signable_new_http_request(
allocator = Allocator.Default.allocator,
request = nativeRequest.get(),
),
) { "aws_signable_new_http_request" }

val nativeSigningConfig: CPointer<aws_signing_config_base> = config.toNativeSigningConfig().reinterpret()

awsAssertOpSuccess(
aws_sign_request_aws(
allocator = Allocator.Default.allocator,
signable = signable,
base_config = nativeSigningConfig,
on_complete = staticCFunction(::signCallback),
userdata = userDataStableRef.asCPointer(),
),
) { "sign() aws_sign_request_aws" }

val callbackChannel = userDataStableRef.get().second
val signature = callbackChannel.receive() // wait for async signing to complete....
return AwsSigningResult(nativeRequest.get().toHttpRequest(), signature).also {
userDataStableRef.dispose()
callbackChannel.close()
}
}
}

public actual suspend fun signChunk(
Expand Down Expand Up @@ -223,9 +227,11 @@ private fun nativeShouldSignHeaderFn(headerName: CPointer<aws_byte_cursor>?, use
return true
}

val kShouldSignHeaderFn = userData.asStableRef<ShouldSignHeaderFunction>().get()
val kHeaderName = headerName.pointed.toKString()
return kShouldSignHeaderFn(kHeaderName)
userData.asStableRef<ShouldSignHeaderFunction>().use {
val kShouldSignHeaderFn = it.get()
val kHeaderName = headerName.pointed.toKString()
return kShouldSignHeaderFn(kHeaderName)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ public actual class HttpClientConnectionManager actual constructor(
actual override fun close() {
if (closed.compareAndSet(false, true)) {
aws_http_connection_manager_release(manager)
shutdownCompleteStableRef.dispose()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import aws.sdk.kotlin.crt.io.ByteCursorBuffer
import aws.sdk.kotlin.crt.util.asAwsByteCursor
import aws.sdk.kotlin.crt.util.initFromCursor
import aws.sdk.kotlin.crt.util.toKString
import aws.sdk.kotlin.crt.util.use
import aws.sdk.kotlin.crt.util.withAwsByteCursor
import kotlinx.atomicfu.atomic
import kotlinx.cinterop.*
Expand Down Expand Up @@ -93,30 +94,33 @@ private fun onResponseHeaders(
numHeaders: size_t,
userdata: COpaquePointer?,
): Int {
val ctx = userdata?.asStableRef<HttpStreamContext>()?.get() ?: return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
val stream = ctx.stream ?: return AWS_OP_ERR
val ctxStableRef = userdata?.asStableRef<HttpStreamContext>() ?: return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
ctxStableRef.use {
val ctx = it.get()
val stream = ctx.stream ?: return AWS_OP_ERR

val hdrCnt = numHeaders.toInt()
val headers: List<HttpHeader>? = if (hdrCnt > 0 && headerArray != null) {
val kheaders = mutableListOf<HttpHeader>()
for (i in 0 until hdrCnt) {
val nativeHdr = headerArray[i]
val hdr = HttpHeader(nativeHdr.name.toKString(), nativeHdr.value.toKString())
kheaders.add(hdr)
}
kheaders
} else {
null
}

val hdrCnt = numHeaders.toInt()
val headers: List<HttpHeader>? = if (hdrCnt > 0 && headerArray != null) {
val kheaders = mutableListOf<HttpHeader>()
for (i in 0 until hdrCnt) {
val nativeHdr = headerArray[i]
val hdr = HttpHeader(nativeHdr.name.toKString(), nativeHdr.value.toKString())
kheaders.add(hdr)
try {
ctx.handler.onResponseHeaders(stream, stream.responseStatusCode, blockType.value.toInt(), headers)
} catch (ex: Exception) {
log(LogLevel.Error, "onResponseHeaders: $ex")
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
}
kheaders
} else {
null
}

try {
ctx.handler.onResponseHeaders(stream, stream.responseStatusCode, blockType.value.toInt(), headers)
} catch (ex: Exception) {
log(LogLevel.Error, "onResponseHeaders: $ex")
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
return AWS_OP_SUCCESS
}

return AWS_OP_SUCCESS
}

private fun onResponseHeaderBlockDone(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import aws.sdk.kotlin.crt.CrtRuntimeException
import aws.sdk.kotlin.crt.NativeHandle
import aws.sdk.kotlin.crt.awsAssertOpSuccess
import aws.sdk.kotlin.crt.util.asAwsByteCursor
import aws.sdk.kotlin.crt.util.use
import kotlinx.atomicfu.atomic
import kotlinx.cinterop.*
import libcrt.*
Expand Down Expand Up @@ -66,18 +67,19 @@ internal class HttpStreamNative(
throw CrtRuntimeException("aws_input_stream_new_from_cursor()")
}

val req = StableRef.create(WriteChunkRequest(cont, byteBuf, stream))
val chunkOpts = cValue<aws_http1_chunk_options> {
chunk_data_size = chunkData.size.convert()
chunk_data = stream
on_complete = staticCFunction(::onWriteChunkComplete)
user_data = req.asCPointer()
}
awsAssertOpSuccess(
aws_http1_stream_write_chunk(ptr, chunkOpts),
) {
cleanupWriteChunkCbData(req)
"aws_http1_stream_write_chunk()"
StableRef.create(WriteChunkRequest(cont, byteBuf, stream)).use { req ->
val chunkOpts = cValue<aws_http1_chunk_options> {
chunk_data_size = chunkData.size.convert()
chunk_data = stream
on_complete = staticCFunction(::onWriteChunkComplete)
user_data = req.asCPointer()
}
awsAssertOpSuccess(
aws_http1_stream_write_chunk(ptr, chunkOpts),
) {
cleanupWriteChunkCbData(req)
"aws_http1_stream_write_chunk()"
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public actual class ClientBootstrap private constructor(

actual override fun close() {
aws_client_bootstrap_release(ptr)
channelStableRef.dispose()

if (manageHr) hr.close()
if (manageElg) elg.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public actual class EventLoopGroup actual constructor(maxThreads: Int) :

actual override fun close() {
aws_event_loop_group_release(ptr)
channelStableRef.dispose()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public actual class HostResolver private constructor(

actual override fun close() {
aws_host_resolver_release(ptr)
channelStableRef.dispose()

if (manageElg) elg.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ private sealed interface InnerBuffer {
init {
pointer.pointed.len = 0.convert()
pointer.pointed.capacity = dest.size.convert()
pointer.pointed.buffer = pinned.addressOf(0).reinterpret()
pointer.pointed.buffer = pinned.takeUnless { dest.isEmpty() }
?.addressOf(0)
?.reinterpret()
}

override fun release() {
Expand Down
12 changes: 12 additions & 0 deletions aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/util/Interop.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package aws.sdk.kotlin.crt.util

import kotlinx.cinterop.StableRef
import kotlinx.coroutines.channels.Channel

/**
Expand All @@ -15,3 +16,14 @@ internal typealias ShutdownChannel = Channel<Unit>
* Create a new shutdown notification channel
*/
internal fun shutdownChannel(): ShutdownChannel = Channel(Channel.RENDEZVOUS)

/**
* Execute [block] using [StableRef], then dispose it.
*/
internal inline fun <T : Any, R> StableRef<T>.use(block: (StableRef<T>) -> R): R {
try {
return block(this)
} finally {
dispose()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ internal class Crc(val checksumFn: AwsChecksumsCrcFunction) : HashFunction {
private var crc = 0U

override fun update(input: ByteArray, offset: Int, length: Int) {
require(offset >= 0) { "offset must not be negative" }
require(length >= 0) { "length must not be negative" }
require(offset + length <= input.size) {
"offset + length must not exceed input size: $offset + $length > ${input.size}"
}

if (input.isEmpty() || length == 0) {
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ public class Md5 :
private var md5 = checkNotNull(aws_md5_new(Allocator.Default)) { "aws_md5_new" }

override fun update(input: ByteArray, offset: Int, length: Int) {
require(offset >= 0) { "offset must not be negative" }
require(length >= 0) { "length must not be negative" }
require(offset + length <= input.size) {
"offset + length must not exceed input size: $offset + $length > ${input.size}"
}

if (input.isEmpty() || length == 0) {
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ internal class Sha(val initializeFn: InitializeHashFn) : HashFunction {

// aws_hash_update
override fun update(input: ByteArray, offset: Int, length: Int) {
require(offset >= 0) { "offset must not be negative" }
require(length >= 0) { "length must not be negative" }
require(offset + length <= input.size) {
"offset + length must not exceed input size: $offset + $length > ${input.size}"
}

if (input.isEmpty() || length == 0) {
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,17 @@ class MutableBufferTest : CrtTest() {
assertEquals("a tay is a hammer; a lep is a ball", actual)
}
}

@Test
fun testEmptyByteArray() {
val dest = ByteArray(0)
val buffer = MutableBuffer.of(dest)

assertEquals(0, buffer.writeRemaining)

val written = buffer.write(byteArrayOf(1, 2, 3))
assertEquals(0, written)

buffer.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package aws.sdk.kotlin.crt.util.hashing
import aws.sdk.kotlin.crt.util.encodeToHex
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith

class HashFunctionTest {
@Test
Expand Down Expand Up @@ -67,4 +68,48 @@ class HashFunctionTest {
assertEquals(expected, hash.digest().encodeToHex())
}
}

@Test
fun testCrcUpdateOutOfBounds() {
val crc32 = Crc32()
val data = ByteArray(4) { it.toByte() }

// offset + length exceeds the buffer size
assertFailsWith<IllegalArgumentException> {
crc32.update(data, 4, 1)
}
}

@Test
fun testMd5UpdateOutOfBounds() {
val md5 = Md5()
val data = ByteArray(4) { it.toByte() }

// offset + length exceeds the buffer size
assertFailsWith<IllegalArgumentException> {
md5.update(data, 4, 1)
}
}

@Test
fun testSha1UpdateOutOfBounds() {
val sha1 = Sha1()
val data = ByteArray(4) { it.toByte() }

// offset + length exceeds the buffer size
assertFailsWith<IllegalArgumentException> {
sha1.update(data, 4, 1)
}
}

@Test
fun testSha256UpdateOutOfBounds() {
val sha256 = Sha256()
val data = ByteArray(4) { it.toByte() }

// offset + length exceeds the buffer size
assertFailsWith<IllegalArgumentException> {
sha256.update(data, 4, 1)
}
}
}
Loading