Skip to content

Commit 311bab3

Browse files
committed
fix: Check length of output types in outputs()
Previously, method `Map::outputs()` checked the length of `snippet.inputs()`. Now, it correctly checks the length of `snippet.outputs()`. Also, de-duplicate code and improve style. changelog: ignore
1 parent ac71ccb commit 311bab3

File tree

1 file changed

+45
-78
lines changed
  • tasm-lib/src/list/higher_order

1 file changed

+45
-78
lines changed

tasm-lib/src/list/higher_order/map.rs

Lines changed: 45 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use std::collections::HashMap;
22

33
use itertools::Itertools;
4-
use num_traits::One;
54
use rand::prelude::*;
65
use triton_vm::isa::parser::tokenize;
76
use triton_vm::prelude::*;
@@ -24,8 +23,10 @@ use crate::InitVmState;
2423

2524
use super::inner_function::InnerFunction;
2625

27-
const MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION: &str = "inner function in `map` \
28-
currently only works with *one* input element. Use a tuple data type to circumvent this.";
26+
const INNER_FN_INCORRECT_NUM_INPUTS: &str = "Inner function in `map` only works with *one* \
27+
input. Use a tuple as a workaround.";
28+
const INNER_FN_INCORRECT_NUM_OUTPUTS: &str = "Inner function in `map` only works with *one* \
29+
output. Use a tuple as a workaround.";
2930

3031
/// Applies a given function to every element of a list, and collects the new elements
3132
/// into a new list.
@@ -41,43 +42,31 @@ impl Map {
4142

4243
impl BasicSnippet for Map {
4344
fn inputs(&self) -> Vec<(DataType, String)> {
44-
match &self.f {
45-
InnerFunction::BasicSnippet(bs) => {
46-
assert!(
47-
bs.inputs().len().is_one(),
48-
"{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}"
49-
);
50-
let element_type = &bs.inputs()[0].0;
51-
vec![(
52-
DataType::List(Box::new(element_type.clone())),
53-
"*input_list".to_string(),
54-
)]
55-
}
56-
_ => vec![(
57-
DataType::List(Box::new(DataType::VoidPointer)),
58-
"*input_list".to_string(),
59-
)],
60-
}
45+
let element_type = if let InnerFunction::BasicSnippet(snippet) = &self.f {
46+
let [(ref element_type, _)] = snippet.inputs()[..] else {
47+
panic!("{INNER_FN_INCORRECT_NUM_INPUTS}");
48+
};
49+
element_type.to_owned()
50+
} else {
51+
DataType::VoidPointer
52+
};
53+
54+
let list_type = DataType::List(Box::new(element_type));
55+
vec![(list_type, "*input_list".to_string())]
6156
}
6257

6358
fn outputs(&self) -> Vec<(DataType, String)> {
64-
match &self.f {
65-
InnerFunction::BasicSnippet(bs) => {
66-
assert!(
67-
bs.inputs().len().is_one(),
68-
"{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}"
69-
);
70-
let element_type = &bs.outputs()[0].0;
71-
vec![(
72-
DataType::List(Box::new(element_type.clone())),
73-
"*output_list".to_string(),
74-
)]
75-
}
76-
_ => vec![(
77-
DataType::List(Box::new(DataType::VoidPointer)),
78-
"*output_list".to_string(),
79-
)],
80-
}
59+
let element_type = if let InnerFunction::BasicSnippet(snippet) = &self.f {
60+
let [(ref element_type, _)] = snippet.outputs()[..] else {
61+
panic!("{INNER_FN_INCORRECT_NUM_OUTPUTS}");
62+
};
63+
element_type.to_owned()
64+
} else {
65+
DataType::VoidPointer
66+
};
67+
68+
let list_type = DataType::List(Box::new(element_type));
69+
vec![(list_type, "*output_list".to_string())]
8170
}
8271

8372
fn entrypoint(&self) -> String {
@@ -109,10 +98,7 @@ impl BasicSnippet for Map {
10998
}
11099
}
111100
InnerFunction::DeprecatedSnippet(sn) => {
112-
assert!(
113-
sn.input_types().len().is_one(),
114-
"{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}"
115-
);
101+
assert_eq!(1, sn.input_types().len(), "{INNER_FN_INCORRECT_NUM_INPUTS}");
116102
let fn_body = sn.function_code(library);
117103
let (_, instructions) = tokenize(&fn_body).unwrap();
118104
let labelled_instructions = isa::parser::to_labelled_instructions(&instructions);
@@ -125,10 +111,7 @@ impl BasicSnippet for Map {
125111
(triton_asm!(call { snippet_name }), String::default())
126112
}
127113
InnerFunction::BasicSnippet(bs) => {
128-
assert!(
129-
bs.inputs().len().is_one(),
130-
"{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}"
131-
);
114+
assert_eq!(1, bs.inputs().len(), "{INNER_FN_INCORRECT_NUM_INPUTS}");
132115
let labelled_instructions = bs.annotated_code(library);
133116
let snippet_name =
134117
library.explicit_import(&bs.entrypoint(), &labelled_instructions);
@@ -142,40 +125,26 @@ impl BasicSnippet for Map {
142125
let write_to_output_list = output_type.write_value_to_memory_leave_pointer();
143126
let input_elem_size = input_type.stack_size();
144127
let output_elem_size = output_type.stack_size();
145-
let input_elem_size_plus_one = input_elem_size + 1;
146-
let output_elem_size_plus_one = output_type.stack_size() + 1;
147-
let minus_two_times_output_size = -(output_type.stack_size() as i32 * 2);
148128

149129
let mul_elem_size = |n| match n {
150130
0 => triton_asm!(pop 1 push 0),
151131
1 => triton_asm!(),
152-
n => triton_asm!(
153-
push {n}
154-
mul
155-
),
132+
n => triton_asm!(push {n} mul),
156133
};
157134

158135
let adjust_output_list_pointer = match output_elem_size {
159-
0 => triton_asm!(),
160-
1 => triton_asm!(),
161-
n => triton_asm!(
162-
push {-(n as i32 - 1)}
163-
add
164-
),
136+
0 | 1 => triton_asm!(),
137+
n => triton_asm!(addi {-(n as i32 - 1)}),
165138
};
166139

167140
let final_output_list_pointer_adjust = match output_elem_size {
168-
0 => triton_asm!(),
169-
1 => triton_asm!(),
170-
n => triton_asm!(
171-
push {n - 1}
172-
add
173-
),
141+
0 | 1 => triton_asm!(),
142+
n => triton_asm!(addi {n - 1}),
174143
};
175144

176145
triton_asm!(
177-
// BEFORE: _ <[additional_input_args]> *input_list
178-
// AFTER: _ <[additional_input_args]> *output_list
146+
// BEFORE: _ <[additional_input_args]> *input_list
147+
// AFTER: _ <[additional_input_args]> *output_list
179148
{entrypoint}:
180149
dup 0
181150
read_mem 1
@@ -219,10 +188,9 @@ impl BasicSnippet for Map {
219188
swap 1
220189
pop 1
221190

222-
223191
return
224192

225-
// INVARIANT: _ <aia> *end_condition_input_list *output_elem *input_elem
193+
// INVARIANT: _ <aia> *end_condition_input_list *output_elem *input_elem
226194
{main_loop}:
227195
// test return condition
228196
dup 2
@@ -233,26 +201,25 @@ impl BasicSnippet for Map {
233201

234202
dup 0
235203
{&read_from_input_list}
236-
// _ <aia> *end_condition_input_list *output_elem *input_elem [input_elem] *prev_input_elem
204+
// _ <aia> *end_condition_input_list *output_elem *input_elem [input_elem] *prev_input_elem
237205

238-
swap {input_elem_size_plus_one}
206+
swap {input_elem_size + 1}
239207
pop 1
240-
// _ <aia> *end_condition_input_list *output_elem *prev_input_elem [input_elem]
208+
// _ <aia> *end_condition_input_list *output_elem *prev_input_elem [input_elem]
241209

242210
// map
243211
{&call_inner_function}
244-
// _ <aia> *end_condition_input_list *output_elem *prev_input_elem [output_elem]
212+
// _ <aia> *end_condition_input_list *output_elem *prev_input_elem [output_elem]
245213

246214
// write
247-
dup {output_elem_size_plus_one}
248-
// _ <aia> *end_condition_input_list *output_elem *prev_input_elem [output_elem] *output_elem
215+
dup {output_type.stack_size() + 1}
216+
// _ <aia> *end_condition_input_list *output_elem *prev_input_elem [output_elem] *output_elem
249217

250218
{&write_to_output_list}
251-
// _ <aia> *end_condition_input_list *output_elem *prev_input_elem *next_output_elem
219+
// _ <aia> *end_condition_input_list *output_elem *prev_input_elem *next_output_elem
252220

253-
push {minus_two_times_output_size}
254-
add
255-
// _ <aia> *end_condition_input_list *output_elem *prev_input_elem *prev_output_elem
221+
addi {-(output_type.stack_size() as i32 * 2)}
222+
// _ <aia> *end_condition_input_list *output_elem *prev_input_elem *prev_output_elem
256223

257224
swap 2
258225
pop 1

0 commit comments

Comments
 (0)