Skip to content

Commit 67ef366

Browse files
committed
feat!: Introduce helper struct StaticAllocation
Previously, `Library::kmalloc()` returned only one address: the address to which values could be written. The address from which reading could start needed to be computed manually by every entity wanting to read from a static allocation, leading to - a lot of duplicated code, - leaked implementation details about the static allocator, and - developer confusion regarding the address returned by the allocator. Now, the helper struct `StaticAllocation` resolves all the above issues. fix #119 BREAKING CHANGE: `Library::kmalloc()` used to return only the write address. Now, it returns a helper struct from which both write and read address can be queried.
1 parent d66afd0 commit 67ef366

File tree

12 files changed

+142
-127
lines changed

12 files changed

+142
-127
lines changed

tasm-lib/src/arithmetic/u64/div_mod_u64.rs

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ impl DeprecatedSnippet for DivModU64 {
7171
let sub_u32 = library.import(Box::new(Safesub));
7272
let leading_zeros_u64 = library.import(Box::new(LeadingZerosU64));
7373
let add_u32 = library.import(Box::new(Safeadd));
74-
let mem_address_for_spilled_divisor = library.kmalloc(2);
75-
let last_mem_address_for_spilled_divisor =
76-
mem_address_for_spilled_divisor + BFieldElement::one();
74+
let spilled_divisor_alloc = library.kmalloc(2);
7775

7876
// The below code has been compiled from a Rust implementation of an LLVM function
7977
// called `divmoddi4` that can do u64 divmod with only access to u32 bit divmod and
@@ -86,7 +84,7 @@ impl DeprecatedSnippet for DivModU64 {
8684
{entrypoint}:
8785
dup 1
8886
dup 1
89-
push {mem_address_for_spilled_divisor}
87+
push {spilled_divisor_alloc.write_address()}
9088
write_mem 2
9189
pop 1
9290
dup 3
@@ -103,15 +101,15 @@ impl DeprecatedSnippet for DivModU64 {
103101
call {and_u64}
104102
swap 1
105103
pop 1
106-
push {last_mem_address_for_spilled_divisor}
107-
read_mem 2
104+
push {spilled_divisor_alloc.read_address()}
105+
read_mem {spilled_divisor_alloc.num_words()}
108106
pop 1
109107
push 32
110108
call {shift_right_u64}
111109
swap 1
112110
pop 1
113-
push {last_mem_address_for_spilled_divisor}
114-
read_mem 2
111+
push {spilled_divisor_alloc.read_address()}
112+
read_mem {spilled_divisor_alloc.num_words()}
115113
pop 1
116114
push 00000000004294967295
117115
push 0
@@ -125,8 +123,8 @@ impl DeprecatedSnippet for DivModU64 {
125123
push 0
126124
dup 11
127125
dup 11
128-
push {last_mem_address_for_spilled_divisor}
129-
read_mem 2
126+
push {spilled_divisor_alloc.read_address()}
127+
read_mem {spilled_divisor_alloc.num_words()}
130128
pop 1
131129
dup 3
132130
dup 3
@@ -248,8 +246,8 @@ impl DeprecatedSnippet for DivModU64 {
248246
pop 1
249247
swap 7
250248
pop 1
251-
push {last_mem_address_for_spilled_divisor}
252-
read_mem 2
249+
push {spilled_divisor_alloc.read_address()}
250+
read_mem {spilled_divisor_alloc.num_words()}
253251
pop 1
254252
dup 5
255253
dup 5
@@ -266,8 +264,8 @@ impl DeprecatedSnippet for DivModU64 {
266264
pop 1
267265
dup 3
268266
dup 3
269-
push {last_mem_address_for_spilled_divisor}
270-
read_mem 2
267+
push {spilled_divisor_alloc.read_address()}
268+
read_mem {spilled_divisor_alloc.num_words()}
271269
pop 1
272270
dup 5
273271
dup 5
@@ -290,8 +288,8 @@ impl DeprecatedSnippet for DivModU64 {
290288
recurse
291289
_binop_Or_bool_bool_44_then:
292290
pop 1
293-
push {last_mem_address_for_spilled_divisor}
294-
read_mem 2
291+
push {spilled_divisor_alloc.read_address()}
292+
read_mem {spilled_divisor_alloc.num_words()}
295293
pop 1
296294
push 0
297295
push 1
@@ -311,8 +309,8 @@ impl DeprecatedSnippet for DivModU64 {
311309
_binop_Or_bool_bool_44_else:
312310
push 0
313311
push 0
314-
push {last_mem_address_for_spilled_divisor}
315-
read_mem 2
312+
push {spilled_divisor_alloc.read_address()}
313+
read_mem {spilled_divisor_alloc.num_words()}
316314
pop 1
317315
swap 3
318316
eq
@@ -322,8 +320,8 @@ impl DeprecatedSnippet for DivModU64 {
322320
push 0
323321
eq
324322
assert
325-
push {last_mem_address_for_spilled_divisor}
326-
read_mem 2
323+
push {spilled_divisor_alloc.read_address()}
324+
read_mem {spilled_divisor_alloc.num_words()}
327325
pop 1
328326
call {leading_zeros_u64}
329327
dup 2
@@ -393,8 +391,8 @@ impl DeprecatedSnippet for DivModU64 {
393391
dup 7
394392
push 0
395393
eq
396-
push {last_mem_address_for_spilled_divisor}
397-
read_mem 2
394+
push {spilled_divisor_alloc.read_address()}
395+
read_mem {spilled_divisor_alloc.num_words()}
398396
pop 1
399397
push 0
400398
push 1

tasm-lib/src/hashing/algebraic_hasher/sample_scalars_static_length_kmalloc.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@ impl BasicSnippet for SampleScalarsStaticLengthKMalloc {
7171
let entrypoint = self.entrypoint();
7272
let squeeze_repeatedly_static_number =
7373
library.import(Box::new(SqueezeRepeatedlyStaticNumber { num_squeezes }));
74-
let scalars_pointer = library.kmalloc(self.num_words_to_allocate());
74+
let scalars_pointer_alloc = library.kmalloc(self.num_words_to_allocate());
7575

7676
triton_asm!(
7777
{entrypoint}:
78-
push {scalars_pointer}
78+
push {scalars_pointer_alloc.write_address()}
7979
call {squeeze_repeatedly_static_number}
8080
return
8181
)

tasm-lib/src/hashing/merkle_root_from_xfes_wrapper.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,23 @@ impl BasicSnippet for MerkleRootFromXfesWrapper {
2626

2727
fn code(&self, library: &mut crate::library::Library) -> Vec<LabelledInstruction> {
2828
let entrypoint = self.entrypoint();
29-
let length_pointer = library.kmalloc(1);
29+
let list_length_alloc = library.kmalloc(1);
3030

31-
let pointer_for_node_memory = library.kmalloc(MAX_LENGTH_SUPPORTED * (Digest::LEN as u32));
31+
let node_memory_alloc = library.kmalloc(MAX_LENGTH_SUPPORTED * (Digest::LEN as u32));
3232

3333
let snippet_for_length_256 = MerkleRootFromXfesStaticSize {
3434
log2_length: 8,
35-
static_memory_pointer: pointer_for_node_memory,
35+
static_memory_pointer: node_memory_alloc.write_address(),
3636
};
3737
let snippet_for_length_256 = library.import(Box::new(snippet_for_length_256));
3838
let snippet_for_length_512 = MerkleRootFromXfesStaticSize {
3939
log2_length: 9,
40-
static_memory_pointer: pointer_for_node_memory,
40+
static_memory_pointer: node_memory_alloc.write_address(),
4141
};
4242
let snippet_for_length_512 = library.import(Box::new(snippet_for_length_512));
4343
let snippet_for_length_1024 = MerkleRootFromXfesStaticSize {
4444
log2_length: 10,
45-
static_memory_pointer: pointer_for_node_memory,
45+
static_memory_pointer: node_memory_alloc.write_address(),
4646
};
4747
let snippet_for_length_1024 = library.import(Box::new(snippet_for_length_1024));
4848

@@ -75,8 +75,8 @@ impl BasicSnippet for MerkleRootFromXfesWrapper {
7575
// _ (*xfes + 1) len
7676

7777
dup 0
78-
push {length_pointer}
79-
write_mem 1
78+
push {list_length_alloc.write_address()}
79+
write_mem {list_length_alloc.num_words()}
8080
pop 1
8181
// _ (*xfes + 1) len
8282

@@ -86,17 +86,17 @@ impl BasicSnippet for MerkleRootFromXfesWrapper {
8686
call {snippet_for_length_256}
8787
// _ ((*xfes + 1)|[root])
8888

89-
push {length_pointer}
90-
read_mem 1
89+
push {list_length_alloc.read_address()}
90+
read_mem {list_length_alloc.num_words()}
9191
pop 1
9292
push 512
9393
eq
9494
skiz
9595
call {snippet_for_length_512}
9696
// _ ((*xfes + 1)|[root])
9797

98-
push {length_pointer}
99-
read_mem 1
98+
push {list_length_alloc.read_address()}
99+
read_mem {list_length_alloc.num_words()}
100100
pop 1
101101
push 1024
102102
eq

tasm-lib/src/library.rs

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use std::collections::HashMap;
22

3+
use arbitrary::Arbitrary;
34
use itertools::Itertools;
5+
use num_traits::ConstOne;
46
use triton_vm::memory_layout::MemoryRegion;
57
use triton_vm::prelude::*;
68

@@ -30,6 +32,34 @@ pub struct Library {
3032
num_allocated_words: u32,
3133
}
3234

35+
/// Represents a [static memory allocation][kmalloc] within Triton VM.
36+
/// Both its location within Triton VM's memory and its size and are fix.
37+
///
38+
/// [kmalloc]: Library::kmalloc
39+
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
40+
pub struct StaticAllocation {
41+
write_address: BFieldElement,
42+
num_words: u32,
43+
}
44+
45+
impl StaticAllocation {
46+
/// The address from which the allocated memory can be read.
47+
pub fn read_address(&self) -> BFieldElement {
48+
let offset = bfe!(self.num_words) - BFieldElement::ONE;
49+
self.write_address() + offset
50+
}
51+
52+
/// The address to which the allocated memory can be written.
53+
pub fn write_address(&self) -> BFieldElement {
54+
self.write_address
55+
}
56+
57+
/// The number of words allocated in this memory block.
58+
pub fn num_words(&self) -> u32 {
59+
self.num_words
60+
}
61+
}
62+
3363
impl Default for Library {
3464
fn default() -> Self {
3565
Self::new()
@@ -119,19 +149,26 @@ impl Library {
119149
self.all_external_dependencies().concat()
120150
}
121151

122-
/// Statically allocate `num_words` words of memory. Panics if more static
123-
/// memory is required than what the capacity allows for.
124-
pub fn kmalloc(&mut self, num_words: u32) -> BFieldElement {
152+
/// Statically allocate `num_words` words of memory.
153+
///
154+
/// # Panics
155+
///
156+
/// Panics if
157+
/// - `num_words` is zero,
158+
/// - the total number of statically allocated words exceeds `u32::MAX`.
159+
pub fn kmalloc(&mut self, num_words: u32) -> StaticAllocation {
125160
assert!(num_words > 0, "must allocate a positive number of words");
126-
let address = STATIC_MEMORY_FIRST_ADDRESS
127-
- bfe!(self.num_allocated_words)
128-
- BFieldElement::new(num_words as u64 - 1);
161+
let write_address =
162+
STATIC_MEMORY_FIRST_ADDRESS - bfe!(self.num_allocated_words) - bfe!(num_words - 1);
129163
self.num_allocated_words = self
130164
.num_allocated_words
131165
.checked_add(num_words)
132166
.expect("Cannot allocate more that u32::MAX words through `kmalloc`.");
133167

134-
address
168+
StaticAllocation {
169+
write_address,
170+
num_words,
171+
}
135172
}
136173
}
137174

@@ -466,13 +503,13 @@ mod tests {
466503
const MINUS_TWO: BFieldElement = BFieldElement::new(BFieldElement::MAX - 1);
467504
let mut lib = Library::new();
468505

469-
let first_free_address = lib.kmalloc(1);
470-
assert_eq!(MINUS_TWO, first_free_address);
506+
let first_chunk = lib.kmalloc(1);
507+
assert_eq!(MINUS_TWO, first_chunk.write_address());
471508

472-
let second_free_address = lib.kmalloc(7);
473-
assert_eq!(-BFieldElement::new(9), second_free_address,);
509+
let second_chunk = lib.kmalloc(7);
510+
assert_eq!(-bfe!(9), second_chunk.write_address());
474511

475-
let third_free_address = lib.kmalloc(1000);
476-
assert_eq!(-BFieldElement::new(1009), third_free_address);
512+
let third_chunk = lib.kmalloc(1000);
513+
assert_eq!(-bfe!(1009), third_chunk.write_address());
477514
}
478515
}

tasm-lib/src/list/contains.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ impl BasicSnippet for Contains {
4141
let elem_size: u32 = self.element_type.stack_size().try_into().unwrap();
4242
let elem_size_plus_one = elem_size + 1;
4343

44-
let needle_pointer_write = library.kmalloc(self.element_type.stack_size() as u32);
45-
let needle_pointer_read = needle_pointer_write + bfe!(elem_size - 1);
44+
let needle_alloc = library.kmalloc(elem_size);
4645

4746
let loop_label = format!("{entrypoint}_loop");
4847

@@ -76,7 +75,7 @@ impl BasicSnippet for Contains {
7675
pop 1
7776
// _ 0 *list *list[i-1] [haystack]
7877

79-
push {needle_pointer_read}
78+
push {needle_alloc.read_address()}
8079
{&read_element}
8180
pop 1
8281
// _ 0 *list *list[i-1] [haystack] [needle]
@@ -102,7 +101,7 @@ impl BasicSnippet for Contains {
102101
// BEFORE: _ *list [value]
103102
// AFTER: match_found
104103
{entrypoint}:
105-
push {needle_pointer_write}
104+
push {needle_alloc.write_address()}
106105
{&write_element}
107106
pop 1
108107
// _ *list

tasm-lib/src/list/multiset_equality_u64s.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@ impl BasicSnippet for MultisetEqualityU64s {
3434
let hash_varlen = library.import(Box::new(HashVarlen));
3535
let compare_xfes = DataType::Xfe.compare();
3636

37-
let running_product_result_write_pointer =
38-
library.kmalloc(EXTENSION_DEGREE.try_into().unwrap());
39-
let running_product_result_read_pointer =
40-
running_product_result_write_pointer + bfe!(EXTENSION_DEGREE as u64 - 1);
37+
let running_product_result_alloc = library.kmalloc(EXTENSION_DEGREE.try_into().unwrap());
4138

4239
let compare_lengths = triton_asm!(
4340
// _ *a *b
@@ -172,8 +169,8 @@ impl BasicSnippet for MultisetEqualityU64s {
172169
// _ *a *b size [-indeterminate] *a *a [garbage; 2] [a_rp]
173170

174171
/* store result in static memory and cleanup stack */
175-
push {running_product_result_write_pointer}
176-
write_mem {EXTENSION_DEGREE}
172+
push {running_product_result_alloc.write_address()}
173+
write_mem {running_product_result_alloc.num_words()}
177174
pop 5
178175
// _ *a *b size [-indeterminate]
179176

@@ -214,8 +211,8 @@ impl BasicSnippet for MultisetEqualityU64s {
214211
pop 3
215212
// _ [b_rp]
216213

217-
push {running_product_result_read_pointer}
218-
read_mem {EXTENSION_DEGREE}
214+
push {running_product_result_alloc.read_address()}
215+
read_mem {running_product_result_alloc.num_words()}
219216
pop 1
220217
// _ [b_rp] [a_rp]
221218

tasm-lib/src/verifier/claim/new_recursive.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use num::Zero;
1+
use num_traits::ConstZero;
22
use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
33
use triton_vm::prelude::*;
44

@@ -36,14 +36,14 @@ impl BasicSnippet for NewRecursive {
3636
let entrypoint = self.entrypoint();
3737
let claim_size = Claim {
3838
program_digest: Digest::default(),
39-
input: vec![BFieldElement::zero(); self.input_size],
40-
output: vec![BFieldElement::zero(); self.output_size],
39+
input: vec![BFieldElement::ZERO; self.input_size],
40+
output: vec![BFieldElement::ZERO; self.output_size],
4141
}
4242
.encode()
4343
.len();
44-
let claim_pointer = library.kmalloc(claim_size.try_into().unwrap());
44+
let claim_alloc = library.kmalloc(claim_size.try_into().unwrap());
4545
const METADATA_SIZE_FOR_FIELD_WITH_VEC_VALUE: usize = 2;
46-
let output_field_pointer = claim_pointer;
46+
let output_field_pointer = claim_alloc.write_address();
4747
let output_field_size: u32 = (1 + self.output_size).try_into().unwrap();
4848
let input_field_pointer = output_field_pointer + bfe!(output_field_size + 1);
4949
let input_field_size: u32 = (1 + self.input_size).try_into().unwrap();
@@ -99,7 +99,7 @@ impl BasicSnippet for NewRecursive {
9999
{&write_digest_to_memory}
100100
// _
101101

102-
push {claim_pointer}
102+
push {claim_alloc.write_address()}
103103
// _ *claim
104104

105105
return

0 commit comments

Comments
 (0)