Skip to content

Commit 5e7d46c

Browse files
authored
[Rust] use cleanup list as necessary for stream/future writes (#1240)
Previously, I had failed to consider (or test) the case where lowering values for stream/future writes would require intermediate allocations that need to be cleaned up. That manifested as compiler errors due to generated code referencing a non-existent `cleanup_list` variable. Fixes #1153 Signed-off-by: Joel Dice <[email protected]>
1 parent da2aed4 commit 5e7d46c

File tree

4 files changed

+69
-25
lines changed

4 files changed

+69
-25
lines changed

crates/rust/src/bindgen.rs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,22 @@ impl<'a, 'b> FunctionBindgen<'a, 'b> {
5454
}
5555
}
5656

57-
fn emit_cleanup(&mut self) {
57+
pub(crate) fn flush_cleanup(&mut self) {
58+
if !self.cleanup.is_empty() {
59+
self.needs_cleanup_list = true;
60+
self.push_str("cleanup_list.extend_from_slice(&[");
61+
for (ptr, layout) in mem::take(&mut self.cleanup) {
62+
self.push_str("(");
63+
self.push_str(&ptr);
64+
self.push_str(", ");
65+
self.push_str(&layout);
66+
self.push_str("),");
67+
}
68+
self.push_str("]);\n");
69+
}
70+
}
71+
72+
pub(crate) fn emit_cleanup(&mut self) {
5873
if self.emitted_cleanup {
5974
return;
6075
}
@@ -244,18 +259,7 @@ impl Bindgen for FunctionBindgen<'_, '_> {
244259
}
245260

246261
fn finish_block(&mut self, operands: &mut Vec<String>) {
247-
if !self.cleanup.is_empty() {
248-
self.needs_cleanup_list = true;
249-
self.push_str("cleanup_list.extend_from_slice(&[");
250-
for (ptr, layout) in mem::take(&mut self.cleanup) {
251-
self.push_str("(");
252-
self.push_str(&ptr);
253-
self.push_str(", ");
254-
self.push_str(&layout);
255-
self.push_str("),");
256-
}
257-
self.push_str("]);\n");
258-
}
262+
self.flush_cleanup();
259263
let (prev_src, prev_cleanup) = self.block_storage.pop().unwrap();
260264
let src = mem::replace(&mut self.src, prev_src);
261265
self.cleanup = prev_cleanup;

crates/rust/src/interface.rs

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -535,34 +535,40 @@ macro_rules! {macro_name} {{
535535
};
536536
let size = size.size_wasm32();
537537
let align = align.align_wasm32();
538-
let (lower, lift) = if let Some(payload_type) = payload_type {
539-
let lower =
538+
let (lower, cleanup, lift) = if let Some(payload_type) = payload_type {
539+
let (lower, cleanup) =
540540
self.lower_to_memory("address", "&value", &payload_type, &module);
541541
let lift =
542542
self.lift_from_memory("address", "value", &payload_type, &module);
543-
(lower, lift)
543+
(lower, cleanup, lift)
544544
} else {
545-
(String::new(), "let value = ();\n".into())
545+
(String::new(), None, "let value = ();\n".into())
546546
};
547547

548+
let (cleanup_start, cleanup_end) =
549+
cleanup.unwrap_or_else(|| (String::new(), String::new()));
550+
548551
let box_ = self.path_to_box();
549552
let code = format!(
550553
r#"
551554
#[doc(hidden)]
552555
pub mod vtable{ordinal} {{
553556
fn write(future: u32, value: {name}) -> ::core::pin::Pin<{box_}<dyn ::core::future::Future<Output = bool>>> {{
554557
{box_}::pin(async move {{
558+
{cleanup_start}
555559
#[repr(align({align}))]
556560
struct Buffer([::core::mem::MaybeUninit::<u8>; {size}]);
557561
let mut buffer = Buffer([::core::mem::MaybeUninit::uninit(); {size}]);
558562
let address = buffer.0.as_mut_ptr() as *mut u8;
559563
{lower}
560564
561-
match unsafe {{ {async_support}::await_future_result(start_write, future, address).await }} {{
565+
let result = match unsafe {{ {async_support}::await_future_result(start_write, future, address).await }} {{
562566
{async_support}::AsyncWaitResult::Values(_) => true,
563567
{async_support}::AsyncWaitResult::End => false,
564568
{async_support}::AsyncWaitResult::Error(_) => unreachable!("received error while performing write"),
565-
}}
569+
}};
570+
{cleanup_end}
571+
result
566572
}})
567573
}}
568574
@@ -660,14 +666,15 @@ pub mod vtable{ordinal} {{
660666
let size = size.size_wasm32();
661667
let align = align.align_wasm32();
662668
let alloc = self.path_to_std_alloc_module();
663-
let (lower_address, lower, lift_address, lift) = match payload_type {
669+
let (lower_address, lower, cleanup, lift_address, lift) = match payload_type
670+
{
664671
Some(payload_type) if !stream_direct(payload_type) => {
665672
let address = format!(
666673
"let address = unsafe {{ {alloc}::alloc\
667674
({alloc}::Layout::from_size_align_unchecked\
668675
({size} * values.len(), {align})) }};"
669676
);
670-
let lower = self.lower_to_memory(
677+
let (lower, cleanup) = self.lower_to_memory(
671678
"address",
672679
"value",
673680
&payload_type,
@@ -696,7 +703,7 @@ for (index, dst) in values.iter_mut().take(count).enumerate() {{
696703
}}
697704
"#
698705
);
699-
(address.clone(), lower, address, lift)
706+
(address.clone(), lower, cleanup, address, lift)
700707
}
701708
_ => {
702709
let lower_address =
@@ -706,25 +713,29 @@ for (index, dst) in values.iter_mut().take(count).enumerate() {{
706713
(
707714
lower_address,
708715
String::new(),
716+
None,
709717
lift_address,
710718
"let value = ();\n".into(),
711719
)
712720
}
713721
};
714722

723+
let (cleanup_start, cleanup_end) =
724+
cleanup.unwrap_or_else(|| (String::new(), String::new()));
725+
715726
let box_ = self.path_to_box();
716727
let code = format!(
717728
r#"
718729
#[doc(hidden)]
719730
pub mod vtable{ordinal} {{
720731
fn write(stream: u32, values: &[{name}]) -> ::core::pin::Pin<{box_}<dyn ::core::future::Future<Output = usize> + '_>> {{
721732
{box_}::pin(async move {{
733+
{cleanup_start}
722734
{lower_address}
723735
{lower}
724736
725737
let mut total = 0;
726738
while total < values.len() {{
727-
728739
match unsafe {{
729740
{async_support}::await_stream_result(
730741
start_write,
@@ -738,6 +749,7 @@ pub mod vtable{ordinal} {{
738749
{async_support}::AsyncWaitResult::End => break,
739750
}}
740751
}}
752+
{cleanup_end}
741753
total
742754
}})
743755
}}
@@ -867,10 +879,30 @@ pub mod vtable{ordinal} {{
867879
}
868880
}
869881

870-
fn lower_to_memory(&mut self, address: &str, value: &str, ty: &Type, module: &str) -> String {
882+
fn lower_to_memory(
883+
&mut self,
884+
address: &str,
885+
value: &str,
886+
ty: &Type,
887+
module: &str,
888+
) -> (String, Option<(String, String)>) {
871889
let mut f = FunctionBindgen::new(self, Vec::new(), true, module, true);
872890
abi::lower_to_memory(f.r#gen.resolve, &mut f, address.into(), value.into(), ty);
873-
format!("unsafe {{ {} }}", String::from(f.src))
891+
f.flush_cleanup();
892+
let lower = format!("unsafe {{ {} }}", String::from(f.src));
893+
let cleanup = if f.needs_cleanup_list {
894+
f.src = Default::default();
895+
f.emit_cleanup();
896+
let body = String::from(f.src);
897+
let vec = self.path_to_vec();
898+
Some((
899+
format!("let mut cleanup_list = {vec}::new();\n"),
900+
format!("unsafe {{ {body} }}"),
901+
))
902+
} else {
903+
None
904+
};
905+
(lower, cleanup)
874906
}
875907

876908
fn lift_from_memory(&mut self, address: &str, value: &str, ty: &Type, module: &str) -> String {

tests/codegen/futures.wit

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ interface futures {
2727
future-f32-ret: func() -> future<f32>;
2828
future-f64-ret: func() -> future<f64>;
2929

30+
future-result-list-string-ret: func() -> future<result<list<string>>>;
31+
future-result-list-list-u8-ret: func() -> future<result<list<list<u8>>>>;
32+
future-list-list-list-u8-ret: func() -> future<list<list<list<u8>>>>;
33+
3034
tuple-future: func(x: future<tuple<u8, s8>>) -> future<tuple<s64, u32>>;
3135
string-future-arg: func(a: future<string>);
3236
string-future-ret: func() -> future<string>;

tests/codegen/streams.wit

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ interface streams {
4141
stream-f32-ret: func() -> stream<f32>;
4242
stream-f64-ret: func() -> stream<f64>;
4343

44+
stream-result-list-string-ret: func() -> stream<result<list<string>>>;
45+
stream-result-list-list-u8-ret: func() -> stream<result<list<list<u8>>>>;
46+
stream-list-list-list-u8-ret: func() -> stream<list<list<list<u8>>>>;
47+
4448
tuple-stream: func(x: stream<tuple<u8, s8>>) -> stream<tuple<s64, u32>>;
4549
string-stream-arg: func(a: stream<string>);
4650
string-stream-ret: func() -> stream<string>;

0 commit comments

Comments
 (0)