Skip to content

Commit 6932cc5

Browse files
committed
refactor all usages of StableRef in CRT to more carefully handle potential NPEs
1 parent 40462ad commit 6932cc5

File tree

9 files changed

+228
-247
lines changed

9 files changed

+228
-247
lines changed

aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/auth/signing/AwsSignerNative.kt

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@ package aws.sdk.kotlin.crt.auth.signing
88
import aws.sdk.kotlin.crt.*
99
import aws.sdk.kotlin.crt.auth.credentials.Credentials
1010
import aws.sdk.kotlin.crt.http.*
11-
import aws.sdk.kotlin.crt.util.asAwsByteCursor
12-
import aws.sdk.kotlin.crt.util.initFromCursor
13-
import aws.sdk.kotlin.crt.util.toAwsString
14-
import aws.sdk.kotlin.crt.util.toKString
15-
import aws.sdk.kotlin.crt.util.use
11+
import aws.sdk.kotlin.crt.util.*
1612
import kotlinx.cinterop.*
1713
import kotlinx.coroutines.channels.Channel
1814
import kotlinx.coroutines.runBlocking
@@ -223,15 +219,10 @@ private fun AwsSigningConfig.toNativeSigningConfig(): CPointer<aws_signing_confi
223219
private typealias ShouldSignHeaderFunction = (String) -> Boolean
224220
private fun nativeShouldSignHeaderFn(headerName: CPointer<aws_byte_cursor>?, userData: COpaquePointer?): Boolean {
225221
checkNotNull(headerName) { "aws_should_sign_header_fn expected non-null header name" }
226-
if (userData == null) {
227-
return true
228-
}
229-
230-
userData.asStableRef<ShouldSignHeaderFunction>().use {
231-
val kShouldSignHeaderFn = it.get()
222+
return userData?.withDereferenced<ShouldSignHeaderFunction, _>(dispose = true) { kShouldSignHeaderFn ->
232223
val kHeaderName = headerName.pointed.toKString()
233-
return kShouldSignHeaderFn(kHeaderName)
234-
}
224+
kShouldSignHeaderFn(kHeaderName)
225+
} ?: error("Expected non-null userData")
235226
}
236227

237228
/**
@@ -243,17 +234,17 @@ private fun signCallback(signingResult: CPointer<aws_signing_result>?, errorCode
243234
checkNotNull(signingResult) { "signing callback received null aws_signing_result" }
244235
checkNotNull(userData) { "signing callback received null user data" }
245236

246-
val (pinnedRequestToSign, callbackChannel) = userData
247-
.asStableRef<Pair<Pinned<CPointer<cnames.structs.aws_http_message>>, Channel<ByteArray>>>()
248-
.get()
237+
userData.withDereferenced<Pair<Pinned<CPointer<cnames.structs.aws_http_message>>, Channel<ByteArray>>> { pair ->
238+
val (pinnedRequestToSign, callbackChannel) = pair
249239

250-
val requestToSign = pinnedRequestToSign.get()
240+
val requestToSign = pinnedRequestToSign.get()
251241

252-
awsAssertOpSuccess(aws_apply_signing_result_to_http_request(requestToSign, Allocator.Default.allocator, signingResult)) {
253-
"aws_apply_signing_result_to_http_request"
254-
}
242+
awsAssertOpSuccess(aws_apply_signing_result_to_http_request(requestToSign, Allocator.Default.allocator, signingResult)) {
243+
"aws_apply_signing_result_to_http_request"
244+
}
255245

256-
runBlocking { callbackChannel.send(signingResult.getSignature()) }
246+
runBlocking { callbackChannel.send(signingResult.getSignature()) }
247+
}
257248
}
258249

259250
/**
@@ -264,8 +255,9 @@ private fun signChunkCallback(signingResult: CPointer<aws_signing_result>?, erro
264255
checkNotNull(signingResult) { "signing callback received null aws_signing_result" }
265256
checkNotNull(userData) { "signing callback received null user data" }
266257

267-
val callbackChannel = userData.asStableRef<Channel<ByteArray>>().get()
268-
runBlocking { callbackChannel.send(signingResult.getSignature()) }
258+
userData.withDereferenced<Channel<ByteArray>> { callbackChannel ->
259+
runBlocking { callbackChannel.send(signingResult.getSignature()) }
260+
}
269261
}
270262

271263
private fun Credentials.toNativeCredentials(): CPointer<cnames.structs.aws_credentials>? = aws_credentials_new_from_string(

aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpClientConnectionManagerNative.kt

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,10 @@ private fun SocketDomain.toNativeSocketDomain() = when (this) {
183183
}
184184

185185
private fun onShutdownComplete(userdata: COpaquePointer?) {
186-
if (userdata == null) return
187-
val notify = userdata.asStableRef<ShutdownChannel>()
188-
with(notify.get()) {
189-
trySend(Unit)
190-
close()
186+
userdata?.withDereferenced<ShutdownChannel>(dispose = true) { notify ->
187+
notify.trySend(Unit)
188+
notify.close()
191189
}
192-
notify.dispose()
193190
}
194191

195192
private data class HttpConnectionAcquisitionRequest(
@@ -202,20 +199,16 @@ private fun onConnectionAcquired(
202199
errCode: Int,
203200
userdata: COpaquePointer?,
204201
) {
205-
if (userdata == null) return
206-
val stableRef = userdata.asStableRef<HttpConnectionAcquisitionRequest>()
207-
val request = stableRef.get()
208-
209-
when {
210-
errCode != AWS_OP_SUCCESS -> request.cont.resumeWithException(HttpException(errCode))
211-
conn == null -> request.cont.resumeWithException(
212-
CrtRuntimeException("acquireConnection(): http connection null", ec = errCode),
213-
)
214-
else -> {
215-
val kconn = HttpClientConnectionNative(request.manager, conn)
216-
request.cont.resume(kconn)
202+
userdata?.withDereferenced<HttpConnectionAcquisitionRequest>(dispose = true) { request ->
203+
when {
204+
errCode != AWS_OP_SUCCESS -> request.cont.resumeWithException(HttpException(errCode))
205+
conn == null -> request.cont.resumeWithException(
206+
CrtRuntimeException("acquireConnection(): http connection null", ec = errCode),
207+
)
208+
else -> {
209+
val kconn = HttpClientConnectionNative(request.manager, conn)
210+
request.cont.resume(kconn)
211+
}
217212
}
218213
}
219-
220-
stableRef.dispose()
221214
}

aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpClientConnectionNative.kt

Lines changed: 61 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -91,121 +91,95 @@ private fun onResponseHeaders(
9191
headerArray: CPointer<aws_http_header>?,
9292
numHeaders: size_t,
9393
userdata: COpaquePointer?,
94-
): Int {
95-
val stableRef = dereferenceUserdata(userdata) ?: return callbackError()
96-
val ctx = stableRef.safeGet() ?: return callbackError()
97-
val stream = ctx.stream ?: return callbackError()
98-
99-
val hdrCnt = numHeaders.toInt()
100-
val headers: List<HttpHeader>? = if (hdrCnt > 0 && headerArray != null) {
101-
val kheaders = mutableListOf<HttpHeader>()
102-
for (i in 0 until hdrCnt) {
103-
val nativeHdr = headerArray[i]
104-
val hdr = HttpHeader(nativeHdr.name.toKString(), nativeHdr.value.toKString())
105-
kheaders.add(hdr)
106-
}
107-
kheaders
108-
} else {
109-
null
110-
}
111-
112-
try {
113-
ctx.handler.onResponseHeaders(stream, stream.responseStatusCode, blockType.value.toInt(), headers)
114-
} catch (ex: Exception) {
115-
log(LogLevel.Error, "onResponseHeaders: $ex")
116-
return callbackError()
117-
}
94+
): Int =
95+
userdata?.withDereferenced<HttpStreamContext, _> { ctx ->
96+
ctx.stream?.let { stream ->
97+
val hdrCnt = numHeaders.toInt()
98+
val headers: List<HttpHeader>? = if (hdrCnt > 0 && headerArray != null) {
99+
val kheaders = mutableListOf<HttpHeader>()
100+
for (i in 0 until hdrCnt) {
101+
val nativeHdr = headerArray[i]
102+
val hdr = HttpHeader(nativeHdr.name.toKString(), nativeHdr.value.toKString())
103+
kheaders.add(hdr)
104+
}
105+
kheaders
106+
} else {
107+
null
108+
}
118109

119-
return AWS_OP_SUCCESS
120-
}
110+
try {
111+
ctx.handler.onResponseHeaders(stream, stream.responseStatusCode, blockType.value.toInt(), headers)
112+
AWS_OP_SUCCESS
113+
} catch (ex: Exception) {
114+
log(LogLevel.Error, "onResponseHeaders: $ex")
115+
null
116+
}
117+
}
118+
} ?: callbackError()
121119

122120
private fun onResponseHeaderBlockDone(
123121
nativeStream: CPointer<cnames.structs.aws_http_stream>?,
124122
blockType: aws_http_header_block,
125123
userdata: COpaquePointer?,
126-
): Int {
127-
val stableRef = dereferenceUserdata(userdata) ?: return callbackError()
128-
val ctx = stableRef.safeGet() ?: return callbackError()
129-
val stream = ctx.stream ?: return callbackError()
130-
131-
try {
132-
ctx.handler.onResponseHeadersDone(stream, blockType.value.toInt())
133-
} catch (ex: Exception) {
134-
log(LogLevel.Error, "onResponseHeaderBlockDone: $ex")
135-
return callbackError()
136-
}
137-
138-
return AWS_OP_SUCCESS
139-
}
124+
): Int =
125+
userdata?.withDereferenced<HttpStreamContext, _> { ctx ->
126+
ctx.stream?.let { stream ->
127+
try {
128+
ctx.handler.onResponseHeadersDone(stream, blockType.value.toInt())
129+
AWS_OP_SUCCESS
130+
} catch (ex: Exception) {
131+
log(LogLevel.Error, "onResponseHeaderBlockDone: $ex")
132+
null
133+
}
134+
}
135+
} ?: callbackError()
140136

141137
private fun onIncomingBody(
142138
nativeStream: CPointer<cnames.structs.aws_http_stream>?,
143139
data: CPointer<aws_byte_cursor>?,
144140
userdata: COpaquePointer?,
145-
): Int {
146-
val stableRef = dereferenceUserdata(userdata) ?: return callbackError()
147-
val ctx = stableRef.safeGet() ?: return callbackError()
148-
val stream = ctx.stream ?: return callbackError()
149-
150-
try {
151-
val body = if (data != null) ByteCursorBuffer(data) else Buffer.Empty
152-
val windowIncrement = ctx.handler.onResponseBody(stream, body)
153-
if (windowIncrement < 0) {
154-
return callbackError()
155-
}
156-
157-
if (windowIncrement > 0) {
158-
aws_http_stream_update_window(nativeStream, windowIncrement.convert())
141+
): Int =
142+
userdata?.withDereferenced<HttpStreamContext, _> { ctx ->
143+
ctx.stream?.let { stream ->
144+
try {
145+
val body = if (data != null) ByteCursorBuffer(data) else Buffer.Empty
146+
val windowIncrement = ctx.handler.onResponseBody(stream, body)
147+
148+
if (windowIncrement < 0) {
149+
null
150+
} else {
151+
if (windowIncrement > 0) {
152+
aws_http_stream_update_window(nativeStream, windowIncrement.convert())
153+
}
154+
AWS_OP_SUCCESS
155+
}
156+
} catch (ex: Exception) {
157+
log(LogLevel.Error, "onIncomingBody: $ex")
158+
null
159+
}
159160
}
160-
} catch (ex: Exception) {
161-
log(LogLevel.Error, "onIncomingBody: $ex")
162-
return callbackError()
163-
}
164-
165-
return AWS_OP_SUCCESS
166-
}
161+
} ?: callbackError()
167162

168163
private fun onStreamComplete(
169164
nativeStream: CPointer<cnames.structs.aws_http_stream>?,
170165
errorCode: Int,
171166
userdata: COpaquePointer?,
172167
) {
173-
val stableRef = dereferenceUserdata(userdata) ?: return
174-
try {
175-
val ctx = stableRef.safeGet() ?: return
168+
userdata?.withDereferenced<HttpStreamContext>(dispose = true) { ctx ->
176169
try {
177170
val stream = ctx.stream ?: return
178171
ctx.handler.onResponseComplete(stream, errorCode)
172+
} catch (ex: Exception) {
173+
log(LogLevel.Error, "onStreamComplete: $ex")
174+
// close connection if callback throws an exception
175+
aws_http_connection_close(aws_http_stream_get_connection(nativeStream))
179176
} finally {
180177
// cleanup request object
181178
aws_http_message_release(ctx.nativeReq)
182179
}
183-
} catch (ex: Exception) {
184-
log(LogLevel.Error, "onStreamComplete: $ex")
185-
// close connection if callback throws an exception
186-
aws_http_connection_close(aws_http_stream_get_connection(nativeStream))
187-
} finally {
188-
// cleanup userdata
189-
stableRef.dispose()
190180
}
191181
}
192182

193-
private fun dereferenceUserdata(userdata: COpaquePointer?): StableRef<HttpStreamContext>? =
194-
try {
195-
userdata?.asStableRef<HttpStreamContext>()
196-
} catch (_: NullPointerException) {
197-
// `asStableRef()` can throw `NullPointerException` when target type can't be coerced to HttpStreamContext
198-
null
199-
}
200-
201-
private fun <T : Any> StableRef<T>.safeGet(): T? =
202-
try {
203-
get()
204-
} catch (_: NullPointerException) {
205-
// `get()` can throw `NullPointerException` when stream has been canceled and CRT is cleaning up resources
206-
null
207-
}
208-
209183
internal fun HttpRequest.toNativeRequest(): CPointer<cnames.structs.aws_http_message> {
210184
val nativeReq = checkNotNull(
211185
aws_http_message_new_request(Allocator.Default),

aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpStreamNative.kt

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import aws.sdk.kotlin.crt.NativeHandle
1010
import aws.sdk.kotlin.crt.awsAssertOpSuccess
1111
import aws.sdk.kotlin.crt.util.asAwsByteCursor
1212
import aws.sdk.kotlin.crt.util.use
13+
import aws.sdk.kotlin.crt.util.withDereferenced
1314
import kotlinx.atomicfu.atomic
1415
import kotlinx.cinterop.*
1516
import libcrt.*
@@ -67,12 +68,13 @@ internal class HttpStreamNative(
6768
throw CrtRuntimeException("aws_input_stream_new_from_cursor()")
6869
}
6970

70-
StableRef.create(WriteChunkRequest(cont, byteBuf, stream)).use { req ->
71+
val req = WriteChunkRequest(cont, byteBuf, stream)
72+
StableRef.create(req).use { stableRef ->
7173
val chunkOpts = cValue<aws_http1_chunk_options> {
7274
chunk_data_size = chunkData.size.convert()
7375
chunk_data = stream
7476
on_complete = staticCFunction(::onWriteChunkComplete)
75-
user_data = req.asCPointer()
77+
user_data = stableRef.asCPointer()
7678
}
7779
awsAssertOpSuccess(
7880
aws_http1_stream_write_chunk(ptr, chunkOpts),
@@ -113,19 +115,18 @@ private fun onWriteChunkComplete(
113115
userData: COpaquePointer?,
114116
) {
115117
if (stream == null) return
116-
val stableRef = userData?.asStableRef<WriteChunkRequest>() ?: return
117-
val req = stableRef.get()
118-
when {
119-
errCode != AWS_OP_SUCCESS -> req.cont.resumeWithException(HttpException(errCode))
120-
else -> req.cont.resume(Unit)
118+
userData?.withDereferenced<WriteChunkRequest> { req ->
119+
checkNotNull(req) { "Received null request in onWriteChunkComplete" }
120+
when {
121+
errCode != AWS_OP_SUCCESS -> req.cont.resumeWithException(HttpException(errCode))
122+
else -> req.cont.resume(Unit)
123+
}
124+
cleanupWriteChunkCbData(req)
121125
}
122-
cleanupWriteChunkCbData(stableRef)
123126
}
124127

125-
private fun cleanupWriteChunkCbData(stableRef: StableRef<WriteChunkRequest>) {
126-
val req = stableRef.get()
128+
private fun cleanupWriteChunkCbData(req: WriteChunkRequest) {
127129
aws_input_stream_destroy(req.inputStream)
128130
aws_byte_buf_clean_up(req.chunkData)
129131
Allocator.Default.free(req.inputStream)
130-
stableRef.dispose()
131132
}

0 commit comments

Comments
 (0)