Skip to content

Commit ca1b218

Browse files
committed
Use Result<T> as channel so exceptions are handled properly
1 parent 4ba7c5b commit ca1b218

File tree

1 file changed

+51
-37
lines changed

1 file changed

+51
-37
lines changed

aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/HostResolverNative.kt

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ public actual class HostResolver private constructor(
6767
val awsHostname = hostname.toAwsString()
6868
val resultCallback = staticCFunction(::awsOnHostResolveFn)
6969

70-
val channel: Channel<List<CrtHostAddress>> = Channel(Channel.RENDEZVOUS)
70+
val channel: Channel<Result<List<CrtHostAddress>>> = Channel(Channel.RENDEZVOUS)
7171
val channelStableRef = StableRef.create(channel)
7272
val userData = channelStableRef.asCPointer()
7373

7474
aws_host_resolver_resolve_host(ptr, awsHostname, resultCallback, aws_host_resolver_init_default_resolution_config(), userData)
7575

76-
return channel.receive().also {
76+
return channel.receive().getOrThrow().also {
7777
aws_string_destroy(awsHostname)
7878
channelStableRef.dispose()
7979
}
@@ -100,49 +100,63 @@ private fun awsOnHostResolveFn(
100100
hostAddresses: CPointer<aws_array_list>?, // list of `aws_host_address`
101101
userData: COpaquePointer?,
102102
): Unit = memScoped {
103-
if (errCode != AWS_OP_SUCCESS) {
104-
throw CrtRuntimeException("aws_on_host_resolved_result_fn", ec = errCode)
105-
}
106103
if (userData == null) {
107104
throw CrtRuntimeException("aws_on_host_resolved_result_fn: userData unexpectedly null")
108105
}
109106

110-
val length = aws_array_list_length(hostAddresses)
111-
if (length == 0uL) {
112-
throw CrtRuntimeException("Failed to resolve host address for ${hostName?.toKString()}")
113-
}
107+
val stableRef = userData.asStableRef<Channel<Result<List<CrtHostAddress>>>>()
108+
val channel = stableRef.get()
114109

115-
val result = ArrayList<CrtHostAddress>(length.toInt())
116-
117-
val element = alloc<COpaquePointerVar>()
118-
for (i in 0uL until length) {
119-
awsAssertOpSuccess(
120-
aws_array_list_get_at_ptr(
121-
hostAddresses,
122-
element.ptr,
123-
i,
124-
),
125-
) { "aws_array_list_get_at_ptr failed at index $i" }
126-
127-
val elemOpaque = element.value ?: throw CrtRuntimeException("aws_host_addresses value at index $i unexpectedly null")
128-
val addr = elemOpaque.reinterpret<aws_host_address>().pointed
129-
130-
val hostStr = addr.host?.toKString() ?: throw CrtRuntimeException("aws_host_addresses `host` at index $i unexpectedly null")
131-
val addressStr = addr.address?.toKString() ?: throw CrtRuntimeException("aws_host_addresses `address` at index $i unexpectedly null")
132-
val addressType = when (addr.record_type) {
133-
aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_A -> AddressType.IpV4
134-
aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_AAAA -> AddressType.IpV6
135-
else -> throw CrtRuntimeException("received unsupported aws_host_address `aws_address_record_type`: ${addr.record_type}")
110+
try {
111+
if (errCode != AWS_OP_SUCCESS) {
112+
throw CrtRuntimeException("aws_on_host_resolved_result_fn", ec = errCode)
136113
}
137114

138-
result += CrtHostAddress(host = hostStr, address = addressStr, addressType)
139-
}
115+
val length = aws_array_list_length(hostAddresses)
116+
if (length == 0uL) {
117+
throw CrtRuntimeException("Failed to resolve host address for ${hostName?.toKString()}")
118+
}
140119

141-
// Send results back through the channel
142-
val stableRef = userData.asStableRef<Channel<List<CrtHostAddress>>>()
143-
val channel = stableRef.get()
144-
channel.trySend(result)
145-
channel.close()
120+
val addressList = ArrayList<CrtHostAddress>(length.toInt())
121+
122+
val element = alloc<COpaquePointerVar>()
123+
for (i in 0uL until length) {
124+
awsAssertOpSuccess(
125+
aws_array_list_get_at_ptr(
126+
hostAddresses,
127+
element.ptr,
128+
i,
129+
),
130+
) { "aws_array_list_get_at_ptr failed at index $i" }
131+
132+
val elemOpaque = element.value ?: run {
133+
throw CrtRuntimeException("aws_host_addresses value at index $i unexpectedly null")
134+
}
135+
136+
val addr = elemOpaque.reinterpret<aws_host_address>().pointed
137+
138+
val hostStr = addr.host?.toKString() ?: run {
139+
throw CrtRuntimeException("aws_host_addresses `host` at index $i unexpectedly null")
140+
}
141+
val addressStr = addr.address?.toKString() ?: run {
142+
throw CrtRuntimeException("aws_host_addresses `address` at index $i unexpectedly null")
143+
}
144+
145+
val addressType = when (addr.record_type) {
146+
aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_A -> AddressType.IpV4
147+
aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_AAAA -> AddressType.IpV6
148+
else -> throw CrtRuntimeException("received unsupported aws_host_address `aws_address_record_type`: ${addr.record_type}")
149+
}
150+
151+
addressList += CrtHostAddress(host = hostStr, address = addressStr, addressType)
152+
}
153+
154+
channel.trySend(Result.success(addressList))
155+
} catch (e: Exception) {
156+
channel.trySend(Result.failure(e))
157+
} finally {
158+
channel.close()
159+
}
146160
}
147161

148162
// Minimal wrapper of aws_host_address

0 commit comments

Comments
 (0)