Skip to content

Commit 6ff6b51

Browse files
committed
cxx output stream tested
1 parent 3eb3d52 commit 6ff6b51

File tree

3 files changed

+213
-62
lines changed

3 files changed

+213
-62
lines changed

kj-rs/io/lib.rs

Lines changed: 26 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
use std::future::Future;
2-
3-
use cxx::UniquePtr;
1+
use std::{future::Future, pin::Pin};
42

53
pub type Result<T> = std::io::Result<T>;
64

@@ -406,7 +404,7 @@ impl<T> AsyncWriteAdapter<T> {
406404

407405
impl<T: AsyncInputStream + Unpin> futures::io::AsyncRead for AsyncReadAdapter<T> {
408406
fn poll_read(
409-
mut self: std::pin::Pin<&mut Self>,
407+
mut self: Pin<&mut Self>,
410408
cx: &mut std::task::Context<'_>,
411409
buf: &mut [u8],
412410
) -> std::task::Poll<std::io::Result<usize>> {
@@ -426,7 +424,7 @@ impl<T: AsyncInputStream + Unpin> futures::io::AsyncRead for AsyncReadAdapter<T>
426424

427425
impl<T: AsyncOutputStream + Unpin> futures::io::AsyncWrite for AsyncWriteAdapter<T> {
428426
fn poll_write(
429-
mut self: std::pin::Pin<&mut Self>,
427+
mut self: Pin<&mut Self>,
430428
cx: &mut std::task::Context<'_>,
431429
buf: &[u8],
432430
) -> std::task::Poll<std::io::Result<usize>> {
@@ -443,15 +441,15 @@ impl<T: AsyncOutputStream + Unpin> futures::io::AsyncWrite for AsyncWriteAdapter
443441
}
444442

445443
fn poll_flush(
446-
self: std::pin::Pin<&mut Self>,
444+
self: Pin<&mut Self>,
447445
_cx: &mut std::task::Context<'_>,
448446
) -> std::task::Poll<std::io::Result<()>> {
449447
// KJ streams don't have explicit flush, so we just return ready
450448
std::task::Poll::Ready(Ok(()))
451449
}
452450

453451
fn poll_close(
454-
self: std::pin::Pin<&mut Self>,
452+
self: Pin<&mut Self>,
455453
_cx: &mut std::task::Context<'_>,
456454
) -> std::task::Poll<std::io::Result<()>> {
457455
// KJ streams don't have explicit close, so we just return ready
@@ -563,26 +561,20 @@ pub fn cxx_to_io_error(e: cxx::Exception) -> std::io::Error {
563561
}
564562

565563
/// Rust wrapper for the `CxxAsyncInputStream` FFI type
566-
pub struct CxxAsyncInputStream {
567-
inner: UniquePtr<ffi::CxxAsyncInputStream>,
564+
pub struct CxxAsyncInputStream<'a> {
565+
inner: Pin<&'a mut ffi::CxxAsyncInputStream>,
568566
}
569567

570-
impl CxxAsyncInputStream {
571-
/// Create a new `CxxAsyncInputStream` from the FFI type
572-
pub fn new(inner: UniquePtr<ffi::CxxAsyncInputStream>) -> Self {
568+
impl<'a> CxxAsyncInputStream<'a> {
569+
pub fn new(inner: Pin<&'a mut ffi::CxxAsyncInputStream>) -> Self {
573570
Self { inner }
574571
}
575572
}
576573

577-
impl AsyncInputStream for CxxAsyncInputStream {
578-
async fn try_read(
579-
&mut self,
580-
buffer: &mut [u8],
581-
min_bytes: usize,
582-
) -> Result<usize> {
574+
impl<'a> AsyncInputStream for CxxAsyncInputStream<'a> {
575+
async fn try_read(&mut self, buffer: &mut [u8], min_bytes: usize) -> Result<usize> {
583576
self.inner
584577
.as_mut()
585-
.expect("CxxAsyncInputStream is null")
586578
.try_read(buffer, min_bytes)
587579
.await
588580
.map_err(cxx_to_io_error)
@@ -605,20 +597,17 @@ impl AsyncInputStream for CxxAsyncInputStream {
605597
}
606598

607599
/// Rust wrapper for the `CxxAsyncOutputStream` FFI type
608-
pub struct CxxAsyncOutputStream {
609-
inner: std::pin::Pin<Box<ffi::CxxAsyncOutputStream>>,
600+
pub struct CxxAsyncOutputStream<'a> {
601+
inner: Pin<&'a mut ffi::CxxAsyncOutputStream>,
610602
}
611603

612-
impl CxxAsyncOutputStream {
613-
/// Create a new `CxxAsyncOutputStream` from the FFI type
614-
pub fn new(ffi_stream: ffi::CxxAsyncOutputStream) -> Self {
615-
Self {
616-
inner: Box::pin(ffi_stream),
617-
}
604+
impl<'a> CxxAsyncOutputStream<'a> {
605+
pub fn new(inner: Pin<&'a mut ffi::CxxAsyncOutputStream>) -> Self {
606+
Self { inner }
618607
}
619608
}
620609

621-
impl AsyncOutputStream for CxxAsyncOutputStream {
610+
impl<'a> AsyncOutputStream for CxxAsyncOutputStream<'a> {
622611
async fn write(&mut self, buffer: &[u8]) -> Result<()> {
623612
self.inner
624613
.as_mut()
@@ -655,25 +644,18 @@ impl AsyncOutputStream for CxxAsyncOutputStream {
655644
}
656645

657646
/// Rust wrapper for the `CxxAsyncIoStream` FFI type
658-
pub struct CxxAsyncIoStream {
659-
inner: std::pin::Pin<Box<ffi::CxxAsyncIoStream>>,
647+
pub struct CxxAsyncIoStream<'a> {
648+
inner: Pin<&'a mut ffi::CxxAsyncIoStream>,
660649
}
661650

662-
impl CxxAsyncIoStream {
663-
/// Create a new `CxxAsyncIoStream` from the FFI type
664-
pub fn new(ffi_stream: ffi::CxxAsyncIoStream) -> Self {
665-
Self {
666-
inner: Box::pin(ffi_stream),
667-
}
651+
impl<'a> CxxAsyncIoStream<'a> {
652+
pub fn new(inner: Pin<&'a mut ffi::CxxAsyncIoStream>) -> Self {
653+
Self { inner }
668654
}
669655
}
670656

671-
impl AsyncInputStream for CxxAsyncIoStream {
672-
async fn try_read(
673-
&mut self,
674-
buffer: &mut [u8],
675-
min_bytes: usize,
676-
) -> Result<usize> {
657+
impl<'a> AsyncInputStream for CxxAsyncIoStream<'a> {
658+
async fn try_read(&mut self, buffer: &mut [u8], min_bytes: usize) -> Result<usize> {
677659
self.inner
678660
.as_mut()
679661
.try_read(buffer, min_bytes)
@@ -697,7 +679,7 @@ impl AsyncInputStream for CxxAsyncIoStream {
697679
}
698680
}
699681

700-
impl AsyncOutputStream for CxxAsyncIoStream {
682+
impl<'a> AsyncOutputStream for CxxAsyncIoStream<'a> {
701683
async fn write(&mut self, buffer: &[u8]) -> Result<()> {
702684
self.inner
703685
.as_mut()
@@ -733,7 +715,7 @@ impl AsyncOutputStream for CxxAsyncIoStream {
733715
}
734716
}
735717

736-
impl AsyncIoStream for CxxAsyncIoStream {
718+
impl<'a> AsyncIoStream for CxxAsyncIoStream<'a> {
737719
async fn shutdown_write(&mut self) -> Result<()> {
738720
self.inner.as_mut().shutdown_write();
739721
Ok(())

kj-rs/io/tests.c++

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,40 @@ class ArrayInputStream: public kj::AsyncInputStream {
3737
kj::ArrayPtr<const kj::byte> data;
3838
};
3939

40+
// Simple output stream that writes to a kj::Vector
41+
class VectorOutputStream: public kj::AsyncOutputStream {
42+
public:
43+
VectorOutputStream() = default;
44+
virtual ~VectorOutputStream() = default;
45+
46+
kj::Promise<void> write(kj::ArrayPtr<const kj::byte> buffer) override {
47+
data.addAll(buffer);
48+
return kj::READY_NOW;
49+
}
50+
51+
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const kj::byte>> pieces) override {
52+
for (auto piece : pieces) {
53+
data.addAll(piece);
54+
}
55+
return kj::READY_NOW;
56+
}
57+
58+
kj::Promise<void> whenWriteDisconnected() override {
59+
return kj::NEVER_DONE; // Never disconnected
60+
}
61+
62+
const kj::Vector<kj::byte>& getData() const {
63+
return data;
64+
}
65+
66+
void clear() {
67+
data.clear();
68+
}
69+
70+
private:
71+
kj::Vector<kj::byte> data;
72+
};
73+
4074
// C++ implementation of FNV-1a hash algorithm (matching Rust implementation)
4175
kj::Promise<uint64_t> computeStreamHash(kj::Own<kj::AsyncInputStream> stream) {
4276
static constexpr uint64_t FNV_OFFSET_BASIS = 14695981039346656037ULL;
@@ -79,15 +113,71 @@ KJ_TEST("Read C++ ArrayInputStream in Rust") {
79113

80114
auto testData = "Hello, World!"_kjb;
81115

82-
auto stream = std::make_unique<kj_rs_io::CxxAsyncInputStream>(kj::heap<ArrayInputStream>(testData));
83-
auto hash = compute_stream_hash_ffi(kj::mv(stream)).wait(waitScope);
116+
auto stream = kj_rs_io::CxxAsyncInputStream(kj::heap<ArrayInputStream>(testData));
117+
auto hash = compute_stream_hash_ffi(stream).wait(waitScope);
84118

85119
KJ_EXPECT(hash == 7993990320990026836);
86120

87-
auto stream2 = std::make_unique<kj_rs_io::CxxAsyncInputStream>(kj::heap<ArrayInputStream>(testData));
88-
auto hash2 = compute_stream_hash_ffi(kj::mv(stream2)).wait(waitScope);
121+
auto stream2 = kj_rs_io::CxxAsyncInputStream(kj::heap<ArrayInputStream>(testData));
122+
auto hash2 = compute_stream_hash_ffi(stream2).wait(waitScope);
89123

90124
KJ_EXPECT(hash == hash2);
91125
}
92126

127+
KJ_TEST("Write to C++ OutputStream from Rust") {
128+
kj::EventLoop loop;
129+
kj::WaitScope waitScope(loop);
130+
131+
auto vectorStream = kj::heap<VectorOutputStream>();
132+
auto* streamPtr = vectorStream.get();
133+
134+
auto stream = kj_rs_io::CxxAsyncOutputStream(kj::mv(vectorStream));
135+
136+
// Generate 100 bytes of pseudorandom data
137+
generate_prng_ffi(stream, 100).wait(waitScope);
138+
139+
// Check that data was written
140+
const auto& data = streamPtr->getData();
141+
KJ_EXPECT(data.size() == 100);
142+
143+
// Test that the data is deterministic by comparing with another stream
144+
auto vectorStream2 = kj::heap<VectorOutputStream>();
145+
auto* streamPtr2 = vectorStream2.get();
146+
auto stream2 = kj_rs_io::CxxAsyncOutputStream(kj::mv(vectorStream2));
147+
148+
generate_prng_ffi(stream2, 100).wait(waitScope);
149+
150+
const auto& data2 = streamPtr2->getData();
151+
KJ_EXPECT(data.size() == data2.size());
152+
153+
// Compare the data byte by byte
154+
for (size_t i = 0; i < data.size(); i++) {
155+
KJ_EXPECT(data[i] == data2[i], "Data should be deterministic", i, data[i], data2[i]);
156+
}
157+
}
158+
159+
KJ_TEST("Write large data to C++ OutputStream from Rust") {
160+
kj::EventLoop loop;
161+
kj::WaitScope waitScope(loop);
162+
163+
auto vectorStream = kj::heap<VectorOutputStream>();
164+
auto* streamPtr = vectorStream.get();
165+
166+
auto stream = kj_rs_io::CxxAsyncOutputStream(kj::mv(vectorStream));
167+
168+
// Generate 2048 * 2048 bytes (multiple chunks)
169+
generate_prng_ffi(stream, 2048 * 2048).wait(waitScope);
170+
171+
// Check that data was written
172+
const auto& data = streamPtr->getData();
173+
KJ_EXPECT(data.size() == 2048 * 2048);
174+
175+
// Basic check that the data varies (not all zeros)
176+
size_t zeroCount = 0;
177+
for (auto byte : data) {
178+
if (byte == 0) zeroCount++;
179+
}
180+
KJ_EXPECT(zeroCount < data.size() / 10, "Data should vary, less than 10% zeros");
181+
}
182+
93183
} // namespace kj_rs_io_test

0 commit comments

Comments
 (0)