Skip to content

Commit 1c556be

Browse files
committed
fix: Use correct types in “higher order” snippets
1 parent 311bab3 commit 1c556be

File tree

4 files changed

+17
-49
lines changed

4 files changed

+17
-49
lines changed

tasm-lib/src/list/higher_order/all.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,9 @@ impl All {
7575

7676
impl BasicSnippet for All {
7777
fn inputs(&self) -> Vec<(DataType, String)> {
78-
let input_type = match &self.f {
79-
InnerFunction::BasicSnippet(basic_snippet) => {
80-
DataType::List(Box::new(basic_snippet.inputs()[0].0.clone()))
81-
}
82-
_ => DataType::VoidPointer,
83-
};
84-
vec![(input_type, "*input_list".to_string())]
78+
let element_type = self.f.domain();
79+
let list_type = DataType::List(Box::new(element_type));
80+
vec![(list_type, "*input_list".to_string())]
8581
}
8682

8783
fn outputs(&self) -> Vec<(DataType, String)> {

tasm-lib/src/list/higher_order/filter.rs

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,21 @@ use super::inner_function::InnerFunction;
2525

2626
/// Filters a given list for elements that satisfy a predicate. A new
2727
/// list is created, containing only those elements that satisfy the
28-
/// predicate. The predicate must be given as an InnerFunction.
28+
/// predicate. The predicate must be given as an [`InnerFunction`].
2929
pub struct Filter {
3030
pub f: InnerFunction,
3131
}
3232

3333
impl BasicSnippet for Filter {
3434
fn inputs(&self) -> Vec<(DataType, String)> {
35-
let list_type = match &self.f {
36-
InnerFunction::BasicSnippet(basic_snippet) => {
37-
DataType::List(Box::new(basic_snippet.inputs()[0].0.clone()))
38-
}
39-
_ => DataType::VoidPointer,
40-
};
35+
let element_type = self.f.domain();
36+
let list_type = DataType::List(Box::new(element_type));
4137
vec![(list_type, "*input_list".to_string())]
4238
}
4339

4440
fn outputs(&self) -> Vec<(DataType, String)> {
45-
let list_type = match &self.f {
46-
InnerFunction::BasicSnippet(basic_snippet) => {
47-
DataType::List(Box::new(basic_snippet.inputs()[0].0.clone()))
48-
}
49-
_ => DataType::VoidPointer,
50-
};
41+
let element_type = self.f.range();
42+
let list_type = DataType::List(Box::new(element_type));
5143
vec![(list_type, "*output_list".to_string())]
5244
}
5345

@@ -90,7 +82,7 @@ impl BasicSnippet for Filter {
9082
// If function was supplied as raw instructions, we need to append the inner function to the function
9183
// body. Otherwise, `library` handles the imports.
9284
let maybe_inner_function_body_raw = match &self.f {
93-
InnerFunction::RawCode(rc) => rc.function.iter().map(|x| x.to_string()).join("\n"),
85+
InnerFunction::RawCode(rc) => rc.function.iter().join("\n"),
9486
InnerFunction::DeprecatedSnippet(_) => String::default(),
9587
InnerFunction::NoFunctionBody(_) => todo!(),
9688
InnerFunction::BasicSnippet(_) => String::default(),
@@ -114,7 +106,7 @@ impl BasicSnippet for Filter {
114106
call {main_loop} // _ *input_list *output_list input_len input_len output_len
115107

116108
swap 2 pop 2 // _ *input_list *output_list output_len
117-
call {set_length} // _input_list *output_list
109+
call {set_length} // _ *input_list *output_list
118110

119111
swap 1 // _ *output_list *input_list
120112
pop 1 // _ *output_list
@@ -123,7 +115,7 @@ impl BasicSnippet for Filter {
123115
// INVARIANT: _ *input_list *output_list input_len input_index output_index
124116
{main_loop}:
125117
// test return condition
126-
dup 1 // _ *input_list *output_list input_len input_index output_index input_index
118+
dup 1 // _ *input_list *output_list input_len input_index output_index input_index
127119
dup 3 eq // _ *input_list *output_list input_len input_index output_index input_index==input_len
128120

129121
skiz return

tasm-lib/src/list/higher_order/inner_function.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,9 @@ impl InnerFunction {
123123
}
124124
InnerFunction::NoFunctionBody(f) => f.input_type.clone(),
125125
InnerFunction::BasicSnippet(bs) => {
126-
let [ref input] = bs.inputs()[..] else {
126+
let [(ref input, _)] = bs.inputs()[..] else {
127127
panic!("{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}");
128128
};
129-
let (input, _) = input;
130129
input.clone()
131130
}
132131
}
@@ -143,10 +142,9 @@ impl InnerFunction {
143142
}
144143
InnerFunction::NoFunctionBody(lnat) => lnat.output_type.clone(),
145144
InnerFunction::BasicSnippet(bs) => {
146-
let [ref output] = bs.outputs()[..] else {
145+
let [(ref output, _)] = bs.outputs()[..] else {
147146
panic!("{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}");
148147
};
149-
let (output, _) = output;
150148
output.clone()
151149
}
152150
}
@@ -163,8 +161,8 @@ impl InnerFunction {
163161
}
164162

165163
/// Run the VM for on a given stack and memory to observe how it manipulates the
166-
/// stack. This is a helper function for [`apply`](apply), which in some cases just
167-
/// grabs the inner function's code and then needs a VM to apply it.
164+
/// stack. This is a helper function for [`apply`](Self::apply), which in some cases
165+
/// just grabs the inner function's code and then needs a VM to apply it.
168166
fn run_vm(
169167
instructions: &[LabelledInstruction],
170168
stack: &mut Vec<BFieldElement>,

tasm-lib/src/list/higher_order/map.rs

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ use super::inner_function::InnerFunction;
2525

2626
const INNER_FN_INCORRECT_NUM_INPUTS: &str = "Inner function in `map` only works with *one* \
2727
input. Use a tuple as a workaround.";
28-
const INNER_FN_INCORRECT_NUM_OUTPUTS: &str = "Inner function in `map` only works with *one* \
29-
output. Use a tuple as a workaround.";
3028

3129
/// Applies a given function to every element of a list, and collects the new elements
3230
/// into a new list.
@@ -42,29 +40,13 @@ impl Map {
4240

4341
impl BasicSnippet for Map {
4442
fn inputs(&self) -> Vec<(DataType, String)> {
45-
let element_type = if let InnerFunction::BasicSnippet(snippet) = &self.f {
46-
let [(ref element_type, _)] = snippet.inputs()[..] else {
47-
panic!("{INNER_FN_INCORRECT_NUM_INPUTS}");
48-
};
49-
element_type.to_owned()
50-
} else {
51-
DataType::VoidPointer
52-
};
53-
43+
let element_type = self.f.domain();
5444
let list_type = DataType::List(Box::new(element_type));
5545
vec![(list_type, "*input_list".to_string())]
5646
}
5747

5848
fn outputs(&self) -> Vec<(DataType, String)> {
59-
let element_type = if let InnerFunction::BasicSnippet(snippet) = &self.f {
60-
let [(ref element_type, _)] = snippet.outputs()[..] else {
61-
panic!("{INNER_FN_INCORRECT_NUM_OUTPUTS}");
62-
};
63-
element_type.to_owned()
64-
} else {
65-
DataType::VoidPointer
66-
};
67-
49+
let element_type = self.f.range();
6850
let list_type = DataType::List(Box::new(element_type));
6951
vec![(list_type, "*output_list".to_string())]
7052
}

0 commit comments

Comments
 (0)