diff --git a/grpc/grpc-core/build.gradle.kts b/grpc/grpc-core/build.gradle.kts index a6f4f5803..d45233654 100644 --- a/grpc/grpc-core/build.gradle.kts +++ b/grpc/grpc-core/build.gradle.kts @@ -14,6 +14,8 @@ plugins { alias(libs.plugins.kotlinx.rpc) } + + kotlin { compilerOptions { freeCompilerArgs.add("-Xexpect-actual-classes") @@ -27,6 +29,7 @@ kotlin { api(libs.coroutines.core) implementation(libs.atomicfu) + implementation(libs.kotlinx.io.core) } } @@ -58,6 +61,12 @@ kotlin { } } + nativeMain { + dependencies { + implementation(libs.kotlinx.collections.immutable) + } + } + nativeTest { dependencies { implementation(kotlin("test")) @@ -83,7 +92,7 @@ kotlin { val buildGrpcppCLib = tasks.register("buildGrpcppCLib") { group = "build" workingDir = grpcppCLib - commandLine("bash", "-c", "bazel build :grpcpp_c_static --config=release") + commandLine("bash", "-c", "bazel build :grpcpp_c_static :protowire_static --config=release") inputs.files(fileTree(grpcppCLib) { exclude("bazel-*/**") }) outputs.dir(grpcppCLib.resolve("bazel-bin")) @@ -108,6 +117,22 @@ kotlin { tasks.named(interopTask, CInteropProcess::class) { dependsOn(buildGrpcppCLib) } + + + val libprotowire by creating { + includeDirs( + grpcppCLib.resolve("include") + ) + extraOpts( + "-libraryPath", "${grpcppCLib.resolve("bazel-out/darwin_arm64-opt/bin")}", + ) + } + + val libUpbTask = "cinterop${libprotowire.name.capitalized()}${it.targetName.capitalized()}" + tasks.named(libUpbTask, CInteropProcess::class) { + dependsOn(buildGrpcppCLib) + } + } } } diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/KTag.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/KTag.kt new file mode 100644 index 000000000..cd9102250 --- /dev/null +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/KTag.kt @@ -0,0 +1,51 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.rpc.grpc.internal.KTag.Companion.K_TAG_TYPE_BITS + +internal enum class WireType { + VARINT, // 0 + FIXED64, // 1 + LENGTH_DELIMITED, // 2 + START_GROUP, // 3 + END_GROUP, // 4 + FIXED32, // 5 +} + +internal data class KTag(val fieldNr: Int, val wireType: WireType) { + + init { + check(isValidFieldNr(fieldNr)) { "Invalid field number: $fieldNr" } + } + + companion object { + // Number of bits in a tag which identify the wire type. + const val K_TAG_TYPE_BITS: Int = 3; + + // Mask for those bits. (just 0b111) + val K_TAG_TYPE_MASK: UInt = (1u shl K_TAG_TYPE_BITS) - 1u + } +} + +internal fun KTag.toRawKTag(): UInt { + return (fieldNr.toUInt() shl K_TAG_TYPE_BITS) or wireType.ordinal.toUInt() +} + +internal fun KTag.Companion.fromOrNull(rawKTag: UInt): KTag? { + val type = (rawKTag and K_TAG_TYPE_MASK).toInt() + val field = (rawKTag shr K_TAG_TYPE_BITS).toInt() + if (!isValidFieldNr(field)) { + return null + } + if (type >= WireType.entries.size) { + return null + } + return KTag(field, WireType.entries[type]) +} + +internal fun KTag.Companion.isValidFieldNr(fieldNr: Int): Boolean { + return 1 <= fieldNr && fieldNr <= 536_870_911 +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.kt new file mode 100644 index 000000000..ece585248 --- /dev/null +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.kt @@ -0,0 +1,81 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.io.Buffer + +/** + * A platform-specific decoder for wire format data. + * + * This decoder is used by first calling [readTag], than looking up the field based on the field number in the returned, + * tag and then calling the actual `read*()` method to read the value to the corresponding field. + * + * [hadError] indicates an error during decoding. While calling `read*()` is safe, the returned values + * are meaningless if [hadError] returns `true`. + * + * NOTE: If the [hadError] after a call to `read*()` returns `false`, it doesn't mean that the + * value is correctly decoded. E.g., the following test will pass: + * ```kt + * val fieldNr = 1 + * val buffer = Buffer() + * + * val encoder = WireEncoder(buffer) + * assertTrue(encoder.writeInt32(fieldNr, 12312)) + * encoder.flush() + * + * WireDecoder(buffer).use { decoder -> + * decoder.readTag() + * decoder.readBool() + * assertFalse(decoder.hasError()) + * } + * ``` + */ +internal interface WireDecoder : AutoCloseable { + fun hadError(): Boolean + fun readTag(): KTag? + fun readBool(): Boolean + fun readInt32(): Int + fun readInt64(): Long + fun readUInt32(): UInt + fun readUInt64(): ULong + fun readSInt32(): Int + fun readSInt64(): Long + fun readFixed32(): UInt + fun readFixed64(): ULong + fun readSFixed32(): Int + fun readSFixed64(): Long + fun readFloat(): Float + fun readDouble(): Double + + fun readEnum(): Int + fun readString(): String + fun readBytes(): ByteArray + fun readPackedBool(): List + fun readPackedInt32(): List + fun readPackedInt64(): List + fun readPackedSInt32(): List + fun readPackedSInt64(): List + fun readPackedUInt32(): List + fun readPackedUInt64(): List + fun readPackedFixed32(): List + fun readPackedFixed64(): List + fun readPackedSFixed32(): List + fun readPackedSFixed64(): List + fun readPackedFloat(): List + fun readPackedDouble(): List + fun readPackedEnum(): List +} + +/** + * Creates a platform-specific [WireDecoder]. + * + * This constructor takes a [Buffer] instead of a [kotlinx.io.Source] because + * the native implementation (`WireDecoderNative`) depends on [Buffer]'s internal structure. + * + * NOTE: Do not use the [source] buffer while the [WireDecoder] is still open. + * + * @param source The buffer containing the encoded wire-format data. + */ +internal expect fun WireDecoder(source: Buffer): WireDecoder \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.kt new file mode 100644 index 000000000..918b30bef --- /dev/null +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.kt @@ -0,0 +1,54 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.io.Sink + +/** + * A platform-specific class that encodes values into protobuf's wire format. + * + * If one `write*()` method returns false, the encoding of the value failed + * and no further encodings can be performed on this [WireEncoder]. + * + * [flush] must be called to ensure that all data is written to the [Sink]. + */ +@OptIn(ExperimentalUnsignedTypes::class) +internal interface WireEncoder { + fun flush() + fun writeBool(field: Int, value: Boolean): Boolean + fun writeInt32(fieldNr: Int, value: Int): Boolean + fun writeInt64(fieldNr: Int, value: Long): Boolean + fun writeUInt32(fieldNr: Int, value: UInt): Boolean + fun writeUInt64(fieldNr: Int, value: ULong): Boolean + fun writeSInt32(fieldNr: Int, value: Int): Boolean + fun writeSInt64(fieldNr: Int, value: Long): Boolean + fun writeFixed32(fieldNr: Int, value: UInt): Boolean + fun writeFixed64(fieldNr: Int, value: ULong): Boolean + fun writeSFixed32(fieldNr: Int, value: Int): Boolean + fun writeSFixed64(fieldNr: Int, value: Long): Boolean + fun writeFloat(fieldNr: Int, value: Float): Boolean + fun writeDouble(fieldNr: Int, value: Double): Boolean + fun writeEnum(fieldNr: Int, value: Int): Boolean + fun writeBytes(fieldNr: Int, value: ByteArray): Boolean + fun writeString(fieldNr: Int, value: String): Boolean + fun writePackedBool(fieldNr: Int, value: List, fieldSize: Int): Boolean + fun writePackedInt32(fieldNr: Int, value: List, fieldSize: Int): Boolean + fun writePackedInt64(fieldNr: Int, value: List, fieldSize: Int): Boolean + fun writePackedUInt32(fieldNr: Int, value: List, fieldSize: Int): Boolean + fun writePackedUInt64(fieldNr: Int, value: List, fieldSize: Int): Boolean + fun writePackedSInt32(fieldNr: Int, value: List, fieldSize: Int): Boolean + fun writePackedSInt64(fieldNr: Int, value: List, fieldSize: Int): Boolean + fun writePackedFixed32(fieldNr: Int, value: List): Boolean + fun writePackedFixed64(fieldNr: Int, value: List): Boolean + fun writePackedSFixed32(fieldNr: Int, value: List): Boolean + fun writePackedSFixed64(fieldNr: Int, value: List): Boolean + fun writePackedFloat(fieldNr: Int, value: List): Boolean + fun writePackedDouble(fieldNr: Int, value: List): Boolean + fun writePackedEnum(fieldNr: Int, value: List, fieldSize: Int) = + writePackedInt32(fieldNr, value, fieldSize) +} + + +internal expect fun WireEncoder(sink: Sink): WireEncoder \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.kt new file mode 100644 index 000000000..161a4fe92 --- /dev/null +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.kt @@ -0,0 +1,24 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +internal object WireSize + +internal expect fun WireSize.int32(value: Int): UInt +internal expect fun WireSize.int64(value: Long): UInt +internal expect fun WireSize.uInt32(value: UInt): UInt +internal expect fun WireSize.uInt64(value: ULong): UInt +internal expect fun WireSize.sInt32(value: Int): UInt +internal expect fun WireSize.sInt64(value: Long): UInt + +internal fun WireSize.bool(value: Boolean) = int32(if (value) 1 else 0) +internal fun WireSize.enum(value: Int) = int32(value) +internal fun WireSize.packedInt32(value: List) = value.sumOf { int32(it) } +internal fun WireSize.packedInt64(value: List) = value.sumOf { int64(it) } +internal fun WireSize.packedUInt32(value: List) = value.sumOf { uInt32(it) } +internal fun WireSize.packedUInt64(value: List) = value.sumOf { uInt64(it) } +internal fun WireSize.packedSInt32(value: List) = value.sumOf { sInt32(it) } +internal fun WireSize.packedSInt64(value: List) = value.sumOf { sInt64(it) } +internal fun WireSize.packedEnum(value: List) = value.sumOf { enum(it) } diff --git a/grpc/grpc-core/src/jsMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.js.kt b/grpc/grpc-core/src/jsMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.js.kt new file mode 100644 index 000000000..f56885157 --- /dev/null +++ b/grpc/grpc-core/src/jsMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.js.kt @@ -0,0 +1,12 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.io.Buffer +import kotlinx.io.Source + +internal actual fun WireDecoder(source: Buffer): WireDecoder { + TODO("Not yet implemented") +} \ No newline at end of file diff --git a/grpc/grpc-core/src/jsMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.js.kt b/grpc/grpc-core/src/jsMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.js.kt new file mode 100644 index 000000000..00c0b3246 --- /dev/null +++ b/grpc/grpc-core/src/jsMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.js.kt @@ -0,0 +1,11 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.io.Sink + +internal actual fun WireEncoder(sink: Sink): WireEncoder { + TODO("Not yet implemented") +} \ No newline at end of file diff --git a/grpc/grpc-core/src/jsMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.js.kt b/grpc/grpc-core/src/jsMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.js.kt new file mode 100644 index 000000000..e70b9bd00 --- /dev/null +++ b/grpc/grpc-core/src/jsMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.js.kt @@ -0,0 +1,29 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +internal actual fun WireSize.int32(value: Int): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.int64(value: Long): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.uInt32(value: UInt): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.uInt64(value: ULong): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.sInt32(value: Int): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.sInt64(value: Long): UInt { + TODO("Not yet implemented") +} \ No newline at end of file diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.jvm.kt new file mode 100644 index 000000000..f56885157 --- /dev/null +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.jvm.kt @@ -0,0 +1,12 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.io.Buffer +import kotlinx.io.Source + +internal actual fun WireDecoder(source: Buffer): WireDecoder { + TODO("Not yet implemented") +} \ No newline at end of file diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.jvm.kt new file mode 100644 index 000000000..00c0b3246 --- /dev/null +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.jvm.kt @@ -0,0 +1,11 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.io.Sink + +internal actual fun WireEncoder(sink: Sink): WireEncoder { + TODO("Not yet implemented") +} \ No newline at end of file diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.jvm.kt new file mode 100644 index 000000000..e70b9bd00 --- /dev/null +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.jvm.kt @@ -0,0 +1,29 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +internal actual fun WireSize.int32(value: Int): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.int64(value: Long): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.uInt32(value: UInt): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.uInt64(value: ULong): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.sInt32(value: Int): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.sInt64(value: Long): UInt { + TODO("Not yet implemented") +} \ No newline at end of file diff --git a/grpc/grpc-core/src/nativeInterop/cinterop/libprotowire.def b/grpc/grpc-core/src/nativeInterop/cinterop/libprotowire.def new file mode 100644 index 000000000..3868c5fc6 --- /dev/null +++ b/grpc/grpc-core/src/nativeInterop/cinterop/libprotowire.def @@ -0,0 +1,6 @@ +headers = protowire.h +headerFilter = protowire.h + +noStringConversion = pw_encoder_write_string + +staticLibraries = libprotowire_static.a \ No newline at end of file diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.native.kt new file mode 100644 index 000000000..782305b1f --- /dev/null +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.native.kt @@ -0,0 +1,337 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.cinterop.* +import kotlinx.collections.immutable.persistentListOf +import kotlinx.io.Buffer +import libprotowire.* +import kotlin.experimental.ExperimentalNativeApi +import kotlin.math.min +import kotlin.native.ref.createCleaner + +// TODO: Evaluate if this buffer size is suitable for all targets (KRPC-186) +// maximum buffer size to allocate as contiguous memory in bytes +private const val MAX_PACKED_BULK_SIZE: Int = 1_000_000 + +@OptIn(ExperimentalForeignApi::class, ExperimentalNativeApi::class) +internal class WireDecoderNative(private val source: Buffer) : WireDecoder { + + private var hadError = false; + + // wraps the source in a class that allows to pass data from the source buffer to the C++ encoder + // without copying it to an intermediate byte array. + private val zeroCopyInput = StableRef.create(ZeroCopyInputSource(source)) + + // construct the pw_decoder_t by passing a pw_zero_copy_input_t that provides a bridge between + // the CodedInputStream and the given source buffer. it passes functions that call the respective + // ZeroCopyInputSource methods. + internal val raw: CPointer = run { + // construct the pw_zero_copy_input_t that functions as a bridge to the ZeroCopyInputSource + val zeroCopyCInput = cValue { + ctx = zeroCopyInput.asCPointer() + next = staticCFunction { ctx, data, size -> + ctx!!.asStableRef().get().next(data!!.reinterpret(), size!!.reinterpret()) + } + backUp = staticCFunction { ctx, count -> + ctx!!.asStableRef().get().backUp(count) + } + skip = staticCFunction { ctx, count -> + ctx!!.asStableRef().get().skip(count) + } + byteCount = staticCFunction { ctx -> + ctx!!.asStableRef().get().byteCount() + } + } + pw_decoder_new(zeroCopyCInput) + ?: error("Failed to create proto wire decoder") + } + + val rawCleaner = createCleaner(raw) { + pw_decoder_delete(it) + } + + + override fun close() { + // this will fix the position in the source buffer + // (done by deconstructor of CodedInputStream) + pw_decoder_close(raw) + + zeroCopyInput.get().close() + zeroCopyInput.dispose() + } + + override fun hadError(): Boolean { + return hadError; + } + + override fun readTag(): KTag? { + val tag = pw_decoder_read_tag(raw) + if (tag == 0u) return null.withError() + val kTag = KTag.fromOrNull(tag) + if (kTag == null) { + hadError = true + } + return kTag + } + + override fun readBool(): Boolean = memScoped { + val value = alloc() + pw_decoder_read_bool(raw, value.ptr).checkError() + return value.value + } + + override fun readInt32(): Int = memScoped { + val value = alloc() + pw_decoder_read_int32(raw, value.ptr).checkError() + return value.value + } + + override fun readInt64(): Long = memScoped { + val value = alloc() + pw_decoder_read_int64(raw, value.ptr).checkError() + return value.value + } + + override fun readUInt32(): UInt = memScoped { + val value = alloc() + pw_decoder_read_uint32(raw, value.ptr).checkError() + return value.value + } + + override fun readUInt64(): ULong = memScoped { + val value = alloc() + pw_decoder_read_uint64(raw, value.ptr).checkError() + return value.value + } + + override fun readSInt32(): Int = memScoped { + val value = alloc() + pw_decoder_read_sint32(raw, value.ptr).checkError() + return value.value + } + + override fun readSInt64(): Long = memScoped { + val value = alloc() + pw_decoder_read_sint64(raw, value.ptr).checkError() + return value.value + } + + override fun readFixed32(): UInt = memScoped { + val value = alloc() + pw_decoder_read_fixed32(raw, value.ptr).checkError() + return value.value + } + + override fun readFixed64(): ULong = memScoped { + val value = alloc() + pw_decoder_read_fixed64(raw, value.ptr).checkError() + return value.value + } + + override fun readSFixed32(): Int = memScoped { + val value = alloc() + pw_decoder_read_sfixed32(raw, value.ptr).checkError() + return value.value + } + + override fun readSFixed64(): Long = memScoped { + val value = alloc() + pw_decoder_read_sfixed64(raw, value.ptr).checkError() + return value.value + } + + override fun readFloat(): Float = memScoped { + val value = alloc() + pw_decoder_read_float(raw, value.ptr).checkError() + return value.value + } + + override fun readDouble(): Double = memScoped { + val value = alloc() + pw_decoder_read_double(raw, value.ptr).checkError() + return value.value + } + + override fun readEnum(): Int = memScoped { + val value = alloc() + pw_decoder_read_enum(raw, value.ptr).checkError() + return value.value + } + + // TODO: Is it possible to avoid copying the c_str, by directly allocating a K/N String (as in readBytes)? KRPC-187 + override fun readString(): String = memScoped { + val str = alloc>() + pw_decoder_read_string(raw, str.ptr).checkError() + try { + if (hadError) return "" + return pw_string_c_str(str.value)?.toKString() ?: "".also { hadError = true } + } finally { + pw_string_delete(str.value) + } + } + + // TODO: Should readBytes return a buffer, to prevent allocation of large contiguous memory blocks ? KRPC-182 + override fun readBytes(): ByteArray { + val length = readInt32() + if (hadError) return ByteArray(0) + if (length < 0) return ByteArray(0).withError() + // check if the remaining buffer size is less than the set length, + // we can early abort, without allocating unnecessary memory + if (source.size < length) return ByteArray(0).withError() + if (length == 0) return ByteArray(0) // actually an empty array (no error) + val bytes = ByteArray(length) + bytes.usePinned { + pw_decoder_read_raw_bytes(raw, it.addressOf(0), length).checkError() + } + if (hadError) return ByteArray(0) + return bytes + } + + override fun readPackedBool() = readPackedVarInternal(this::readBool) + override fun readPackedInt32() = readPackedVarInternal(this::readInt32) + override fun readPackedInt64() = readPackedVarInternal(this::readInt64) + override fun readPackedUInt32() = readPackedVarInternal(this::readUInt32) + override fun readPackedUInt64() = readPackedVarInternal(this::readUInt64) + override fun readPackedSInt32() = readPackedVarInternal(this::readSInt32) + override fun readPackedSInt64() = readPackedVarInternal(this::readSInt64) + override fun readPackedEnum() = readPackedVarInternal(this::readEnum) + + override fun readPackedFixed32() = readPackedFixedInternal( + UInt.SIZE_BYTES, + ::UIntArray, + Pinned::addressOf, + UIntArray::asList, + ) + + override fun readPackedFixed64() = readPackedFixedInternal( + ULong.SIZE_BYTES, + ::ULongArray, + Pinned::addressOf, + ULongArray::asList, + ) + + override fun readPackedSFixed32() = readPackedFixedInternal( + Int.SIZE_BYTES, + ::IntArray, + Pinned::addressOf, + IntArray::asList, + ) + + override fun readPackedSFixed64() = readPackedFixedInternal( + Long.SIZE_BYTES, + ::LongArray, + Pinned::addressOf, + LongArray::asList, + ) + + override fun readPackedFloat() = readPackedFixedInternal( + Float.SIZE_BYTES, + ::FloatArray, + Pinned::addressOf, + FloatArray::asList, + ) + + override fun readPackedDouble() = readPackedFixedInternal( + Double.SIZE_BYTES, + ::DoubleArray, + Pinned::addressOf, + DoubleArray::asList, + ) + + private inline fun readPackedVarInternal( + crossinline readFn: () -> T + ): List { + val byteLen = readInt32() + if (hadError) return emptyList() + if (byteLen < 0) return emptyList().withError() + if (source.size < byteLen) return emptyList().withError() + if (byteLen == 0) return emptyList() // actually an empty list (no error) + + val limit = pw_decoder_push_limit(raw, byteLen) + + val result = mutableListOf() + + while (pw_decoder_bytes_until_limit(raw) > 0) { + val elem = readFn() + if (hadError) break + result.add(elem) + } + + pw_decoder_pop_limit(raw, limit) + return result + } + + /* + * Based on the length of the packed repeated field, one of two list strategies is chosen. + * If the length is less or equal a specific threshold (MAX_PACKED_BULK_SIZE), + * a single array list is filled with the buffer-packed value (two copies). + * Otherwise, a kotlinx.collections.immutable.PersistentList is used to split allocation in several chunks. + * To build the persistent list, a buffer array is allocated that is used for fast copy from C++ to Kotlin. + * + * Note that this implementation assumes a little endian memory order. + */ + private inline fun readPackedFixedInternal( + sizeBytes: Int, + crossinline createArray: (Int) -> R, + crossinline getAddress: Pinned.(Int) -> COpaquePointer, + crossinline asList: (R) -> List + ): List { + // fetch the size of the packed repeated field + var byteLen = readInt32() + if (hadError) return emptyList() + if (byteLen < 0) return emptyList().withError() + if (source.size < byteLen) return emptyList().withError() + if (byteLen % sizeBytes != 0) return emptyList().withError() + if (byteLen == 0) return emptyList() // actually an empty list (no error) + + // allocate the buffer array (has at most MAX_PACKED_BULK_SIZE bytes) + val bufByteLen = minOf(byteLen, MAX_PACKED_BULK_SIZE) + val bufElemCount = bufByteLen / sizeBytes + val buffer = createArray(bufElemCount) + + buffer.usePinned { + val bufAddr = it.getAddress(0) + + if (byteLen == bufByteLen) { + // the whole packed field fits into the buffer -> copy into buffer and returns it as a list. + pw_decoder_read_raw_bytes(raw, bufAddr, byteLen).checkError() + return asList(buffer) + } else { + // the packed field is too large for the buffer, so we load it into a persistent list + var chunkedList = persistentListOf() + + while (byteLen > 0) { + // copy data into the buffer. + val copySize = min(bufByteLen, byteLen) + pw_decoder_read_raw_bytes(raw, bufAddr, copySize).checkError() + if (hadError) return emptyList() + + // add buffer to the chunked list + chunkedList = if (copySize == bufByteLen) { + chunkedList.addAll(asList(buffer)) + } else { + chunkedList.addAll(asList(buffer).subList(0, copySize / sizeBytes)) + } + + byteLen -= copySize + } + + return chunkedList + } + } + } + + private fun Boolean.checkError() { + hadError = !this || hadError; + } + + private fun T.withError(): T { + hadError = true + return this + } +} + +internal actual fun WireDecoder(source: Buffer): WireDecoder = WireDecoderNative(source) \ No newline at end of file diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.native.kt new file mode 100644 index 000000000..9944ce0af --- /dev/null +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.native.kt @@ -0,0 +1,187 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.cinterop.* +import kotlinx.io.Sink +import libprotowire.* +import kotlin.experimental.ExperimentalNativeApi +import kotlin.native.ref.createCleaner + + +@OptIn(ExperimentalForeignApi::class, ExperimentalNativeApi::class) +internal class WireEncoderNative(private val sink: Sink) : WireEncoder { + /** + * The context object provides a stable reference to the kotlin context. + * This is required, as functions must be static and cannot capture environment references. + * With this context, the write callback (called by the pw_encoder_t) is able + * to write the data to the [sink]. + */ + private inner class Ctx { + fun write(buf: CPointer, size: Int): Boolean { + sink.writeFully(buf, 0L, size.toLong()) + return true + } + } + + // create context as a stable reference that can be passed to static function callback + private val context = StableRef.create(this.Ctx()) + + // construct encoder with a callback that calls write() on this.context + internal val raw: CPointer = run { + pw_encoder_new(context.asCPointer(), staticCFunction { ctx, buf, size -> + if (buf == null || ctx == null) { + return@staticCFunction false + } + ctx.asStableRef().get().write(buf.reinterpret(), size) + }) ?: error("Failed to create proto wire encoder") + } + + private val contextCleaner = createCleaner(context) { + it.dispose() + } + private val rawCleaner = createCleaner(raw) { + pw_encoder_delete(it) + } + + override fun flush() { + pw_encoder_flush(raw) + } + + override fun writeBool(field: Int, value: Boolean): Boolean { + return pw_encoder_write_bool(raw, field, value) + } + + override fun writeInt32(fieldNr: Int, value: Int): Boolean { + return pw_encoder_write_int32(raw, fieldNr, value) + } + + override fun writeInt64(fieldNr: Int, value: Long): Boolean { + return pw_encoder_write_int64(raw, fieldNr, value) + } + + override fun writeUInt32(fieldNr: Int, value: UInt): Boolean { + return pw_encoder_write_uint32(raw, fieldNr, value) + } + + override fun writeUInt64(fieldNr: Int, value: ULong): Boolean { + return pw_encoder_write_uint64(raw, fieldNr, value) + } + + override fun writeSInt32(fieldNr: Int, value: Int): Boolean { + return pw_encoder_write_sint32(raw, fieldNr, value) + } + + override fun writeSInt64(fieldNr: Int, value: Long): Boolean { + return pw_encoder_write_sint64(raw, fieldNr, value) + } + + override fun writeFixed32(fieldNr: Int, value: UInt): Boolean { + return pw_encoder_write_fixed32(raw, fieldNr, value) + } + + override fun writeFixed64(fieldNr: Int, value: ULong): Boolean { + return pw_encoder_write_fixed64(raw, fieldNr, value) + } + + override fun writeSFixed32(fieldNr: Int, value: Int): Boolean { + return pw_encoder_write_sfixed32(raw, fieldNr, value) + } + + override fun writeSFixed64(fieldNr: Int, value: Long): Boolean { + return pw_encoder_write_sfixed64(raw, fieldNr, value) + } + + override fun writeFloat(fieldNr: Int, value: Float): Boolean { + return pw_encoder_write_float(raw, fieldNr, value) + } + + override fun writeDouble(fieldNr: Int, value: Double): Boolean { + return pw_encoder_write_double(raw, fieldNr, value) + } + + override fun writeEnum(fieldNr: Int, value: Int): Boolean { + return pw_encoder_write_enum(raw, fieldNr, value) + } + + override fun writeString(fieldNr: Int, value: String): Boolean = memScoped { + if (value.isEmpty()) { + return pw_encoder_write_string(raw, fieldNr, null, 0) + } + val cStr = value.cstr + return pw_encoder_write_string(raw, fieldNr, cStr.ptr, cStr.size) + } + + override fun writeBytes(fieldNr: Int, value: ByteArray): Boolean { + if (value.isEmpty()) { + return pw_encoder_write_bytes(raw, fieldNr, null, 0) + } + return value.usePinned { + pw_encoder_write_bytes(raw, fieldNr, it.addressOf(0), value.size) + } + } + + override fun writePackedBool(fieldNr: Int, value: List, fieldSize: Int) = + writePackedInternal(fieldNr, value, fieldSize, ::pw_encoder_write_bool_no_tag) + + override fun writePackedInt32(fieldNr: Int, value: List, fieldSize: Int) = + writePackedInternal(fieldNr, value, fieldSize, ::pw_encoder_write_int32_no_tag) + + override fun writePackedInt64(fieldNr: Int, value: List, fieldSize: Int) = + writePackedInternal(fieldNr, value, fieldSize, ::pw_encoder_write_int64_no_tag) + + override fun writePackedUInt32(fieldNr: Int, value: List, fieldSize: Int) = + writePackedInternal(fieldNr, value, fieldSize, ::pw_encoder_write_uint32_no_tag) + + override fun writePackedUInt64(fieldNr: Int, value: List, fieldSize: Int) = + writePackedInternal(fieldNr, value, fieldSize, ::pw_encoder_write_uint64_no_tag) + + override fun writePackedSInt32(fieldNr: Int, value: List, fieldSize: Int) = + writePackedInternal(fieldNr, value, fieldSize, ::pw_encoder_write_sint32_no_tag) + + override fun writePackedSInt64(fieldNr: Int, value: List, fieldSize: Int) = + writePackedInternal(fieldNr, value, fieldSize, ::pw_encoder_write_sint64_no_tag) + + override fun writePackedFixed32(fieldNr: Int, value: List) = + writePackedInternal(fieldNr, value, value.size * UInt.SIZE_BYTES, ::pw_encoder_write_fixed32_no_tag) + + override fun writePackedFixed64(fieldNr: Int, value: List) = + writePackedInternal(fieldNr, value, value.size * ULong.SIZE_BYTES, ::pw_encoder_write_fixed64_no_tag) + + override fun writePackedSFixed32(fieldNr: Int, value: List) = + writePackedInternal(fieldNr, value, value.size * Int.SIZE_BYTES, ::pw_encoder_write_sfixed32_no_tag) + + override fun writePackedSFixed64(fieldNr: Int, value: List) = + writePackedInternal(fieldNr, value, value.size * Long.SIZE_BYTES, ::pw_encoder_write_sfixed64_no_tag) + + override fun writePackedFloat(fieldNr: Int, value: List) = + writePackedInternal(fieldNr, value, value.size * Float.SIZE_BYTES, ::pw_encoder_write_float_no_tag) + + override fun writePackedDouble(fieldNr: Int, value: List) = + writePackedInternal(fieldNr, value, value.size * Double.SIZE_BYTES, ::pw_encoder_write_double_no_tag) +} + +internal actual fun WireEncoder(sink: Sink): WireEncoder = WireEncoderNative(sink) + + +// the current implementation is slow, as it iterates through the list, to write each element individually, +// which can be speed up in case of fixed sized types, that are not compressed. KRPC-183 +@OptIn(ExperimentalForeignApi::class) +private inline fun WireEncoderNative.writePackedInternal( + fieldNr: Int, + value: List, + fieldSize: Int, + crossinline writer: (CValuesRef?, T) -> Boolean +): Boolean { + pw_encoder_write_tag(raw, fieldNr, WireType.LENGTH_DELIMITED.ordinal) + // write the field size of the packed field + pw_encoder_write_int32_no_tag(raw, fieldSize) + for (v in value) { + if (!writer(raw, v)) { + return false + } + } + return true +} \ No newline at end of file diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.native.kt new file mode 100644 index 000000000..432479129 --- /dev/null +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.native.kt @@ -0,0 +1,18 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +@file:OptIn(ExperimentalForeignApi::class) + +package kotlinx.rpc.grpc.internal + +import kotlinx.cinterop.ExperimentalForeignApi +import libprotowire.* + +internal actual fun WireSize.int32(value: Int) = pw_size_int32(value) +internal actual fun WireSize.int64(value: Long) = pw_size_int64(value) +internal actual fun WireSize.uInt32(value: UInt) = pw_size_uint32(value) +internal actual fun WireSize.uInt64(value: ULong) = pw_size_uint64(value) +internal actual fun WireSize.sInt32(value: Int) = pw_size_sint32(value) +internal actual fun WireSize.sInt64(value: Long) = pw_size_sint64(value) + diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/ZeroCopyInputSource.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/ZeroCopyInputSource.kt new file mode 100644 index 000000000..e20df0607 --- /dev/null +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/ZeroCopyInputSource.kt @@ -0,0 +1,191 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.cinterop.* +import kotlinx.io.Buffer +import kotlinx.io.EOFException +import kotlinx.io.InternalIoApi +import kotlinx.io.UnsafeIoApi +import kotlinx.io.unsafe.UnsafeBufferOperations + + +/** + * Handles (almost) zero-copy input operations on a [Buffer], allowing efficient transfer of data + * without creating intermediate copies. This class provides mechanisms to iterate over + * buffer data in a zero-copy manner and supports backing up, skipping, and advancing data. + * + * This class is intended for internal use only and specifically designed to be used + * as a bridge between [Buffer] and the C++ protobuf `ZeroCopyInputStream`. + * This implementation assumes that the data of a buffer segment is backed by a [ByteArray]. + * + * The inner [Buffer] MUST NOT be accessed while using the [ZeroCopyInputSource] as it highly + * depends on tracked state of [Buffer] internals. Each read or write to the underlying buffer + * will result in undefined behavior. + * Additionally, [ZeroCopyInputSource] is not thread-safe, so concurrent access might also + * result in undefined behavior. + * + * Unlike [Buffer.readByte], the [ZeroCopyInputSource.next] does directly consume the data + * in the [Buffer]. This has two reasons: + * 1. The underlying [ByteArray] must stay valid after the `next()` call + * 2. The [ZeroCopyInputSource.backUp] method might preserve bytes that where already read + * during the last [ZeroCopyInputSource.next] call. This method is required by the + * `ZeroCopyInputStream` interface of protobuf. + * Because of this, the inner [Buffer] is in an invalid read position during the use of + * [ZeroCopyInputSource]. After closing the [ZeroCopyInputSource] the inner [Buffer] + * is valid again. However, the buffer might be further advanced than expected, + * depending on whether the user called [backUp] for the unused bytes. + * + * The that memory received by a call to [ZeroCopyInputSource.next] is only valid until the next + * invocation of any method of [ZeroCopyInputSource]. + * + * @param inner The underlying [Buffer] to read data from. If must not be accessed while using + * [ZeroCopyInputSource]. + */ +@OptIn(ExperimentalForeignApi::class, InternalIoApi::class, UnsafeIoApi::class) +internal class ZeroCopyInputSource(private val inner: Buffer) : AutoCloseable { + + // number of bytes read since construction + private var byteCount = 0L + // the array segment that was read by the latest call to next() + // while it was already read by the ZeroCopyInputSource user, it is not yet + // released by the buffer. this is done by releaseLatestReadSegement + // which releases the segment in the inner buffer. + private var latestReadSegementArray: Pinned? = null + private var closed = false; + + /** + * Get access to a segment of continuous bytes in the underlying [Buffer]. + * The returned memory gets invalid with a call to `next(), backUp(), skip()` or `close()`. + * If the method returns `false`, if the inner buffer is exhausted. The `outData` and + * `outSize` remain unset in this case. + * + * @return false if the buffer is exhausted, otherwise true + */ + fun next(outData: CPointer>, outSize: CPointer): Boolean { + check(!closed) { "ZeroCopyInputSource has already been closed." } + if (latestReadSegementArray != null) { + // if there is some unreleased segment array, we must release it first. + // this will advance the head of the buffer to the correct position. + releaseLatestReadSegment() + } + if (inner.exhausted()) { + return false + } + // perform access to the underlying array of the buffer's current segment + UnsafeBufferOperations.readFromHead(inner.buffer) { arr, start, end -> + check(latestReadSegementArray == null) { "currArr must be null at this point"} + // fix the array so it does not move in memory, which is important as we pass its + // memory address as a result to the caller. + latestReadSegementArray = arr.pin() + + val segmentSize = end - start + outData.pointed.value = latestReadSegementArray!!.addressOf(start) + outSize.pointed.value = segmentSize; + + byteCount += segmentSize; + + // we are not yet advancing the inner buffer head. + // this ensures that the segment array is not released by the buffer and remains valid + 0 + } + return true; + } + + /** + * Allows to replay [count] many bytes of the previously read segment. + * This is useful when writing procedures that are only supposed to read up + * to a certain point in the input, then return. If [next] returns a + * buffer that goes beyond what you wanted to read, you can use [backUp] + * to return to the point where you intended to finish. + * ```kt + * next(...) // access the current buffer segment + * backUp(10) // back up the last 10 bytes of the previous accessed segment + * next(...) // read the 10 last bytes of the previous accessed segment again + * ``` + * This is only possible if [next] was the last method called. + * + * @throws IllegalStateException if [count] is greater than size of the last read segment (retrieved from [next]). + * + */ + fun backUp(count: Int) { + check(!closed) { "ZeroCopyInputSource has already been closed." } + check(latestReadSegementArray != null) { "next() must be immediately before backUp()" } + val readBytes = releaseLatestReadSegment(count) + check(readBytes >= 0) { "backUp() must not be called more than the number of bytes that were read in next()" } + byteCount -= count; + } + + /** + * Skip [count] bytes of the buffer. + * @return `false` iff the buffer is exhausted before skipping completed, `true` otherwise + */ + fun skip(count: Int): Boolean { + check(!closed) { "ZeroCopyInputSource has already been closed." } + if (latestReadSegementArray != null) { + releaseLatestReadSegment() + } + try { + byteCount += count + inner.skip(count.toLong()) + return true + } catch (_: EOFException) { + return false + } + } + + /** + * The number of bytes read since the object got created. + * If [backUp] is called, it will decrement the number of read bytes by the given amount. + */ + fun byteCount(): Long { + return byteCount + } + + /** + * Releases the latest read segment that was not yet released. + * It won't close the underlying [Buffer]. After closing this, the underlying [Buffer] is + * valid and can be used again. + * + * This [ZeroCopyInputSource] must not be used after closing it. + */ + override fun close() { + if (latestReadSegementArray != null) { + releaseLatestReadSegment() + } + closed = true; + } + + /** + * Releases the segment that was previously read using [next] but not yet released by the buffer. + * It also unpins it, so it can be collected by the GC. This must only be called if [next] was previously called. + * + * The [backUpCount] defines how many bytes of the segment should stay valid (not released). This is used by the + * [backUp] to allow users to replay reading of the latest read segment. + * If the [backUpCount] is greater than the segment size, 0 bytes are read. + * + * @return number of bytes released, based on [backUpCount]. This value might be negative + * if the [backUpCount] is greater than the latest read segment (indicates a user side error). + */ + private fun releaseLatestReadSegment(backUpCount: Int = 0): Int { + check(latestReadSegementArray != null) { "currArr must be not null" } + var readBytes: Int + // the return value of the readFromHead defines the number of bytes that are getting released in the underlying + // buffer. + UnsafeBufferOperations.readFromHead(inner.buffer) { arr, start, end -> + check(latestReadSegementArray?.get() == arr) { + "array to advance must be the SAME as the currArr, was there some access to the underlying buffer?" } + // release the whole segmentSize - the backup count. + readBytes = end - start - backUpCount + // prevent the value from being negative. + val safeReadBytes = maxOf(readBytes, 0) + safeReadBytes + } + // remove tracking of the released segment + latestReadSegementArray?.unpin() + latestReadSegementArray = null; + return readBytes + } +} \ No newline at end of file diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/bridge/GrpcByteBuffer.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/bridge/GrpcByteBuffer.kt similarity index 70% rename from grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/bridge/GrpcByteBuffer.kt rename to grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/bridge/GrpcByteBuffer.kt index fb7317e97..f5df196b2 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/bridge/GrpcByteBuffer.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/bridge/GrpcByteBuffer.kt @@ -2,20 +2,28 @@ * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. */ -package kotlinx.rpc.grpc.bridge +package kotlinx.rpc.grpc.internal.bridge import kotlinx.cinterop.* import libgrpcpp_c.* +import kotlin.experimental.ExperimentalNativeApi +import kotlin.native.ref.createCleaner -@OptIn(ExperimentalForeignApi::class) +@OptIn(ExperimentalForeignApi::class, ExperimentalNativeApi::class) internal class GrpcByteBuffer internal constructor( internal val cByteBuffer: CPointer -) : AutoCloseable { +) { constructor(slice: GrpcSlice) : this(memScoped { grpc_raw_byte_buffer_create(slice.cSlice, 1u) ?: error("Failed to create byte buffer") }) + init { + createCleaner(cByteBuffer) { + grpc_byte_buffer_destroy(it) + } + } + fun intoSlice(): GrpcSlice { memScoped { val respSlice = alloc() @@ -24,8 +32,4 @@ internal class GrpcByteBuffer internal constructor( } } - override fun close() { - grpc_byte_buffer_destroy(cByteBuffer) - } - } diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/bridge/GrpcClient.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/bridge/GrpcClient.kt similarity index 88% rename from grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/bridge/GrpcClient.kt rename to grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/bridge/GrpcClient.kt index f79764b87..642712dfa 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/bridge/GrpcClient.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/bridge/GrpcClient.kt @@ -2,19 +2,27 @@ * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. */ -package kotlinx.rpc.grpc.bridge +package kotlinx.rpc.grpc.internal.bridge import kotlinx.cinterop.* import kotlinx.coroutines.suspendCancellableCoroutine import libgrpcpp_c.* import kotlin.coroutines.resume import kotlin.coroutines.resumeWithException +import kotlin.experimental.ExperimentalNativeApi +import kotlin.native.ref.createCleaner -@OptIn(ExperimentalForeignApi::class) -internal class GrpcClient(target: String) : AutoCloseable { +@OptIn(ExperimentalForeignApi::class, ExperimentalNativeApi::class) +internal class GrpcClient(target: String) { private var clientPtr: CPointer = grpc_client_create_insecure(target) ?: error("Failed to create client") + init { + createCleaner(clientPtr) { + grpc_client_delete(it) + } + } + fun callUnaryBlocking(method: String, req: GrpcSlice): GrpcSlice { memScoped { val result = alloc() @@ -62,9 +70,4 @@ internal class GrpcClient(target: String) : AutoCloseable { cbCtxStable.dispose() }) } - - override fun close() { - grpc_client_delete(clientPtr) - } - } diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/bridge/GrpcSlice.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/bridge/GrpcSlice.kt similarity index 67% rename from grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/bridge/GrpcSlice.kt rename to grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/bridge/GrpcSlice.kt index 9f9e06271..70ab9a515 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/bridge/GrpcSlice.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/bridge/GrpcSlice.kt @@ -2,7 +2,7 @@ * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. */ -package kotlinx.rpc.grpc.bridge +package kotlinx.rpc.grpc.internal.bridge import kotlinx.cinterop.CValue import kotlinx.cinterop.ExperimentalForeignApi @@ -11,9 +11,11 @@ import kotlinx.cinterop.usePinned import libgrpcpp_c.grpc_slice import libgrpcpp_c.grpc_slice_from_copied_buffer import libgrpcpp_c.grpc_slice_unref +import kotlin.experimental.ExperimentalNativeApi +import kotlin.native.ref.createCleaner -@OptIn(ExperimentalForeignApi::class) -internal class GrpcSlice internal constructor(internal val cSlice: CValue) : AutoCloseable { +@OptIn(ExperimentalForeignApi::class, ExperimentalNativeApi::class) +internal class GrpcSlice internal constructor(internal val cSlice: CValue) { constructor(buffer: ByteArray) : this( buffer.usePinned { pinned -> @@ -21,7 +23,9 @@ internal class GrpcSlice internal constructor(internal val cSlice: CValue, offset: Long, length: Long) { + var consumed = 0L + while (consumed < length) { + UnsafeBufferOperations.writeToTail(this.buffer, 1) { array, start, endExclusive -> + val size = minOf(length - consumed, (endExclusive - start).toLong()) + + array.usePinned { + memcpy(it.addressOf(start), buffer + offset + consumed, size.convert()) + } + + consumed += size + size.toInt() + } + } +} \ No newline at end of file diff --git a/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/BridgeTest.kt b/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/BridgeTest.kt deleted file mode 100644 index ead336b5f..000000000 --- a/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/BridgeTest.kt +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. - */ - -package kotlinx.rpc.grpc - -import kotlinx.coroutines.runBlocking -import kotlinx.rpc.grpc.bridge.GrpcByteBuffer -import kotlinx.rpc.grpc.bridge.GrpcClient -import kotlinx.rpc.grpc.bridge.GrpcSlice -import libgrpcpp_c.pb_decode_greeter_sayhello_response -import kotlin.test.Test - -@OptIn(kotlinx.cinterop.ExperimentalForeignApi::class) -class BridgeTest { - - @Test - fun testBasicUnaryAsyncCall() { - runBlocking { - GrpcClient("localhost:50051").use { client -> - GrpcSlice(byteArrayOf(8, 4)).use { request -> - GrpcByteBuffer(request).use { req_buf -> - client.callUnary("/Greeter/SayHello", req_buf) - .use { result -> - result.intoSlice().use { response -> - val value = pb_decode_greeter_sayhello_response(response.cSlice) - println("Response received: $value") - } - - } - } - } - } - } - } -} diff --git a/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/internal/BridgeTest.kt b/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/internal/BridgeTest.kt new file mode 100644 index 000000000..5c3b8b97a --- /dev/null +++ b/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/internal/BridgeTest.kt @@ -0,0 +1,39 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.coroutines.runBlocking +import kotlinx.rpc.grpc.internal.bridge.GrpcByteBuffer +import kotlinx.rpc.grpc.internal.bridge.GrpcClient +import kotlinx.rpc.grpc.internal.bridge.GrpcSlice +import libgrpcpp_c.pb_decode_greeter_sayhello_response +import kotlin.native.runtime.GC +import kotlin.native.runtime.NativeRuntimeApi +import kotlin.test.Test +import kotlin.test.fail + +@OptIn(ExperimentalForeignApi::class) +class BridgeTest { + + @OptIn(NativeRuntimeApi::class) + @Test + fun testBasicUnaryAsyncCall() = runBlocking { + try { + val client = GrpcClient("localhost:50051") + val request = GrpcSlice(byteArrayOf(8, 4)) + val reqBuf = GrpcByteBuffer(request) + val result = client.callUnary("/Greeter/SayHello", reqBuf) + val response = result.intoSlice() + val value = pb_decode_greeter_sayhello_response(response.cSlice) + println("Response received: $value") + } catch (e: Exception) { + // trigger GC collection, otherwise there will be a leak + GC.collect() + fail("Got an exception: ${e.message}", e) + } + } + +} diff --git a/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/internal/WireCodecTest.kt b/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/internal/WireCodecTest.kt new file mode 100644 index 000000000..d93132fc1 --- /dev/null +++ b/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/internal/WireCodecTest.kt @@ -0,0 +1,800 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.io.Buffer +import kotlin.experimental.ExperimentalNativeApi +import kotlin.test.* + +// TODO: Move this to the commonTest +@OptIn(ExperimentalForeignApi::class, ExperimentalNativeApi::class) +class WireCodecTest { + + @Test + fun testBoolEncodeDecode() { + val fieldNr = 3 + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeBool(fieldNr, true)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertFalse(decoder.hadError()) + assertNotNull(tag) + assertEquals(WireType.VARINT, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readBool() + assertNotNull(value) + assertTrue(value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testInt32EncodeDecode() { + val fieldNr = 5 + val testValue = 42 + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeInt32(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.VARINT, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readInt32() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testInt64EncodeDecode() { + val fieldNr = 6 + val testValue = Long.MAX_VALUE + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeInt64(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.VARINT, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readInt64() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testUInt32EncodeDecode() { + val fieldNr = 7 + val testValue = UInt.MAX_VALUE + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeUInt32(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.VARINT, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readUInt32() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testUInt64EncodeDecode() { + val fieldNr = 8 + val testValue = ULong.MAX_VALUE + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeUInt64(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.VARINT, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readUInt64() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testSInt32EncodeDecode() { + val fieldNr = 9 + val testValue = Int.MIN_VALUE + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeSInt32(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.VARINT, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readSInt32() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testSInt64EncodeDecode() { + val fieldNr = 10 + val testValue = Long.MIN_VALUE // Min long value + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeSInt64(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.VARINT, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readSInt64() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testFixed32EncodeDecode() { + val fieldNr = 11 + val testValue = UInt.MAX_VALUE + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeFixed32(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.FIXED32, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readFixed32() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testFixed64EncodeDecode() { + val fieldNr = 12 + val testValue = ULong.MAX_VALUE + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeFixed64(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.FIXED64, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readFixed64() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testSFixed32EncodeDecode() { + val fieldNr = 13 + val testValue = Int.MIN_VALUE + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeSFixed32(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.FIXED32, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readSFixed32() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testSFixed64EncodeDecode() { + val fieldNr = 14 + val testValue = Long.MIN_VALUE + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeSFixed64(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.FIXED64, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readSFixed64() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testEnumEncodeDecode() { + val fieldNr = 15 + val testValue = 42 + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeEnum(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.VARINT, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readEnum() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testStringEncodeDecode() { + val fieldNr = 16 + val testValue = "Hello, World!" + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeString(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.LENGTH_DELIMITED, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readString() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testEmptyBufferDecoding() { + val buffer = Buffer() + + val decoder = WireDecoder(buffer) + decoder.readTag() + assertTrue(decoder.hadError()) + } + + @Test + fun testMissingFlush() { + val fieldNr = 17 + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + encoder.writeBool(fieldNr, true) + // Intentionally not calling flush() + + // The data is not being written to the buffer yet + val decoder = WireDecoder(buffer) + assertNull(decoder.readTag()) + decoder.close() + } + + @Test + fun testMultipleFieldsEncodeDecode() { + val buffer = Buffer() + val encoder = WireEncoder(buffer) + + // Write multiple fields of different types + assertTrue(encoder.writeBool(1, true)) + assertTrue(encoder.writeInt32(2, 42)) + assertTrue(encoder.writeString(3, "Hello")) + assertTrue(encoder.writeFixed64(4, 123456789uL)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + // Read and verify each field + val tag1 = decoder.readTag() + assertNotNull(tag1) + assertEquals(1, tag1.fieldNr) + assertEquals(WireType.VARINT, tag1.wireType) + val bool = decoder.readBool() + assertNotNull(bool) + assertTrue(bool) + + val tag2 = decoder.readTag() + assertNotNull(tag2) + assertEquals(2, tag2.fieldNr) + assertEquals(WireType.VARINT, tag2.wireType) + val int32 = decoder.readInt32() + assertNotNull(int32) + assertEquals(42, int32) + + val tag3 = decoder.readTag() + assertNotNull(tag3) + assertEquals(3, tag3.fieldNr) + assertEquals(WireType.LENGTH_DELIMITED, tag3.wireType) + val string = decoder.readString() + assertNotNull(string) + assertEquals("Hello", string) + + val tag4 = decoder.readTag() + assertNotNull(tag4) + assertEquals(4, tag4.fieldNr) + assertEquals(WireType.FIXED64, tag4.wireType) + val fixed64 = decoder.readFixed64() + assertNotNull(fixed64) + assertEquals(123456789uL, fixed64) + + // No more tags + assertNull(decoder.readTag()) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testReadAfterClose() { + val fieldNr = 19 + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeBool(fieldNr, true)) + encoder.flush() + + val decoder = WireDecoder(buffer) + decoder.close() + + // Reading after close should either return null or throw an exception + try { + val tag = decoder.readTag() + assertNull(tag) + } catch (e: Exception) { + // Expected exception in some implementations + } + } + + @Test + fun testWriteAfterFlush() { + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeBool(1, true)) + encoder.flush() + + // Writing after flush should still work + assertTrue(encoder.writeInt32(2, 42)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + // Verify both values were written + val tag1 = decoder.readTag() + assertNotNull(tag1) + assertEquals(1, tag1.fieldNr) + val bool = decoder.readBool() + assertNotNull(bool) + assertTrue(bool) + + val tag2 = decoder.readTag() + assertNotNull(tag2) + assertEquals(2, tag2.fieldNr) + val int32 = decoder.readInt32() + assertNotNull(int32) + assertEquals(42, int32) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testUnicodeStringEncodeDecode() { + val fieldNr = 20 + val testValue = "Hello, δΈ–η•Œ! 😊" + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeString(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.LENGTH_DELIMITED, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readString() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testBufferNotExhausted() { + val fieldNr = 1 + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeBool(fieldNr, true)) + assertTrue(encoder.writeBool(fieldNr + 1, true)) + encoder.flush() + + WireDecoder(buffer).use { decoder -> + decoder.readTag() + assertNotNull(decoder.readString()) + } + assertFalse(buffer.exhausted()) + } + + @Test + fun testBufferUsedByMultipleDecoders() { + val buffer = Buffer() + + val field1Nr = 1 + val field2Nr = 2 + val field1Str = "a".repeat(1000000) + val field2Str = "b".repeat(1000000) + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeString(field1Nr, field1Str)) + assertTrue(encoder.writeString(field2Nr, field2Str)) + encoder.flush() + + WireDecoder(buffer).use { decoder -> + val tag = decoder.readTag() + assertEquals(field1Nr, tag?.fieldNr) + assertEquals(field1Str, decoder.readString()) + } + assertFalse(buffer.exhausted()) + + WireDecoder(buffer).use { decoder -> + val tag = decoder.readTag() + assertEquals(field2Nr, tag?.fieldNr) + assertEquals(field2Str, decoder.readString()) + } + assertTrue(buffer.exhausted()) + } + + @Test + fun testEmptyString() { + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeString(1, "")) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(1, tag.fieldNr) + assertEquals(WireType.LENGTH_DELIMITED, tag.wireType) + + val str = decoder.readString() + assertNotNull(str) + assertEquals("", str) + } + + @Test + fun testEmptyByteArray() { + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeBytes(1, ByteArray(0))) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(1, tag.fieldNr) + assertEquals(WireType.LENGTH_DELIMITED, tag.wireType) + + val bytes = decoder.readBytes() + assertNotNull(bytes) + assertEquals(0, bytes.size) + } + + @Test + fun testBytesEncodeDecode() { + val buffer = Buffer() + val encoder = WireEncoder(buffer) + + val bytes = ByteArray(1000000) { it.toByte() } + + assertTrue(encoder.writeBytes(1, bytes)) + encoder.flush() + + val decoder = WireDecoder(buffer) + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(1, tag.fieldNr) + assertEquals(WireType.LENGTH_DELIMITED, tag.wireType) + + val actualBytes = decoder.readBytes() + assertNotNull(actualBytes) + assertEquals(1000000, actualBytes.size) + assertTrue(bytes.contentEquals(actualBytes)) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testDoubleEncodeDecode() { + val fieldNr = 21 + val testValue = 3.14159265359 + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeDouble(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.FIXED64, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readDouble() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testFloatEncodeDecode() { + val fieldNr = 22 + val testValue = 3.14159f + val buffer = Buffer() + + val encoder = WireEncoder(buffer) + assertTrue(encoder.writeFloat(fieldNr, testValue)) + encoder.flush() + + val decoder = WireDecoder(buffer) + + val tag = decoder.readTag() + assertNotNull(tag) + assertEquals(WireType.FIXED32, tag.wireType) + assertEquals(fieldNr, tag.fieldNr) + + val value = decoder.readFloat() + assertNotNull(value) + assertEquals(testValue, value) + + decoder.close() + assertTrue(buffer.exhausted()) + } + + private fun runPackedFixedTest( + list: List, + write: WireEncoder.(Int, List) -> Boolean, + read: WireDecoder.() -> List?, + ) { + val buf = Buffer() + with(WireEncoder(buf)) { + assertTrue(write(1, list)) + flush() + } + WireDecoder(buf).use { dec -> + dec.readTag()!!.apply { + assertEquals(1, fieldNr) + assertEquals(WireType.LENGTH_DELIMITED, wireType) + } + val test = dec.read() + assertEquals(list, test) + } + assertTrue(buf.exhausted()) + } + + @Test + fun testPackedFixed32() = runPackedFixedTest( + List(1_000_000) { UInt.MAX_VALUE + it.toUInt() }, + WireEncoder::writePackedFixed32, + WireDecoder::readPackedFixed32 + ) + + @Test + fun testPackedFixed64() = runPackedFixedTest( + List(1_000_000) { UInt.MAX_VALUE + it.toULong() }, + WireEncoder::writePackedFixed64, + WireDecoder::readPackedFixed64, + ) + + @Test + fun testPackedSFixed32() = runPackedFixedTest( + List(1_000_000) { Int.MAX_VALUE + it }, + WireEncoder::writePackedSFixed32, + WireDecoder::readPackedSFixed32 + ) + + @Test + fun testPackedSFixed64() = runPackedFixedTest( + List(1_000_000) { Long.MAX_VALUE + it }, + WireEncoder::writePackedSFixed64, + WireDecoder::readPackedSFixed64 + ) + + @Test + fun testPackedFloat() = runPackedFixedTest( + List(1_000_000) { it.toFloat() / 3.3f * ((it and 1) * 2 - 1) }, + WireEncoder::writePackedFloat, + WireDecoder::readPackedFloat, + ) + + @Test + fun testPackedDouble() = runPackedFixedTest( + List(1_000_000) { it.toDouble() / 3.3 * ((it and 1) * 2 - 1) }, + WireEncoder::writePackedDouble, + WireDecoder::readPackedDouble, + ) + + private fun runPackedVarTest( + list: List, + sizeFn: (List) -> UInt, + write: WireEncoder.(Int, List, Int) -> Boolean, + read: WireDecoder.() -> List?, + ) { + val buf = Buffer() + with(WireEncoder(buf)) { + assertTrue(write(1, list, sizeFn(list).toInt())) + flush() + } + WireDecoder(buf).use { dec -> + dec.readTag()!!.apply { + assertEquals(1, fieldNr) + assertEquals(WireType.LENGTH_DELIMITED, wireType) + } + val test = dec.read() + assertEquals(list, test) + } + assertTrue(buf.exhausted()) + } + + @Test + fun testPackedInt32() = runPackedVarTest( + List(1_000_000) { Int.MAX_VALUE + it }, + WireSize::packedInt32, + WireEncoder::writePackedInt32, + WireDecoder::readPackedInt32 + ) + + @Test + fun testPackedInt64() = runPackedVarTest( + List(1_000_000) { Long.MAX_VALUE + it.toLong() }, + WireSize::packedInt64, + WireEncoder::writePackedInt64, + WireDecoder::readPackedInt64 + ) + + @Test + fun testPackedUInt32() = runPackedVarTest( + List(1_000_000) { UInt.MAX_VALUE + it.toUInt() }, + WireSize::packedUInt32, + WireEncoder::writePackedUInt32, + WireDecoder::readPackedUInt32 + ) + + @Test + fun testPackedUInt64() = runPackedVarTest( + List(1_000_000) { ULong.MAX_VALUE + it.toULong() }, + WireSize::packedUInt64, + WireEncoder::writePackedUInt64, + WireDecoder::readPackedUInt64 + ) + + @Test + fun testPackedSInt32() = runPackedVarTest( + List(1_000_000) { Int.MAX_VALUE + it }, + WireSize::packedSInt32, + WireEncoder::writePackedSInt32, + WireDecoder::readPackedSInt32 + ) + + @Test + fun testPackedSInt64() = runPackedVarTest( + List(1_000_000) { Long.MAX_VALUE + it.toLong() }, + WireSize::packedSInt64, + WireEncoder::writePackedSInt64, + WireDecoder::readPackedSInt64 + ) + + @Test + fun testPackedEnum() = runPackedVarTest( + List(1_000_000) { it }, + WireSize::packedEnum, + WireEncoder::writePackedEnum, + WireDecoder::readPackedEnum + ) + +} diff --git a/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/internal/ZeroCopyInputSourceTest.kt b/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/internal/ZeroCopyInputSourceTest.kt new file mode 100644 index 000000000..234d0941a --- /dev/null +++ b/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/internal/ZeroCopyInputSourceTest.kt @@ -0,0 +1,330 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.cinterop.ByteVar +import kotlinx.cinterop.CPointerVar +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.cinterop.IntVar +import kotlinx.cinterop.addressOf +import kotlinx.cinterop.alloc +import kotlinx.cinterop.memScoped +import kotlinx.cinterop.ptr +import kotlinx.cinterop.usePinned +import kotlinx.cinterop.value +import kotlinx.io.Buffer +import platform.posix.memcpy +import kotlin.experimental.ExperimentalNativeApi +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotEquals +import kotlin.test.assertTrue +import kotlin.test.fail + +@OptIn(ExperimentalForeignApi::class, ExperimentalNativeApi::class) +class ZeroCopyInputSourceTest { + + @Test + fun testEmptyBuffer() { + val buffer = Buffer() + val source = ZeroCopyInputSource(buffer) + + // next() should return false for empty buffer + memScoped { + val data = alloc>() + val size = alloc() + assertFalse(source.next(data.ptr, size.ptr)) + } + + // byteCount should be 0 + assertEquals(0, source.byteCount()) + + // skip should return false for empty buffer + assertFalse(source.skip(10)) + + // close should work without errors + source.close() + } + + @Test + fun testNextMethod() { + val buffer = Buffer() + val testData = ByteArray(100) { it.toByte() } + buffer.write(testData) + + val source = ZeroCopyInputSource(buffer) + + // First next() call should return true and provide data + val firstRead = source.nextIntoArray() + assertEquals(100, firstRead.size) + assertTrue(firstRead.contentEquals(testData)) + + // Second next() call should return false (buffer exhausted) + val secondRead = source.nextIntoArray() + assertEquals(0, secondRead.size) + + // byteCount should reflect the bytes read + assertEquals(100, source.byteCount()) + + source.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testBackUpMethod() { + val buffer = Buffer() + val testData = ByteArray(100) { it.toByte() } + buffer.write(testData) + + val source = ZeroCopyInputSource(buffer) + + // Read all data + val firstRead = source.nextIntoArray() + assertEquals(100, firstRead.size) + + // Back up 20 bytes + source.backUp(20) + + // byteCount should be reduced by backup amount + assertEquals(80, source.byteCount()) + + // Next read should return the backed-up bytes + val secondRead = source.nextIntoArray() + assertEquals(20, secondRead.size) + + // Verify the backed-up bytes are correct (last 20 bytes of original data) + for (i in 0 until 20) { + assertEquals(testData[80 + i], secondRead[i]) + } + + // Buffer should be exhausted now + val thirdRead = source.nextIntoArray() + assertEquals(0, thirdRead.size) + + source.close() + assertTrue(buffer.exhausted()) + } + + @Test + fun testInvalidBackUpSequence() { + val buffer = Buffer() + buffer.write(ByteArray(10)) + val source = ZeroCopyInputSource(buffer) + + // Calling backUp() without a preceding next() should throw + assertFailsWith { + source.backUp(5) + } + + source.close() + } + + @Test + fun testSkipMethod() { + val buffer = Buffer() + val testData = ByteArray(100) { it.toByte() } + buffer.write(testData) + + val source = ZeroCopyInputSource(buffer) + + // Skip 30 bytes + assertTrue(source.skip(30)) + assertEquals(30, source.byteCount()) + + // Reading all left segments + val allLeftBytes = assertNBytesLeft(source, 70) + + // Verify we're reading from the correct position + for (i in 0 until 70) { + assertEquals(testData[30 + i], allLeftBytes.readByte()) + } + + assertEquals(100, source.byteCount()) + + // Skip beyond the end should return false + assertFalse(source.skip(10)) + + source.close() + } + + @Test + fun testMultipleSegments() { + val buffer = Buffer() + // Create multiple segments by writing small chunks + for (i in 0 until 100) { + buffer.write(ByteArray(100) { (i * 100 + it).toByte() }) + } + + val source = ZeroCopyInputSource(buffer) + val allData = mutableListOf() + + // Read all segments + var segmentCount = 0 + while (true) { + val segment = source.nextIntoArray() + if (segment.isEmpty()) break + segmentCount++ + allData.addAll(segment.toList()) + } + + // assert there were more than 1 segment; otherwise the test is useless + assertTrue(segmentCount > 1) + + // Verify we read all 100 bytes + assertEquals(100 * 100, allData.size) + assertEquals(100 * 100, source.byteCount()) + + // Verify the data is correct + for (i in 0 until 100 * 100) { + assertEquals(i.toByte(), allData[i]) + } + + source.close() + } + + @Test + fun testCloseMethod() { + val buffer = Buffer() + val testData = ByteArray(100) { it.toByte() } + buffer.write(testData) + + val source = ZeroCopyInputSource(buffer) + + // Read the data from source + source.nextIntoArray() + // Back up 20 bytes which have to be available in the original buffer after closing + source.backUp(20) + + // Close the source + source.close() + + // After closing, the buffer should be valid for reading + assertFalse(buffer.exhausted()) + assertEquals(20, buffer.size) + + // Original buffer should contain last 20 bytes of test data + for (i in 0 until 20) { + assertEquals(testData[80 + i], buffer.readByte()) + } + + + // But the source should not be usable + assertFailsWith(message = "ZeroCopyInputSource has already been closed.") { + memScoped { + val data = alloc>() + val size = alloc() + source.next(data.ptr, size.ptr) + fail("Should not be able to use ZeroCopyInputSource after closing") + } + } + } + + @Test + fun testOutOfBoundsBackup() { + val buffer = Buffer() + val testData = ByteArray(100) { it.toByte() } + buffer.write(testData) + + val source = ZeroCopyInputSource(buffer) + + // Read all data + val read = source.nextIntoArray() + assertEquals(100, read.size) + + // Try to back up more bytes than we read + assertFailsWith(message = + "backUp() must not be called more than the number of bytes that were read in next()" + ) { source.backUp(200) } + + source.close() + } + + @Test + fun testMultiChunkConsistency() { + val buffer = Buffer() + fillWithChunks(buffer, 1000, 10) + val total = 1000L * 10 + + val source = ZeroCopyInputSource(buffer) + val seg1 = source.nextIntoArray() + + assertEquals(seg1.size.toLong(), source.byteCount()) + + source.close() + + assertEquals(total - seg1.size, buffer.size) + } + + @Test + fun testMultiChunkBackupConsistency() { + val buffer = Buffer() + fillWithChunks(buffer, 1000, 10) + val total = 1000L * 10 + + val source = ZeroCopyInputSource(buffer) + val seg1 = source.nextIntoArray() + + source.backUp(100) + + assertEquals(seg1.size.toLong() - 100, source.byteCount()) + + source.close() + + assertEquals(total - seg1.size + 100, buffer.size) + } + + @Test + fun testMultiChunkBackup() { + val buffer = Buffer() + fillWithChunks(buffer, 1000, 10) + val total = 1000L * 10 + + val source = ZeroCopyInputSource(buffer) + val seg1 = source.nextIntoArray() + + assertEquals(source.byteCount(), seg1.size.toLong()) + + source.close() + + assertEquals(total - seg1.size, buffer.size) + } + + private fun fillWithChunks(buffer: Buffer, numberOfChunks: Int, chunkSize: Int) { + repeat(numberOfChunks) { i -> + buffer.write(ByteArray(chunkSize) { i.toByte() }) + } + } + +} + +private fun assertNBytesLeft(source: ZeroCopyInputSource, n: Long): Buffer { + // Reading all left segments + val combined = Buffer() + while (combined.size < n) { + val read = source.nextIntoArray() + assertNotEquals(0, read.size) + combined.write(read) + } + assertEquals(n, combined.size) + return combined +} + +@OptIn(ExperimentalForeignApi::class) +private fun ZeroCopyInputSource.nextIntoArray(): ByteArray = memScoped { + val data = alloc>() + val size = alloc() + + if (!next(data.ptr, size.ptr)) { + return ByteArray(0) + } + + val result = ByteArray(size.value) + result.usePinned { + memcpy(it.addressOf(0), data.value, size.value.toULong()) + } + result +} \ No newline at end of file diff --git a/grpc/grpc-core/src/wasmJsMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.wasmJs.kt b/grpc/grpc-core/src/wasmJsMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.wasmJs.kt new file mode 100644 index 000000000..f56885157 --- /dev/null +++ b/grpc/grpc-core/src/wasmJsMain/kotlin/kotlinx/rpc/grpc/internal/WireDecoder.wasmJs.kt @@ -0,0 +1,12 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.io.Buffer +import kotlinx.io.Source + +internal actual fun WireDecoder(source: Buffer): WireDecoder { + TODO("Not yet implemented") +} \ No newline at end of file diff --git a/grpc/grpc-core/src/wasmJsMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.wasmJs.kt b/grpc/grpc-core/src/wasmJsMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.wasmJs.kt new file mode 100644 index 000000000..00c0b3246 --- /dev/null +++ b/grpc/grpc-core/src/wasmJsMain/kotlin/kotlinx/rpc/grpc/internal/WireEncoder.wasmJs.kt @@ -0,0 +1,11 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.io.Sink + +internal actual fun WireEncoder(sink: Sink): WireEncoder { + TODO("Not yet implemented") +} \ No newline at end of file diff --git a/grpc/grpc-core/src/wasmJsMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.wasmJs.kt b/grpc/grpc-core/src/wasmJsMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.wasmJs.kt new file mode 100644 index 000000000..e70b9bd00 --- /dev/null +++ b/grpc/grpc-core/src/wasmJsMain/kotlin/kotlinx/rpc/grpc/internal/WireSize.wasmJs.kt @@ -0,0 +1,29 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +internal actual fun WireSize.int32(value: Int): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.int64(value: Long): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.uInt32(value: UInt): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.uInt64(value: ULong): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.sInt32(value: Int): UInt { + TODO("Not yet implemented") +} + +internal actual fun WireSize.sInt64(value: Long): UInt { + TODO("Not yet implemented") +} \ No newline at end of file diff --git a/grpc/grpc-ktor-server/gradle.properties b/grpc/grpc-ktor-server/gradle.properties index b68c20f8d..1c532ca1e 100644 --- a/grpc/grpc-ktor-server/gradle.properties +++ b/grpc/grpc-ktor-server/gradle.properties @@ -3,3 +3,18 @@ # kotlinx.rpc.exclude.wasmWasi=true +kotlinx.rpc.exclude.iosArm64=true +kotlinx.rpc.exclude.iosX64=true +kotlinx.rpc.exclude.iosSimulatorArm64=true +kotlinx.rpc.exclude.linuxArm64=true +kotlinx.rpc.exclude.linuxX64=true +kotlinx.rpc.exclude.macosX64=true +kotlinx.rpc.exclude.mingwX64=true +kotlinx.rpc.exclude.tvosArm64=true +kotlinx.rpc.exclude.tvosSimulatorArm64=true +kotlinx.rpc.exclude.tvosX64=true +kotlinx.rpc.exclude.watchosArm32=true +kotlinx.rpc.exclude.watchosArm64=true +kotlinx.rpc.exclude.watchosDeviceArm64=true +kotlinx.rpc.exclude.watchosSimulatorArm64=true +kotlinx.rpc.exclude.watchosX64=true diff --git a/grpc/grpcpp-c/BUILD.bazel b/grpc/grpcpp-c/BUILD.bazel index 9bdb4e927..4e3157af4 100644 --- a/grpc/grpcpp-c/BUILD.bazel +++ b/grpc/grpcpp-c/BUILD.bazel @@ -1,5 +1,20 @@ load("@rules_cc//cc:defs.bzl", "cc_library") +cc_binary( + name = "testdemo", + srcs = ["src/main.cpp"], + deps = [ + ":protowire", + ], +) + +cc_static_library( + name = "grpcpp_c_static", + deps = [ + ":grpcpp_c", + ], +) + cc_library( name = "grpcpp_c", srcs = ["src/grpcpp_c.cpp"], @@ -8,14 +23,30 @@ cc_library( includes = ["include"], visibility = ["//visibility:public"], deps = [ + # TODO: Reduce the dependencies and only use required once. KRPC-185 + "@com_github_grpc_grpc//:channelz", + "@com_github_grpc_grpc//:generic_stub", "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc_credentials_util", "@com_google_protobuf//:protobuf", ], ) cc_static_library( - name = "grpcpp_c_static", + name = "protowire_static", + deps = [ + ":protowire", + ], +) + +cc_library( + name = "protowire", + srcs = ["src/protowire.cpp"], + hdrs = glob(["include/*.h"]), + copts = ["-std=c++20"], + includes = ["include"], + visibility = ["//visibility:public"], deps = [ - "grpcpp_c", + "@com_google_protobuf//:protobuf_lite", ], ) diff --git a/grpc/grpcpp-c/include/protowire.h b/grpc/grpcpp-c/include/protowire.h new file mode 100644 index 000000000..e53d3f41a --- /dev/null +++ b/grpc/grpcpp-c/include/protowire.h @@ -0,0 +1,155 @@ +#ifndef PROTOWIRE_H +#define PROTOWIRE_H + +#include +#include + +// Defines the C wrapper around the C++ Wire Format encoding/decoding API +// (WireFormatLite, Coded(Input|Output)Stream, ZeroCopyInputStream, CopingOutputStream) +#ifdef __cplusplus +extern "C" { +#endif + + //// STD::STRING WRAPPER //// + + // A std::string wrapper that helps reduce copies when C++ api returns std::strings + typedef struct pw_string pw_string_t; + + pw_string_t * pw_string_new(const char *str); + void pw_string_delete(pw_string_t *self); + const char * pw_string_c_str(pw_string_t *self); + + //// WIRE ENCODER //// + + typedef struct pw_encoder pw_encoder_t; + + /** + * Create a new pw_encoder_t that wraps a CodedOutputStream to encode values into a wire format stream. + * + * @param ctx a stable pointer to a Kotlin managed object, used by the K/N sink callback to access Kotlin objects. + * @param sink_fn the K/N callback function to write encoded data into the kotlinx.io.Sink. + */ + pw_encoder_t *pw_encoder_new(void* ctx, bool (*sink_fn)(void* ctx, const void* buf, int size)); + void pw_encoder_delete(pw_encoder_t *self); + bool pw_encoder_flush(pw_encoder_t *self); + + bool pw_encoder_write_bool(pw_encoder_t *self, int field_no, bool value); + bool pw_encoder_write_int32(pw_encoder_t *self, int field_no, int32_t value); + bool pw_encoder_write_int64(pw_encoder_t *self, int field_no, int64_t value); + bool pw_encoder_write_uint32(pw_encoder_t *self, int field_no, uint32_t value); + bool pw_encoder_write_uint64(pw_encoder_t *self, int field_no, uint64_t value); + bool pw_encoder_write_sint32(pw_encoder_t *self, int field_no, int32_t value); + bool pw_encoder_write_sint64(pw_encoder_t *self, int field_no, int64_t value); + bool pw_encoder_write_fixed32(pw_encoder_t *self, int field_no, uint32_t value); + bool pw_encoder_write_fixed64(pw_encoder_t *self, int field_no, uint64_t value); + bool pw_encoder_write_sfixed32(pw_encoder_t *self, int field_no, int32_t value); + bool pw_encoder_write_sfixed64(pw_encoder_t *self, int field_no, int64_t value); + bool pw_encoder_write_float(pw_encoder_t *self, int field_no, float value); + bool pw_encoder_write_double(pw_encoder_t *self, int field_no, double value); + bool pw_encoder_write_enum(pw_encoder_t *self, int field_no, int value); + bool pw_encoder_write_string(pw_encoder_t *self, int field_no, const char *data, int size); + bool pw_encoder_write_bytes(pw_encoder_t *self, int field_no, const void *data, int size); + + // No tag writers + bool pw_encoder_write_tag(pw_encoder_t *self, int field_no, int wire_type); + bool pw_encoder_write_bool_no_tag(pw_encoder_t *self, bool value); + bool pw_encoder_write_int32_no_tag(pw_encoder_t *self, int32_t value); + bool pw_encoder_write_int64_no_tag(pw_encoder_t *self, int64_t value); + bool pw_encoder_write_uint32_no_tag(pw_encoder_t *self, uint32_t value); + bool pw_encoder_write_uint64_no_tag(pw_encoder_t *self, uint64_t value); + bool pw_encoder_write_sint32_no_tag(pw_encoder_t *self, int32_t value); + bool pw_encoder_write_sint64_no_tag(pw_encoder_t *self, int64_t value); + bool pw_encoder_write_fixed32_no_tag(pw_encoder_t *self, uint32_t value); + bool pw_encoder_write_fixed64_no_tag(pw_encoder_t *self, uint64_t value); + bool pw_encoder_write_sfixed32_no_tag(pw_encoder_t *self, int32_t value); + bool pw_encoder_write_sfixed64_no_tag(pw_encoder_t *self, int64_t value); + bool pw_encoder_write_float_no_tag(pw_encoder_t *self, float value); + bool pw_encoder_write_double_no_tag(pw_encoder_t *self, double value); + + //// WIRE DECODER //// + + typedef struct pw_decoder pw_decoder_t; + + /** + * Holds callbacks corresponding to the methods of a ZeroCopyInputStream. + * They are called to retrieve data from the K/N side with a minimal number of copies. + * + * For method documentation see the ZeroCopyInputStream (C++) interface and the ZeroCopyInputSource (Kotlin) class. + */ + typedef struct pw_zero_copy_input { + void *ctx; + bool (*next)(void *ctx, const void **data, int *size); + void (*backUp)(void *ctx, int size); + bool (*skip)(void *ctx, int size); + int64_t (*byteCount)(void *ctx); + } pw_zero_copy_input_t; + + + /** + * Create a new pw_decoder_t that wraps a CodedInputStream to decode values from a wire format stream. + * + * @param zero_copy_input holds callbacks to the K/N side, matching the ZeroCopyInputStream interface. + */ + pw_decoder_t * pw_decoder_new(pw_zero_copy_input_t zero_copy_input); + void pw_decoder_delete(pw_decoder_t *self); + void pw_decoder_close(pw_decoder_t *self); + + uint32_t pw_decoder_read_tag(pw_decoder_t *self); + bool pw_decoder_read_bool(pw_decoder_t *self, bool *value); + bool pw_decoder_read_int32(pw_decoder_t *self, int32_t *value); + bool pw_decoder_read_int64(pw_decoder_t *self, int64_t *value); + bool pw_decoder_read_uint32(pw_decoder_t *self, uint32_t *value); + bool pw_decoder_read_uint64(pw_decoder_t *self, uint64_t *value); + bool pw_decoder_read_sint32(pw_decoder_t *self, int32_t *value); + bool pw_decoder_read_sint64(pw_decoder_t *self, int64_t *value); + bool pw_decoder_read_fixed32(pw_decoder_t *self, uint32_t *value); + bool pw_decoder_read_fixed64(pw_decoder_t *self, uint64_t *value); + bool pw_decoder_read_sfixed32(pw_decoder_t *self, int32_t *value); + bool pw_decoder_read_sfixed64(pw_decoder_t *self, int64_t *value); + bool pw_decoder_read_float(pw_decoder_t *self, float *value); + bool pw_decoder_read_double(pw_decoder_t *self, double *value); + bool pw_decoder_read_enum(pw_decoder_t *self, int *value); + bool pw_decoder_read_string(pw_decoder_t *self, pw_string_t **opaque_string); + // To read an actual bytes field, you must combine read_int32 and this function + bool pw_decoder_read_raw_bytes(pw_decoder_t *self, void* buffer, int size); + + /** + * Pushes the limit of the underlying pb::io::CodedStream to a certain value. + * + * This is required for reading packed fields that don't have fixed size (like int32). + * In this case, the user must push the limit by the length of the decoded LEN and read until + * the limit is reached (which indicates that the end of the repeated field is reached). + */ + int pw_decoder_push_limit(pw_decoder_t *self, int limit); + /** + * Resets the limit previously pushed with pw_decoder_push_limit. + * The limit argument must be the value returned by pw_decoder_push_limit. + * + * This is typically used called after the user reached the end of a packed field and wants to + * reset the stream to the state before the read started. + */ + void pw_decoder_pop_limit(pw_decoder_t *self, int limit); + /** + * Returns the number of bytes until the limit of the underlying pb::io::CodedStream is reached. + * + * This is used to know when to stop reading a packed field. It must be used in combination with + * pw_decoder_push_limit and pw_decoder_pop_limit. + */ + int pw_decoder_bytes_until_limit(pw_decoder_t *self); + + + /// Size Calculation Functions /// + + uint32_t pw_size_int32(int32_t value); + uint32_t pw_size_int64(int64_t value); + uint32_t pw_size_uint32(uint32_t value); + uint32_t pw_size_uint64(uint64_t value); + uint32_t pw_size_sint32(int32_t value); + uint32_t pw_size_sint64(int64_t value); + + +#ifdef __cplusplus + } +#endif + +#endif //PROTOWIRE_H diff --git a/grpc/grpcpp-c/src/protowire.cpp b/grpc/grpcpp-c/src/protowire.cpp new file mode 100644 index 000000000..50700b47d --- /dev/null +++ b/grpc/grpcpp-c/src/protowire.cpp @@ -0,0 +1,293 @@ +// +// Created by Johannes Zottele on 17.07.25. +// + +#include "protowire.h" + +#include + +#include "src/google/protobuf/io/zero_copy_stream_impl_lite.h" +#include "src/google/protobuf/io/coded_stream.h" +#include "src/google/protobuf/wire_format_lite.h" + +namespace pb = google::protobuf; +typedef pb::internal::WireFormatLite WireFormatLite; + +namespace protowire { + /** + * A bridge that passes write calls to the K/N side, to write the data into a kotlinx.io.Sink. + * + * This reduces the amount of copying, as the callback on the K/N may directly use the + * buffer pointer to copy the whole chunk at into the stream. + */ + class SinkStream final : public pb::io::CopyingOutputStream { + public: + /** + * Constructs the stream with a ctx pointer and a callback to the K/N side. + * The ctx pointer is used to on the K/N to reference a Kotlin managed object + * from within its static callback function. + * + * @param thisRef the context used by the K/N side to reference Kotlin managed objects. + * @param sink_fn the K/N callback to write data into the sink + */ + SinkStream(void *thisRef, bool(*sink_fn)(void *ctx, const void *buffer, int size)) + : ctx(thisRef), + sink_fn(sink_fn) { + } + + bool Write(const void *buffer, int size) override { + return sink_fn(ctx, buffer, size); + } + + private: + void *ctx; + bool (*sink_fn)(void *ctx, const void *buffer, int size); + }; + + /** + * A bridge that passes read calls to the K/N side, to read data from a kotlinx.io.Buffer. + * + * This allows efficient data reading from the K/N side buffer, as it allows + * directly accessing continuous memory blocks from within the buffer, instead of copying them + * via C-Interop. + * + * All ZeroCopyInputStream methods are delegated to the K/N call back functions, hold in + * the pw_zero_copy_input_t. + */ + class BufferSourceStream final : public pb::io::ZeroCopyInputStream { + public: + /** + * Constructs the BufferSourceStream to access kotlinx.io.Buffer segments directly, without + * copying them via C-Interop. + * + * @param input a struct containing K/N callbacks for all methods of the ZeroCopyInputStream. + */ + explicit BufferSourceStream(const pw_zero_copy_input_t &input) + : input(input) { + } + + bool Next(const void **data, int *size) override { + auto result = input.next(input.ctx, data, size); + return result; + }; + + void BackUp(int count) override { + return input.backUp(input.ctx, count); + }; + + bool Skip(int count) override { + return input.skip(input.ctx, count); + }; + + int64_t ByteCount() const override { + return input.byteCount(input.ctx); + }; + + private: + pw_zero_copy_input_t input; + }; + +} + +struct pw_string { + std::string str; +}; + +struct pw_encoder { + protowire::SinkStream sinkStream; + pb::io::CopyingOutputStreamAdaptor copyingOutputStreamAdaptor; + pb::io::CodedOutputStream codedOutputStream; + + explicit pw_encoder(protowire::SinkStream sink) + : sinkStream(std::move(sink)), + copyingOutputStreamAdaptor(&sinkStream), + codedOutputStream(©ingOutputStreamAdaptor) { + codedOutputStream.EnableAliasing(true); + } +}; + +struct pw_decoder { + protowire::BufferSourceStream bufferSourceStream; + pb::io::CodedInputStream codedInputStream; + + explicit pw_decoder(pw_zero_copy_input_t input) + : bufferSourceStream(input), + codedInputStream(&bufferSourceStream) {} +}; + + +extern "C" { + + /// STD::STRING WRAPPER IMPLEMENTATION /// + + pw_string_t *pw_string_new(const char *str) { + return new pw_string_t{str }; + } + void pw_string_delete(pw_string_t *self) { + delete self; + } + const char *pw_string_c_str(pw_string_t *self) { + return self->str.c_str(); + } + + /// ENCODER IMPLEMENTATION /// + + pw_encoder_t *pw_encoder_new(void* ctx, bool (* sink_fn)(void* ctx, const void* buf, int size)) { + auto sink = protowire::SinkStream(ctx, sink_fn); + return new pw_encoder(std::move(sink)); + } + + void pw_encoder_delete(pw_encoder_t *self) { + delete self; + } + bool pw_encoder_flush(pw_encoder_t *self) { + self->codedOutputStream.Trim(); + if (!self->copyingOutputStreamAdaptor.Flush()) { + return false; + } + return !self->codedOutputStream.HadError(); + } + + // check that there was no error + static bool check(pw_encoder_t *self) { + return !self->codedOutputStream.HadError(); + } + +#define WRITE_FIELD_FUNC( funcSuffix, wireTy, cTy) \ + bool pw_encoder_write_##funcSuffix(pw_encoder_t *self, int field_no, cTy value) { \ + WireFormatLite::Write##wireTy(field_no, value, &self->codedOutputStream); \ + return check(self); \ + } + + WRITE_FIELD_FUNC( bool, Bool, bool) + WRITE_FIELD_FUNC( int32, Int32, int32_t) + WRITE_FIELD_FUNC( int64, Int64, int64_t) + WRITE_FIELD_FUNC( uint32, UInt32, uint32_t) + WRITE_FIELD_FUNC( uint64, UInt64, uint64_t) + WRITE_FIELD_FUNC( sint32, SInt32, int32_t) + WRITE_FIELD_FUNC( sint64, SInt64, int64_t) + WRITE_FIELD_FUNC( fixed32, Fixed32, uint32_t) + WRITE_FIELD_FUNC( fixed64, Fixed64, uint64_t) + WRITE_FIELD_FUNC( sfixed32, SFixed32, int32_t) + WRITE_FIELD_FUNC( sfixed64, SFixed64, int64_t) + WRITE_FIELD_FUNC( float, Float, float) + WRITE_FIELD_FUNC( double, Double, double) + WRITE_FIELD_FUNC( enum, Enum, int) + + bool pw_encoder_write_string(pw_encoder_t *self, int field_no, const char *data, int size) { + return pw_encoder_write_bytes(self, field_no, data, size); + } + bool pw_encoder_write_bytes(pw_encoder_t *self, int field_no, const void *data, int size) { + WireFormatLite::WriteTag(field_no, WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &self->codedOutputStream); + self->codedOutputStream.WriteVarint32(size); + self->codedOutputStream.WriteRawMaybeAliased(data, size); + return check(self); + } + + bool pw_encoder_write_tag(pw_encoder_t *self, int field_no, int wire_type) { + WireFormatLite::WriteTag(field_no, static_cast(wire_type), &self->codedOutputStream); + return check(self); + } + +#define WRITE_FIELD_NO_TAG( funcSuffix, wireTy, cTy) \ + bool pw_encoder_write_##funcSuffix##_no_tag(pw_encoder_t *self, cTy value) { \ + WireFormatLite::Write##wireTy##NoTag(value, &self->codedOutputStream); \ + return check(self); \ + } + + WRITE_FIELD_NO_TAG( bool, Bool, bool) + WRITE_FIELD_NO_TAG( int32, Int32, int32_t) + WRITE_FIELD_NO_TAG( int64, Int64, int64_t) + WRITE_FIELD_NO_TAG( uint32, UInt32, uint32_t) + WRITE_FIELD_NO_TAG( uint64, UInt64, uint64_t) + WRITE_FIELD_NO_TAG( sint32, SInt32, int32_t) + WRITE_FIELD_NO_TAG( sint64, SInt64, int64_t) + WRITE_FIELD_NO_TAG( fixed32, Fixed32, uint32_t) + WRITE_FIELD_NO_TAG( fixed64, Fixed64, uint64_t) + WRITE_FIELD_NO_TAG( sfixed32, SFixed32, int32_t) + WRITE_FIELD_NO_TAG( sfixed64, SFixed64, int64_t) + WRITE_FIELD_NO_TAG( float, Float, float) + WRITE_FIELD_NO_TAG( double, Double, double) + WRITE_FIELD_NO_TAG( enum, Enum, int) + + + /// DECODER IMPLEMENTATION /// + + pw_decoder_t *pw_decoder_new(pw_zero_copy_input_t zero_copy_input) { + return new pw_decoder_t(zero_copy_input); + } + + void pw_decoder_delete(pw_decoder_t *self) { + delete self; + } + + void pw_decoder_close(pw_decoder_t *self) { + // the deconstructor backs the stream up to the current position. + self->codedInputStream.~CodedInputStream(); + } + + uint32_t pw_decoder_read_tag(pw_decoder_t *self) { + return self->codedInputStream.ReadTag(); + } + +#define READ_VAL_FUNC( funcSuffix, wireTy, cTy) \ + bool pw_decoder_read_##funcSuffix(pw_decoder_t *self, cTy *value_ref) { \ + return WireFormatLite::ReadPrimitive(&self->codedInputStream, value_ref); \ + } + + READ_VAL_FUNC( bool, BOOL, bool) + READ_VAL_FUNC( int32, INT32, int32_t) + READ_VAL_FUNC( int64, INT64, int64_t) + READ_VAL_FUNC( uint32, UINT32, uint32_t) + READ_VAL_FUNC( uint64, UINT64, uint64_t) + READ_VAL_FUNC( sint32, SINT32, int32_t) + READ_VAL_FUNC( sint64, SINT64, int64_t) + READ_VAL_FUNC( fixed32, FIXED32, uint32_t) + READ_VAL_FUNC( fixed64, FIXED64, uint64_t) + READ_VAL_FUNC( sfixed32, SFIXED32, int32_t) + READ_VAL_FUNC( sfixed64, SFIXED64, int64_t) + READ_VAL_FUNC( float, FLOAT, float) + READ_VAL_FUNC( double, DOUBLE, double) + READ_VAL_FUNC( enum, ENUM, int) + + bool pw_decoder_read_string(pw_decoder_t *self, pw_string_t **string_ref) { + *string_ref = new pw_string_t; + return WireFormatLite::ReadString(&self->codedInputStream, &(*string_ref)->str); + } + + bool pw_decoder_read_raw_bytes(pw_decoder_t *self, void* buffer, int size) { + return self->codedInputStream.ReadRaw(buffer, size); + } + + int pw_decoder_push_limit(pw_decoder_t *self, int limit) { + return self->codedInputStream.PushLimit(limit); + } + + void pw_decoder_pop_limit(pw_decoder_t *self, int limit) { + self->codedInputStream.PopLimit(limit); + } + + int pw_decoder_bytes_until_limit(pw_decoder_t *self) { + return self->codedInputStream.BytesUntilLimit(); + } + + uint32_t pw_size_int32(int32_t value) { + return WireFormatLite::Int32Size(value); + } + uint32_t pw_size_int64(int64_t value) { + return WireFormatLite::Int64Size(value); + } + uint32_t pw_size_uint32(uint32_t value) { + return WireFormatLite::UInt32Size(value); + } + uint32_t pw_size_uint64(uint64_t value) { + return WireFormatLite::UInt64Size(value); + } + uint32_t pw_size_sint32(int32_t value) { + return WireFormatLite::SInt32Size(value); + } + uint32_t pw_size_sint64(int64_t value) { + return WireFormatLite::SInt64Size(value); + } + +} diff --git a/versions-root/libs.versions.toml b/versions-root/libs.versions.toml index 8c7137f4a..3be9fa2ec 100644 --- a/versions-root/libs.versions.toml +++ b/versions-root/libs.versions.toml @@ -20,6 +20,8 @@ junit5 = "5.13.2" intellij = "241.19416.19" gradle-doctor = "0.11.0" kotlinx-browser = "0.3" +kotlinx-io = "0.8.0" +kotlinx-collections = "0.4.0" dokka = "2.0.0" puppeteer = "24.9.0" atomicfu = "0.29.0" @@ -59,6 +61,8 @@ kotlin-compiler-test-framework = { module = "org.jetbrains.kotlin:kotlin-compile serialization-plugin = { module = "org.jetbrains.kotlin:kotlin-serialization-compiler-plugin", version.ref = "kotlin-compiler" } serialization-plugin-forIde = { module = "org.jetbrains.kotlin:kotlinx-serialization-compiler-plugin-for-ide", version.ref = "kotlin-compiler" } kotlinx-browser = { module = "org.jetbrains.kotlinx:kotlinx-browser", version.ref = "kotlinx-browser" } +kotlinx-io-core = { module = "org.jetbrains.kotlinx:kotlinx-io-core", version.ref = "kotlinx-io" } +kotlinx-collections-immutable = { module = "org.jetbrains.kotlinx:kotlinx-collections-immutable", version.ref = "kotlinx-collections"} # serialization serialization-core = { module = "org.jetbrains.kotlinx:kotlinx-serialization-core", version.ref = "serialization" }