Skip to content

Commit 00d573c

Browse files
committed
grpc-native: Add overrideAuthority on native
Signed-off-by: Johannes Zottele <[email protected]>
1 parent ba97b64 commit 00d573c

File tree

5 files changed

+145
-153
lines changed

5 files changed

+145
-153
lines changed
Lines changed: 8 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -2,135 +2,11 @@
22
* Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
33
*/
44

5-
package kotlinx.rpc.grpc.test.proto
5+
package kotlinx.rpc.grpc.test
66

7-
import hello.HelloRequest
8-
import hello.HelloService
9-
import hello.invoke
10-
import kotlinx.coroutines.test.runTest
11-
import kotlinx.rpc.grpc.GrpcClient
12-
import kotlinx.rpc.grpc.GrpcServer
13-
import kotlinx.rpc.grpc.TlsChannelCredentials
14-
import kotlinx.rpc.grpc.TlsClientAuth
15-
import kotlinx.rpc.grpc.TlsServerCredentials
16-
import kotlinx.rpc.grpc.test.EchoRequest
17-
import kotlinx.rpc.grpc.test.EchoService
18-
import kotlinx.rpc.grpc.test.EchoServiceImpl
19-
import kotlinx.rpc.grpc.test.invoke
20-
import kotlinx.rpc.registerService
21-
import kotlinx.rpc.withService
22-
import kotlin.test.Test
7+
// Certs are taken from grpc-java/testing/src/main/resources/certs
238

24-
private const val PORT = 50051
25-
26-
class GrpcbInTlsTest {
27-
28-
29-
@Test
30-
fun testTlsCall() = runTest {
31-
// uses default TLS credentials
32-
val grpcClient = GrpcClient("grpcb.in", 9001)
33-
val service = grpcClient.withService<HelloService>()
34-
val request = HelloRequest {
35-
greeting = "Postman"
36-
}
37-
val result = service.SayHello(request)
38-
39-
println(result.reply)
40-
41-
// Ensure we don't leak the client channel between tests
42-
grpcClient.shutdown()
43-
grpcClient.awaitTermination()
44-
}
45-
46-
47-
@Test
48-
fun testCustomTls() = runTest {
49-
val serverTls = TlsServerCredentials(SERVER_CERT_PEM, SERVER_KEY_PEM)
50-
51-
val grpcServer = GrpcServer(
52-
PORT,
53-
credentials = serverTls,
54-
builder = {
55-
registerService<EchoService> { EchoServiceImpl() }
56-
})
57-
grpcServer.start()
58-
59-
val clientTls = TlsChannelCredentials {
60-
trustManager(SERVER_CERT_PEM)
61-
}
62-
63-
val grpcClient = GrpcClient(
64-
"localhost", PORT,
65-
credentials = clientTls,
66-
) {}
67-
68-
val service = grpcClient.withService<EchoService>()
69-
val request = EchoRequest {
70-
message = "Postman"
71-
}
72-
73-
try {
74-
service.UnaryEcho(request)
75-
} catch (t: Throwable) {
76-
println("[DEBUG_LOG] TLS test failed: ${t::class.simpleName}: ${t.message}")
77-
t.printStackTrace()
78-
throw t
79-
} finally {
80-
grpcServer.shutdown()
81-
grpcServer.awaitTermination()
82-
grpcClient.shutdown()
83-
grpcClient.awaitTermination()
84-
}
85-
}
86-
87-
@Test
88-
fun testCustomMTls() = runTest {
89-
val serverTls = TlsServerCredentials(SERVER_CERT_PEM, SERVER_KEY_PEM) {
90-
trustManager(CA_PEM)
91-
clientAuth(TlsClientAuth.REQUIRE)
92-
}
93-
94-
val grpcServer = GrpcServer(
95-
PORT,
96-
credentials = serverTls,
97-
builder = {
98-
registerService<EchoService> { EchoServiceImpl() }
99-
})
100-
grpcServer.start()
101-
102-
val clientTls = TlsChannelCredentials {
103-
keyManager(CLIENT_CERT_PEM, CLIENT_KEY_PEM)
104-
trustManager(CA_PEM)
105-
}
106-
107-
val grpcClient = GrpcClient(
108-
"localhost", PORT,
109-
credentials = clientTls,
110-
) {
111-
overrideAuthority("foo.test.google.fr")
112-
}
113-
114-
val service = grpcClient.withService<EchoService>()
115-
val request = EchoRequest {
116-
message = "Postman"
117-
}
118-
119-
try {
120-
service.UnaryEcho(request)
121-
} catch (t: Throwable) {
122-
println("[DEBUG_LOG] TLS test failed: ${t::class.simpleName}: ${t.message}")
123-
t.printStackTrace()
124-
throw t
125-
} finally {
126-
grpcServer.shutdown()
127-
grpcServer.awaitTermination()
128-
grpcClient.shutdown()
129-
grpcClient.awaitTermination()
130-
}
131-
}
132-
133-
private val CA_PEM = """
9+
val CA_PEM = """
13410
-----BEGIN CERTIFICATE-----
13511
MIIDWjCCAkKgAwIBAgIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQEL
13612
BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
@@ -153,7 +29,7 @@ class GrpcbInTlsTest {
15329
-----END CERTIFICATE-----
15430
""".trimIndent()
15531

156-
private val SERVER_CERT_PEM = """
32+
val SERVER_CERT_PEM = """
15733
-----BEGIN CERTIFICATE-----
15834
MIIDtDCCApygAwIBAgIUbJfTREJ6k6/+oInWhV1O1j3ZT0IwDQYJKoZIhvcNAQEL
15935
BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
@@ -178,7 +54,7 @@ class GrpcbInTlsTest {
17854
-----END CERTIFICATE-----
17955
""".trimIndent()
18056

181-
private val SERVER_KEY_PEM = """
57+
val SERVER_KEY_PEM = """
18258
-----BEGIN PRIVATE KEY-----
18359
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDnE443EknxvxBq
18460
6+hvn/t09hl8hx366EBYvZmVM/NC+7igXRAjiJiA/mIaCvL3MS0Iz5hBLxSGICU+
@@ -209,7 +85,7 @@ class GrpcbInTlsTest {
20985
-----END PRIVATE KEY-----
21086
""".trimIndent()
21187

212-
private val CLIENT_CERT_PEM = """
88+
val CLIENT_CERT_PEM = """
21389
-----BEGIN CERTIFICATE-----
21490
MIIDNzCCAh8CFGyX00RCepOv/qCJ1oVdTtY92U83MA0GCSqGSIb3DQEBCwUAMFYx
21591
CzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl
@@ -232,7 +108,7 @@ class GrpcbInTlsTest {
232108
-----END CERTIFICATE-----
233109
""".trimIndent()
234110

235-
private val CLIENT_KEY_PEM = """
111+
val CLIENT_KEY_PEM = """
236112
-----BEGIN PRIVATE KEY-----
237113
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCyqYRp+DXVp72N
238114
FbQH8hdhTZLycZXOlJhmMsrJmrjn2p7pI/8mTZ/0FC+SGWBGZV+ELiHrmCX5zfaI
@@ -261,6 +137,4 @@ class GrpcbInTlsTest {
261137
7vztKEH85yzp4n02FNL6H7xL4VVILvyZHdolmiORJ4qT2hZnl8pEQ2TYuF4RlHUd
262138
nSwXX+2o0J/nF85fm4AwWKAc
263139
-----END PRIVATE KEY-----
264-
""".trimIndent()
265-
266-
}
140+
""".trimIndent()

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/GrpcProtoTest.kt

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,46 @@ import kotlinx.coroutines.sync.Mutex
88
import kotlinx.coroutines.sync.withLock
99
import kotlinx.coroutines.test.runTest
1010
import kotlinx.rpc.RpcServer
11+
import kotlinx.rpc.grpc.ChannelCredentials
1112
import kotlinx.rpc.grpc.GrpcClient
1213
import kotlinx.rpc.grpc.GrpcServer
14+
import kotlinx.rpc.grpc.ServerCredentials
1315

1416
abstract class GrpcProtoTest {
1517
private val serverMutex = Mutex()
1618

1719
abstract fun RpcServer.registerServices()
1820

19-
protected fun runGrpcTest(test: suspend (GrpcClient) -> Unit) = runTest {
21+
protected fun runGrpcTest(
22+
serverCreds: ServerCredentials? = null,
23+
clientCreds: ChannelCredentials? = null,
24+
overrideAuthority: String? = null,
25+
test: suspend (GrpcClient) -> Unit,
26+
) = runTest {
2027
serverMutex.withLock {
21-
val grpcClient = GrpcClient("localhost", PORT) {
22-
usePlaintext()
28+
val grpcClient = GrpcClient("localhost", PORT, credentials = clientCreds) {
29+
if (overrideAuthority != null) overrideAuthority(overrideAuthority)
30+
if (clientCreds == null) {
31+
usePlaintext()
32+
}
2333
}
2434

25-
val grpcServer = GrpcServer(PORT, builder = {
26-
registerServices()
27-
})
35+
val grpcServer = GrpcServer(
36+
PORT,
37+
credentials = serverCreds,
38+
builder = {
39+
registerServices()
40+
})
2841

2942
grpcServer.start()
30-
test(grpcClient)
31-
grpcServer.shutdown()
32-
grpcServer.awaitTermination()
33-
grpcClient.shutdown()
34-
grpcClient.awaitTermination()
43+
try {
44+
test(grpcClient)
45+
} finally {
46+
grpcServer.shutdown()
47+
grpcServer.awaitTermination()
48+
grpcClient.shutdown()
49+
grpcClient.awaitTermination()
50+
}
3551
}
3652
}
3753

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.rpc.grpc.test.proto
6+
7+
import hello.HelloRequest
8+
import hello.HelloService
9+
import hello.invoke
10+
import kotlinx.coroutines.test.runTest
11+
import kotlinx.rpc.RpcServer
12+
import kotlinx.rpc.grpc.GrpcClient
13+
import kotlinx.rpc.grpc.TlsChannelCredentials
14+
import kotlinx.rpc.grpc.TlsClientAuth
15+
import kotlinx.rpc.grpc.TlsServerCredentials
16+
import kotlinx.rpc.grpc.test.CA_PEM
17+
import kotlinx.rpc.grpc.test.CLIENT_CERT_PEM
18+
import kotlinx.rpc.grpc.test.CLIENT_KEY_PEM
19+
import kotlinx.rpc.grpc.test.EchoRequest
20+
import kotlinx.rpc.grpc.test.EchoService
21+
import kotlinx.rpc.grpc.test.EchoServiceImpl
22+
import kotlinx.rpc.grpc.test.SERVER_CERT_PEM
23+
import kotlinx.rpc.grpc.test.SERVER_KEY_PEM
24+
import kotlinx.rpc.grpc.test.invoke
25+
import kotlinx.rpc.registerService
26+
import kotlinx.rpc.withService
27+
import kotlin.test.Test
28+
import kotlin.test.assertEquals
29+
30+
class GrpcbTlsTest : GrpcProtoTest() {
31+
32+
override fun RpcServer.registerServices() {
33+
registerService<EchoService> { EchoServiceImpl() }
34+
}
35+
36+
@Test
37+
fun testDefaultTlsCall() = runTest {
38+
// uses default client TLS credentials
39+
val grpcClient = GrpcClient("grpcb.in", 9001)
40+
val service = grpcClient.withService<HelloService>()
41+
val request = HelloRequest {
42+
greeting = "world"
43+
}
44+
val result = service.SayHello(request)
45+
46+
assertEquals("hello world", result.reply)
47+
48+
// Ensure we don't leak the client channel between tests
49+
grpcClient.shutdown()
50+
grpcClient.awaitTermination()
51+
}
52+
53+
54+
@Test
55+
fun testCustomTls() {
56+
val serverTls = TlsServerCredentials(SERVER_CERT_PEM, SERVER_KEY_PEM)
57+
val clientTls = TlsChannelCredentials { trustManager(SERVER_CERT_PEM) }
58+
59+
runGrpcTest(serverTls, clientTls, overrideAuthority = "foo.test.google.fr") { client ->
60+
val service = client.withService<EchoService>()
61+
val request = EchoRequest { message = "Echo" }
62+
val response = service.UnaryEcho(request)
63+
assertEquals("Echo", response.message)
64+
}
65+
}
66+
67+
@Test
68+
fun testCustomMTls() = runTest {
69+
val serverTls = TlsServerCredentials(SERVER_CERT_PEM, SERVER_KEY_PEM) {
70+
trustManager(CA_PEM)
71+
clientAuth(TlsClientAuth.REQUIRE)
72+
}
73+
val clientTls = TlsChannelCredentials {
74+
keyManager(CLIENT_CERT_PEM, CLIENT_KEY_PEM)
75+
trustManager(CA_PEM)
76+
}
77+
78+
runGrpcTest(serverTls, clientTls, overrideAuthority = "foo.test.google.fr") { client ->
79+
val service = client.withService<EchoService>()
80+
val request = EchoRequest { message = "Echo" }
81+
val response = service.UnaryEcho(request)
82+
assertEquals("Echo", response.message)
83+
}
84+
}
85+
86+
}

grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/NativeManagedChannel.kt

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ import cnames.structs.grpc_channel
1010
import kotlinx.atomicfu.atomic
1111
import kotlinx.cinterop.CPointer
1212
import kotlinx.cinterop.ExperimentalForeignApi
13+
import kotlinx.cinterop.alloc
14+
import kotlinx.cinterop.cstr
15+
import kotlinx.cinterop.memScoped
16+
import kotlinx.cinterop.ptr
1317
import kotlinx.coroutines.CompletableDeferred
1418
import kotlinx.coroutines.Job
1519
import kotlinx.coroutines.SupervisorJob
@@ -21,6 +25,9 @@ import kotlinx.rpc.grpc.ManagedChannelPlatform
2125
import libkgrpc.GPR_CLOCK_REALTIME
2226
import libkgrpc.GRPC_PROPAGATE_DEFAULTS
2327
import libkgrpc.gpr_inf_future
28+
import libkgrpc.grpc_arg
29+
import libkgrpc.grpc_arg_type
30+
import libkgrpc.grpc_channel_args
2431
import libkgrpc.grpc_channel_create
2532
import libkgrpc.grpc_channel_create_call
2633
import libkgrpc.grpc_channel_destroy
@@ -55,8 +62,22 @@ internal class NativeManagedChannel(
5562
// the channel's completion queue, handling all request operations
5663
private val cq = CompletionQueue()
5764

58-
internal val raw: CPointer<grpc_channel> = grpc_channel_create(target, credentials.raw, null)
59-
?: error("Failed to create channel")
65+
internal val raw: CPointer<grpc_channel> = memScoped {
66+
val args = authority?.let {
67+
var authorityOverride = alloc<grpc_arg> {
68+
type = grpc_arg_type.GRPC_ARG_STRING
69+
key = "grpc.ssl_target_name_override".cstr.ptr
70+
value.string = authority.cstr.ptr
71+
}
72+
73+
alloc<grpc_channel_args> {
74+
num_args = 1u
75+
args = authorityOverride.ptr
76+
}
77+
}
78+
grpc_channel_create(target, credentials.raw, args?.ptr)
79+
?: error("Failed to create channel")
80+
}
6081

6182
@Suppress("unused")
6283
private val rawCleaner = createCleaner(raw) {
@@ -123,21 +144,19 @@ internal class NativeManagedChannel(
123144
// to construct a valid HTTP/2 path, we must prepend the name with a slash.
124145
// the user does not do this to align it with the java implementation.
125146
val methodNameSlice = "/$methodFullName".toGrpcSlice()
126-
val authoritySlice = authority?.toGrpcSlice()
127147

128148
val rawCall = grpc_channel_create_call(
129149
channel = raw,
130150
parent_call = null,
131151
propagation_mask = GRPC_PROPAGATE_DEFAULTS,
132152
completion_queue = cq.raw,
133153
method = methodNameSlice,
134-
host = authoritySlice,
154+
host = null,
135155
deadline = gpr_inf_future(GPR_CLOCK_REALTIME),
136156
reserved = null
137157
) ?: error("Failed to create call")
138158

139159
grpc_slice_unref(methodNameSlice)
140-
authoritySlice?.let { grpc_slice_unref(it) }
141160

142161
return NativeClientCall(
143162
cq, rawCall, methodDescriptor, callJob

0 commit comments

Comments
 (0)