1
1
use std:: collections:: HashMap ;
2
2
3
3
use itertools:: Itertools ;
4
- use num_traits:: One ;
5
4
use rand:: prelude:: * ;
6
5
use triton_vm:: isa:: parser:: tokenize;
7
6
use triton_vm:: prelude:: * ;
@@ -24,8 +23,10 @@ use crate::InitVmState;
24
23
25
24
use super :: inner_function:: InnerFunction ;
26
25
27
- const MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION : & str = "inner function in `map` \
28
- currently only works with *one* input element. Use a tuple data type to circumvent this.";
26
+ const INNER_FN_INCORRECT_NUM_INPUTS : & str = "Inner function in `map` only works with *one* \
27
+ input. Use a tuple as a workaround.";
28
+ const INNER_FN_INCORRECT_NUM_OUTPUTS : & str = "Inner function in `map` only works with *one* \
29
+ output. Use a tuple as a workaround.";
29
30
30
31
/// Applies a given function to every element of a list, and collects the new elements
31
32
/// into a new list.
@@ -41,43 +42,31 @@ impl Map {
41
42
42
43
impl BasicSnippet for Map {
43
44
fn inputs ( & self ) -> Vec < ( DataType , String ) > {
44
- match & self . f {
45
- InnerFunction :: BasicSnippet ( bs) => {
46
- assert ! (
47
- bs. inputs( ) . len( ) . is_one( ) ,
48
- "{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}"
49
- ) ;
50
- let element_type = & bs. inputs ( ) [ 0 ] . 0 ;
51
- vec ! [ (
52
- DataType :: List ( Box :: new( element_type. clone( ) ) ) ,
53
- "*input_list" . to_string( ) ,
54
- ) ]
55
- }
56
- _ => vec ! [ (
57
- DataType :: List ( Box :: new( DataType :: VoidPointer ) ) ,
58
- "*input_list" . to_string( ) ,
59
- ) ] ,
60
- }
45
+ let element_type = if let InnerFunction :: BasicSnippet ( snippet) = & self . f {
46
+ let [ ( ref element_type, _) ] = snippet. inputs ( ) [ ..] else {
47
+ panic ! ( "{INNER_FN_INCORRECT_NUM_INPUTS}" ) ;
48
+ } ;
49
+ element_type. to_owned ( )
50
+ } else {
51
+ DataType :: VoidPointer
52
+ } ;
53
+
54
+ let list_type = DataType :: List ( Box :: new ( element_type) ) ;
55
+ vec ! [ ( list_type, "*input_list" . to_string( ) ) ]
61
56
}
62
57
63
58
fn outputs ( & self ) -> Vec < ( DataType , String ) > {
64
- match & self . f {
65
- InnerFunction :: BasicSnippet ( bs) => {
66
- assert ! (
67
- bs. inputs( ) . len( ) . is_one( ) ,
68
- "{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}"
69
- ) ;
70
- let element_type = & bs. outputs ( ) [ 0 ] . 0 ;
71
- vec ! [ (
72
- DataType :: List ( Box :: new( element_type. clone( ) ) ) ,
73
- "*output_list" . to_string( ) ,
74
- ) ]
75
- }
76
- _ => vec ! [ (
77
- DataType :: List ( Box :: new( DataType :: VoidPointer ) ) ,
78
- "*output_list" . to_string( ) ,
79
- ) ] ,
80
- }
59
+ let element_type = if let InnerFunction :: BasicSnippet ( snippet) = & self . f {
60
+ let [ ( ref element_type, _) ] = snippet. outputs ( ) [ ..] else {
61
+ panic ! ( "{INNER_FN_INCORRECT_NUM_OUTPUTS}" ) ;
62
+ } ;
63
+ element_type. to_owned ( )
64
+ } else {
65
+ DataType :: VoidPointer
66
+ } ;
67
+
68
+ let list_type = DataType :: List ( Box :: new ( element_type) ) ;
69
+ vec ! [ ( list_type, "*output_list" . to_string( ) ) ]
81
70
}
82
71
83
72
fn entrypoint ( & self ) -> String {
@@ -109,10 +98,7 @@ impl BasicSnippet for Map {
109
98
}
110
99
}
111
100
InnerFunction :: DeprecatedSnippet ( sn) => {
112
- assert ! (
113
- sn. input_types( ) . len( ) . is_one( ) ,
114
- "{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}"
115
- ) ;
101
+ assert_eq ! ( 1 , sn. input_types( ) . len( ) , "{INNER_FN_INCORRECT_NUM_INPUTS}" ) ;
116
102
let fn_body = sn. function_code ( library) ;
117
103
let ( _, instructions) = tokenize ( & fn_body) . unwrap ( ) ;
118
104
let labelled_instructions = isa:: parser:: to_labelled_instructions ( & instructions) ;
@@ -125,10 +111,7 @@ impl BasicSnippet for Map {
125
111
( triton_asm ! ( call { snippet_name } ) , String :: default ( ) )
126
112
}
127
113
InnerFunction :: BasicSnippet ( bs) => {
128
- assert ! (
129
- bs. inputs( ) . len( ) . is_one( ) ,
130
- "{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}"
131
- ) ;
114
+ assert_eq ! ( 1 , bs. inputs( ) . len( ) , "{INNER_FN_INCORRECT_NUM_INPUTS}" ) ;
132
115
let labelled_instructions = bs. annotated_code ( library) ;
133
116
let snippet_name =
134
117
library. explicit_import ( & bs. entrypoint ( ) , & labelled_instructions) ;
@@ -142,40 +125,26 @@ impl BasicSnippet for Map {
142
125
let write_to_output_list = output_type. write_value_to_memory_leave_pointer ( ) ;
143
126
let input_elem_size = input_type. stack_size ( ) ;
144
127
let output_elem_size = output_type. stack_size ( ) ;
145
- let input_elem_size_plus_one = input_elem_size + 1 ;
146
- let output_elem_size_plus_one = output_type. stack_size ( ) + 1 ;
147
- let minus_two_times_output_size = -( output_type. stack_size ( ) as i32 * 2 ) ;
148
128
149
129
let mul_elem_size = |n| match n {
150
130
0 => triton_asm ! ( pop 1 push 0 ) ,
151
131
1 => triton_asm ! ( ) ,
152
- n => triton_asm ! (
153
- push { n}
154
- mul
155
- ) ,
132
+ n => triton_asm ! ( push { n} mul) ,
156
133
} ;
157
134
158
135
let adjust_output_list_pointer = match output_elem_size {
159
- 0 => triton_asm ! ( ) ,
160
- 1 => triton_asm ! ( ) ,
161
- n => triton_asm ! (
162
- push { -( n as i32 - 1 ) }
163
- add
164
- ) ,
136
+ 0 | 1 => triton_asm ! ( ) ,
137
+ n => triton_asm ! ( addi { -( n as i32 - 1 ) } ) ,
165
138
} ;
166
139
167
140
let final_output_list_pointer_adjust = match output_elem_size {
168
- 0 => triton_asm ! ( ) ,
169
- 1 => triton_asm ! ( ) ,
170
- n => triton_asm ! (
171
- push { n - 1 }
172
- add
173
- ) ,
141
+ 0 | 1 => triton_asm ! ( ) ,
142
+ n => triton_asm ! ( addi { n - 1 } ) ,
174
143
} ;
175
144
176
145
triton_asm ! (
177
- // BEFORE: _ <[additional_input_args]> *input_list
178
- // AFTER: _ <[additional_input_args]> *output_list
146
+ // BEFORE: _ <[additional_input_args]> *input_list
147
+ // AFTER: _ <[additional_input_args]> *output_list
179
148
{ entrypoint} :
180
149
dup 0
181
150
read_mem 1
@@ -219,10 +188,9 @@ impl BasicSnippet for Map {
219
188
swap 1
220
189
pop 1
221
190
222
-
223
191
return
224
192
225
- // INVARIANT: _ <aia> *end_condition_input_list *output_elem *input_elem
193
+ // INVARIANT: _ <aia> *end_condition_input_list *output_elem *input_elem
226
194
{ main_loop} :
227
195
// test return condition
228
196
dup 2
@@ -233,26 +201,25 @@ impl BasicSnippet for Map {
233
201
234
202
dup 0
235
203
{ & read_from_input_list}
236
- // _ <aia> *end_condition_input_list *output_elem *input_elem [input_elem] *prev_input_elem
204
+ // _ <aia> *end_condition_input_list *output_elem *input_elem [input_elem] *prev_input_elem
237
205
238
- swap { input_elem_size_plus_one }
206
+ swap { input_elem_size + 1 }
239
207
pop 1
240
- // _ <aia> *end_condition_input_list *output_elem *prev_input_elem [input_elem]
208
+ // _ <aia> *end_condition_input_list *output_elem *prev_input_elem [input_elem]
241
209
242
210
// map
243
211
{ & call_inner_function}
244
- // _ <aia> *end_condition_input_list *output_elem *prev_input_elem [output_elem]
212
+ // _ <aia> *end_condition_input_list *output_elem *prev_input_elem [output_elem]
245
213
246
214
// write
247
- dup { output_elem_size_plus_one }
248
- // _ <aia> *end_condition_input_list *output_elem *prev_input_elem [output_elem] *output_elem
215
+ dup { output_type . stack_size ( ) + 1 }
216
+ // _ <aia> *end_condition_input_list *output_elem *prev_input_elem [output_elem] *output_elem
249
217
250
218
{ & write_to_output_list}
251
- // _ <aia> *end_condition_input_list *output_elem *prev_input_elem *next_output_elem
219
+ // _ <aia> *end_condition_input_list *output_elem *prev_input_elem *next_output_elem
252
220
253
- push { minus_two_times_output_size}
254
- add
255
- // _ <aia> *end_condition_input_list *output_elem *prev_input_elem *prev_output_elem
221
+ addi { -( output_type. stack_size( ) as i32 * 2 ) }
222
+ // _ <aia> *end_condition_input_list *output_elem *prev_input_elem *prev_output_elem
256
223
257
224
swap 2
258
225
pop 1
0 commit comments