Skip to content

Commit 3af418a

Browse files
committed
feat(TasmStruct): Add destructuring capabilities
When writing Triton assembly involving `TasmStruct`s, it is common to access many if not all fields of of the struct. Previously, the only sensible way to do so was duplicating the struct pointer, then executing the field-getter. Now, executing the code emitted by the new `destructure` function replaces a struct pointer with pointers to all its fields.
1 parent 3d3daa7 commit 3af418a

File tree

3 files changed

+427
-0
lines changed

3 files changed

+427
-0
lines changed

tasm-lib/src/structure/auto_generated_tasm_object_implementations.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ macro_rules! derive_tasm_object_for {
4747
fn get_field_start_with_jump_distance(field_name: &str) -> Vec<LabelledInstruction> {
4848
$fake::get_field_start_with_jump_distance(field_name)
4949
}
50+
51+
fn destructure() -> Vec<LabelledInstruction> {
52+
$fake::destructure()
53+
}
5054
}
5155
};
5256
}
@@ -119,4 +123,8 @@ impl TasmStruct for MmrAccumulator {
119123
fn get_field_start_with_jump_distance(field_name: &str) -> Vec<LabelledInstruction> {
120124
FakeMmrAccumulator::get_field_start_with_jump_distance(field_name)
121125
}
126+
127+
fn destructure() -> Vec<LabelledInstruction> {
128+
FakeMmrAccumulator::destructure()
129+
}
122130
}

tasm-lib/src/structure/tasm_object.rs

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,51 @@ pub trait TasmStruct: TasmObject {
108108
/// [`get_field_with_size`](TasmObject::get_field_with_size) instead.
109109
#[doc(hidden)]
110110
fn get_field_start_with_jump_distance(field_name: &str) -> Vec<LabelledInstruction>;
111+
112+
/// Destructure a struct into the pointers to its fields.
113+
///
114+
/// ```text
115+
/// BEFORE: _ *struct
116+
/// AFTER: _ [pointers to all fields]
117+
/// ```
118+
///
119+
/// # Example
120+
///
121+
/// The example below defines a struct `Foo` and encodes an instance of it into
122+
/// memory. It then creates a Triton VM program to read and destructure the
123+
/// `Foo` instance, extracting and outputting the `bar` field. Finally, it runs
124+
/// the program and asserts that the extracted value matches the original `bar`
125+
/// value.
126+
///
127+
/// ```ignore // derive macro `BFieldCodec` does not behave nicely; todo
128+
/// # use tasm_lib::prelude::*;
129+
/// # use tasm_lib::triton_vm::prelude::*;
130+
/// # use tasm_lib::memory::encode_to_memory;
131+
/// #[derive(BFieldCodec, TasmObject)]
132+
/// struct Foo {
133+
/// bar: u32,
134+
/// baz: XFieldElement,
135+
/// }
136+
///
137+
/// let foo = Foo { bar: 13, baz: xfe!(0) };
138+
/// let foo_ptr = bfe!(42);
139+
/// let mut non_determinism = NonDeterminism::default();
140+
/// encode_to_memory(&mut non_determinism.ram, foo_ptr, &foo);
141+
///
142+
/// let program = triton_program! {
143+
/// read_io 1 // _ *foo
144+
/// {&Foo::destructure()} // _ *baz *bar
145+
/// read_mem 1 // _ *baz bar (*bar - 1)
146+
/// pop 1 // _ *baz bar
147+
/// write_io 1 // _ *baz
148+
/// halt
149+
/// };
150+
///
151+
/// let output = VM::run(program, PublicInput::new(vec![foo_ptr]), non_determinism).unwrap();
152+
/// let [bar] = output[..] else { panic!() };
153+
/// assert_eq!(bfe!(foo.bar), bar);
154+
/// ```
155+
fn destructure() -> Vec<LabelledInstruction>;
111156
}
112157

113158
pub fn decode_from_memory_with_size<T: BFieldCodec>(
@@ -891,6 +936,237 @@ mod tests {
891936
}
892937
}
893938

939+
#[cfg(test)]
940+
mod destructure {
941+
use super::*;
942+
943+
#[test]
944+
fn unit_struct() {
945+
#[derive(BFieldCodec, TasmObject)]
946+
struct Empty {}
947+
948+
let sentinel = bfe!(0xdead_face_u64);
949+
let program = triton_program! {
950+
push {sentinel} // _ s
951+
push 0 // _ s 0
952+
{&Empty::destructure()} // _ s
953+
push {sentinel} // _ s s
954+
eq // _ (s == s)
955+
assert // _
956+
halt
957+
};
958+
VM::run(program, PublicInput::default(), NonDeterminism::default()).unwrap();
959+
}
960+
961+
mod one_field {
962+
use super::*;
963+
964+
#[derive(Debug, Copy, Clone, BFieldCodec, TasmObject, Arbitrary)]
965+
struct TupleStatic(u32);
966+
967+
#[derive(Debug, Clone, BFieldCodec, TasmObject, Arbitrary)]
968+
struct TupleDynamic(Vec<u32>);
969+
970+
#[derive(Debug, Copy, Clone, BFieldCodec, TasmObject, Arbitrary)]
971+
struct NamedStatic {
972+
field: u32,
973+
}
974+
975+
#[derive(Debug, Clone, BFieldCodec, TasmObject, Arbitrary)]
976+
struct NamedDynamic {
977+
field: Vec<u32>,
978+
}
979+
980+
// This macro is a little bit cursed due to the `$post_process`. Since it is
981+
// very limited in scope, I say it's better than duplicating essentially the
982+
// same code four times. If you want to extend the scope of this macro, please
983+
// re-design it.
984+
macro_rules! one_field_test_case {
985+
(fn $test_name:ident for $ty:ident: $f_name:tt $($post_process:tt)*) => {
986+
#[proptest]
987+
fn $test_name(
988+
#[strategy(arb())] foo: $ty,
989+
#[strategy(arb())] ptr: BFieldElement,
990+
) {
991+
let program = triton_program! {
992+
push {ptr}
993+
{&$ty::destructure()}
994+
read_mem 1 pop 1 write_io 1
995+
halt
996+
};
997+
998+
let mut non_determinism = NonDeterminism::default();
999+
encode_to_memory(&mut non_determinism.ram, ptr, &foo);
1000+
1001+
let output = VM::run(program, PublicInput::default(), non_determinism);
1002+
let [output] = output.unwrap()[..] else {
1003+
return Err(TestCaseError::Fail("unexpected output".into()));
1004+
};
1005+
1006+
let $ty { $f_name: the_field } = foo;
1007+
let expected = the_field$($post_process)*;
1008+
prop_assert_eq!(bfe!(expected), output);
1009+
}
1010+
};
1011+
}
1012+
1013+
one_field_test_case!( fn tuple_static for TupleStatic: 0 );
1014+
one_field_test_case!( fn tuple_dynamic for TupleDynamic: 0.len() );
1015+
one_field_test_case!( fn named_static for NamedStatic: field );
1016+
one_field_test_case!( fn named_dynamic for NamedDynamic: field.len() );
1017+
}
1018+
1019+
mod two_fields {
1020+
use super::*;
1021+
1022+
#[derive(Debug, Copy, Clone, BFieldCodec, TasmObject, Arbitrary)]
1023+
struct TupleStatStat(u32, u32);
1024+
1025+
#[derive(Debug, Clone, BFieldCodec, TasmObject, Arbitrary)]
1026+
struct TupleStatDyn(u32, Vec<u32>);
1027+
1028+
#[derive(Debug, Clone, BFieldCodec, TasmObject, Arbitrary)]
1029+
struct TupleDynStat(Vec<u32>, u32);
1030+
1031+
#[derive(Debug, Clone, BFieldCodec, TasmObject, Arbitrary)]
1032+
struct TupleDynDyn(Vec<u32>, Vec<u32>);
1033+
1034+
#[derive(Debug, Copy, Clone, BFieldCodec, TasmObject, Arbitrary)]
1035+
struct NamedStatStat {
1036+
a: u32,
1037+
b: u32,
1038+
}
1039+
1040+
#[derive(Debug, Clone, BFieldCodec, TasmObject, Arbitrary)]
1041+
struct NamedStatDyn {
1042+
a: u32,
1043+
b: Vec<u32>,
1044+
}
1045+
1046+
#[derive(Debug, Clone, BFieldCodec, TasmObject, Arbitrary)]
1047+
struct NamedDynStat {
1048+
a: Vec<u32>,
1049+
b: u32,
1050+
}
1051+
1052+
#[derive(Debug, Clone, BFieldCodec, TasmObject, Arbitrary)]
1053+
struct NamedDynDyn {
1054+
a: Vec<u32>,
1055+
b: Vec<u32>,
1056+
}
1057+
1058+
// This macro is a little bit cursed due to the `$post_process`es. Since it is
1059+
// very limited in scope, I say it's better than duplicating essentially the
1060+
// same code eight times. If you want to extend the scope of this macro, please
1061+
// re-design it.
1062+
macro_rules! two_fields_test_case {
1063+
(fn $test_name:ident for $ty:ident:
1064+
($f_name_0:tt $($post_process_0:tt)*)
1065+
($f_name_1:tt $($post_process_1:tt)*)
1066+
) => {
1067+
#[proptest]
1068+
fn $test_name(
1069+
#[strategy(arb())] foo: $ty,
1070+
#[strategy(arb())] ptr: BFieldElement,
1071+
) {
1072+
let program = triton_program! {
1073+
push {ptr}
1074+
{&$ty::destructure()}
1075+
read_mem 1 pop 1 write_io 1
1076+
read_mem 1 pop 1 write_io 1
1077+
halt
1078+
};
1079+
1080+
let mut non_determinism = NonDeterminism::default();
1081+
encode_to_memory(&mut non_determinism.ram, ptr, &foo);
1082+
1083+
let output = VM::run(program, PublicInput::default(), non_determinism);
1084+
let [output_0, output_1] = output.unwrap()[..] else {
1085+
return Err(TestCaseError::Fail("unexpected output".into()));
1086+
};
1087+
1088+
let $ty { $f_name_0: field_0, $f_name_1: field_1 } = foo;
1089+
let expected_0 = field_0$($post_process_0)*;
1090+
let expected_1 = field_1$($post_process_1)*;
1091+
prop_assert_eq!(bfe!(expected_0), output_0);
1092+
prop_assert_eq!(bfe!(expected_1), output_1);
1093+
}
1094+
};
1095+
}
1096+
1097+
two_fields_test_case!( fn tuple_stat_stat for TupleStatStat: (0) (1) );
1098+
two_fields_test_case!( fn tuple_stat_dyn for TupleStatDyn: (0) (1.len()) );
1099+
two_fields_test_case!( fn tuple_dyn_stat for TupleDynStat: (0.len()) (1) );
1100+
two_fields_test_case!( fn tuple_dyn_dyn for TupleDynDyn: (0.len()) (1.len()) );
1101+
two_fields_test_case!( fn named_stat_stat for NamedStatStat: (a) (b) );
1102+
two_fields_test_case!( fn named_stat_dyn for NamedStatDyn: (a) (b.len()) );
1103+
two_fields_test_case!( fn named_dyn_stat for NamedDynStat: (a.len()) (b) );
1104+
two_fields_test_case!( fn named_dyn_dyn for NamedDynDyn: (a.len()) (b.len()) );
1105+
}
1106+
1107+
#[test]
1108+
fn all_static_dynamic_neighbor_combinations() {
1109+
/// A struct where all neighbor combinations of fields with
1110+
/// {static, dynamic}×{static, dynamic} sizes occur.
1111+
#[derive(Debug, BFieldCodec, TasmObject, Eq, PartialEq)]
1112+
struct Foo {
1113+
a: XFieldElement,
1114+
b: Vec<Digest>,
1115+
c: Vec<Vec<XFieldElement>>,
1116+
d: u128,
1117+
e: u64,
1118+
}
1119+
1120+
let foo = Foo {
1121+
a: xfe!([42, 43, 44]),
1122+
b: vec![Digest::new(bfe_array![45, 46, 47, 48, 49])],
1123+
c: vec![vec![], xfe_vec![[50, 51, 52]]],
1124+
d: 53 + (54 << 32) + (55 << 64) + (56 << 96),
1125+
e: 57 + (58 << 32),
1126+
};
1127+
1128+
let foo_encoding = bfe_vec![
1129+
/* e: 00..=01 */ 57, 58, //
1130+
/* d: 02..=05 */ 53, 54, 55, 56, //
1131+
/* c: 06..=14 */ 8, 2, 1, 0, 4, 1, 50, 51, 52, //
1132+
/* b: 15..=21 */ 6, 1, 45, 46, 47, 48, 49, //
1133+
/* a: 22..=24 */ 42, 43, 44 //
1134+
];
1135+
debug_assert_eq!(foo_encoding, foo.encode(),);
1136+
1137+
let foo_ptr = bfe!(100);
1138+
let mut non_determinism = NonDeterminism::default();
1139+
encode_to_memory(&mut non_determinism.ram, foo_ptr, &foo);
1140+
1141+
let program = triton_program! {
1142+
read_io 1 // _ *foo
1143+
{&Foo::destructure()} // _ *e *d *c *b *a
1144+
write_io 5 // _
1145+
halt
1146+
};
1147+
1148+
let input = PublicInput::new(vec![foo_ptr]);
1149+
let output = VM::run(program, input, non_determinism.clone()).unwrap();
1150+
let [a_ptr, b_ptr, c_ptr, d_ptr, e_ptr] = output[..] else {
1151+
panic!("expected 5 pointers");
1152+
};
1153+
1154+
assert_eq!(foo_ptr + bfe!(22), a_ptr);
1155+
assert_eq!(foo_ptr + bfe!(16), b_ptr);
1156+
assert_eq!(foo_ptr + bfe!(7), c_ptr);
1157+
assert_eq!(foo_ptr + bfe!(2), d_ptr);
1158+
assert_eq!(foo_ptr + bfe!(0), e_ptr);
1159+
1160+
let a = *XFieldElement::decode_from_memory(&non_determinism.ram, a_ptr).unwrap();
1161+
let b = *Vec::decode_from_memory(&non_determinism.ram, b_ptr).unwrap();
1162+
let c = *Vec::decode_from_memory(&non_determinism.ram, c_ptr).unwrap();
1163+
let d = *u128::decode_from_memory(&non_determinism.ram, d_ptr).unwrap();
1164+
let e = *u64::decode_from_memory(&non_determinism.ram, e_ptr).unwrap();
1165+
let foo_again = Foo { a, b, c, d, e };
1166+
assert_eq!(foo, foo_again);
1167+
}
1168+
}
1169+
8941170
#[test]
8951171
fn test_option() {
8961172
let mut rng = thread_rng();

0 commit comments

Comments
 (0)