Skip to content

Commit 4df2bf9

Browse files
shuklaayushbranch-rebase-bot[bot]
authored andcommitted
fix: use tiny-keccak fork in new keccak example
1 parent 9d36e3a commit 4df2bf9

File tree

3 files changed

+97
-111
lines changed

3 files changed

+97
-111
lines changed

examples/new-keccak/Cargo.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@ edition = "2021"
77
members = []
88

99
[dependencies]
10+
tiny-keccak = { git = "https://github.com/openvm-org/tiny-keccak", branch = "perf/custom-xorin-keccak", features = ["keccak"] }
11+
hex-literal = "1.1.0"
12+
13+
[target.'cfg(target_os = "zkvm")'.dependencies]
1014
openvm = { git = "https://github.com/openvm-org/openvm.git", features = [
1115
"std",
1216
] }
13-
openvm-new-keccak256-guest = { path = "../../extensions/new-keccak256/guest" }
14-
hex-literal = "1.1.0"
1517

1618
[features]
1719
default = []
1820

1921
# remove this if copying example outside of monorepo
2022
[patch."https://github.com/openvm-org/openvm.git"]
2123
openvm = { path = "../../crates/toolchain/openvm" }
22-
openvm-keccak256 = { path = "../../guest-libs/keccak256" }
23-
24+
openvm-new-keccak256-guest = { path = "../../extensions/new-keccak256/guest" }

examples/new-keccak/src/main.rs

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
// [!region imports]
2-
// [!endregion imports]
2+
use hex_literal::hex;
3+
use tiny_keccak::{Hasher, Keccak};
34

4-
// [!region main]
55
#[cfg(target_os = "zkvm")]
6-
use hex_literal::hex;
76
use openvm as _;
7+
// [!endregion imports]
8+
9+
// [!region main]
810

911
/// Vector of test cases for Keccak-256 hash function.
1012
/// Each test case consists of (input_bytes, expected_hash_result).
11-
#[cfg(target_os = "zkvm")]
1213
const KECCAK_TEST_CASES: &[(&[u8], [u8; 32])] = &[
1314
(
1415
b"",
@@ -147,25 +148,19 @@ const KECCAK_TEST_CASES: &[(&[u8], [u8; 32])] = &[
147148
hex!("f5392ee04880a0bd1336f30ee79b5c014a90728bf29f422dabb4ae6bc972f30b"),
148149
),
149150
];
150-
// todo: call the forked tiny keccak library once that is updated instead of directly calling the
151-
// keccak256_guest native functions
152151
pub fn main() {
153-
#[cfg(target_os = "zkvm")]
154-
{
155-
for &(input, expected) in KECCAK_TEST_CASES {
156-
let mut input = input.to_vec();
157-
let mut output = [0u8; 32];
158-
openvm_new_keccak256_guest::native_keccak256(
159-
input.as_ptr(),
160-
input.len(),
161-
output.as_mut_ptr(),
162-
);
163-
assert_eq!(output, expected);
164-
}
152+
// Run the keccak tests for all targets
153+
for &(input, expected) in KECCAK_TEST_CASES {
154+
let mut output = [0u8; 32];
165155

166-
// let mut expected_output = [0u8; 32];
167-
// openvm_keccak256::keccak256(&buffer);
168-
// assert_eq!(output, expected_output);
156+
// Using tiny-keccak API
157+
let mut hasher = Keccak::v256();
158+
hasher.update(input);
159+
hasher.finalize(&mut output);
160+
161+
assert_eq!(output, expected);
169162
}
163+
164+
println!("All {} keccak256 test cases passed!", KECCAK_TEST_CASES.len());
170165
}
171166
// [!endregion main]

extensions/new-keccak256/guest/src/lib.rs

Lines changed: 76 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -21,38 +21,64 @@ struct AlignedStackBuf<const N: usize> {
2121
data: [u8; N],
2222
}
2323

24+
/// SAFETY: Caller must ensure:
25+
/// - buffer and input are aligned to MIN_ALIGN
26+
/// - len is a multiple of MIN_ALIGN
27+
#[cfg(target_os = "zkvm")]
28+
#[inline(always)]
29+
unsafe fn native_xorin_unchecked(buffer: *mut u8, input: *const u8, len: usize) {
30+
__native_xorin(buffer, input, len);
31+
}
32+
2433
#[cfg(target_os = "zkvm")]
2534
#[no_mangle]
2635
pub extern "C" fn native_xorin(buffer: *mut u8, input: *const u8, len: usize) {
2736
if len == 0 {
2837
return;
2938
}
3039
unsafe {
31-
let aligned_buffer;
32-
let aligned_input;
40+
let buffer_aligned = buffer as usize % MIN_ALIGN == 0;
41+
let input_aligned = input as usize % MIN_ALIGN == 0;
42+
let len_aligned = len % MIN_ALIGN == 0;
43+
let all_aligned = buffer_aligned && input_aligned && len_aligned;
3344

34-
let actual_buffer = if buffer as usize % MIN_ALIGN == 0 {
35-
buffer
45+
if all_aligned {
46+
__native_xorin(buffer, input, len);
3647
} else {
37-
aligned_buffer = AlignedBuf::new(buffer, len, MIN_ALIGN);
38-
aligned_buffer.ptr
39-
};
48+
let adjusted_len = len.next_multiple_of(MIN_ALIGN);
49+
let aligned_buffer;
50+
let aligned_input;
4051

41-
let actual_input = if input as usize % MIN_ALIGN == 0 {
42-
input
43-
} else {
44-
aligned_input = AlignedBuf::new(input, len, MIN_ALIGN);
45-
aligned_input.ptr
46-
};
52+
let actual_buffer = if buffer_aligned && len_aligned {
53+
buffer
54+
} else {
55+
aligned_buffer = AlignedBuf::new(buffer, adjusted_len, MIN_ALIGN);
56+
aligned_buffer.ptr
57+
};
58+
59+
let actual_input = if input_aligned && len_aligned {
60+
input
61+
} else {
62+
aligned_input = AlignedBuf::new(input, adjusted_len, MIN_ALIGN);
63+
aligned_input.ptr
64+
};
4765

48-
__native_xorin(actual_buffer, actual_input, len);
66+
__native_xorin(actual_buffer, actual_input, adjusted_len);
4967

50-
if buffer as usize % MIN_ALIGN != 0 {
51-
core::ptr::copy_nonoverlapping(actual_buffer as *const u8, buffer, len);
68+
if !buffer_aligned || !len_aligned {
69+
core::ptr::copy_nonoverlapping(actual_buffer as *const u8, buffer, len);
70+
}
5271
}
5372
}
5473
}
5574

75+
/// SAFETY: Caller must ensure buffer is aligned to MIN_ALIGN
76+
#[cfg(target_os = "zkvm")]
77+
#[inline(always)]
78+
unsafe fn native_keccakf_unchecked(buffer: *mut u8) {
79+
__native_keccakf(buffer);
80+
}
81+
5682
#[cfg(target_os = "zkvm")]
5783
#[no_mangle]
5884
pub extern "C" fn native_keccakf(buffer: *mut u8) {
@@ -86,20 +112,21 @@ pub extern "C" fn native_keccakf(buffer: *mut u8) {
86112
#[cfg(target_os = "zkvm")]
87113
#[no_mangle]
88114
pub extern "C" fn native_keccak256(bytes: *const u8, len: usize, output: *mut u8) {
89-
// SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or
90-
// `output` are not aligned to 4 bytes.
91115
unsafe {
116+
let bytes_aligned = bytes as usize % MIN_ALIGN == 0;
117+
let output_aligned = output as usize % MIN_ALIGN == 0;
118+
92119
let aligned_bytes;
93120
let aligned_output;
94121

95-
let actual_bytes = if len == 0 || bytes as usize % MIN_ALIGN == 0 {
122+
let actual_bytes = if len == 0 || bytes_aligned {
96123
bytes
97124
} else {
98125
aligned_bytes = AlignedBuf::new(bytes, len, MIN_ALIGN);
99126
aligned_bytes.ptr
100127
};
101128

102-
let actual_output = if output as usize % MIN_ALIGN == 0 {
129+
let actual_output = if output_aligned {
103130
output
104131
} else {
105132
aligned_output = AlignedBuf::uninit(KECCAK_OUTPUT_SIZE, MIN_ALIGN);
@@ -108,99 +135,62 @@ pub extern "C" fn native_keccak256(bytes: *const u8, len: usize, output: *mut u8
108135

109136
keccak256_impl(actual_bytes, len, actual_output);
110137

111-
if output as usize % MIN_ALIGN != 0 {
138+
if !output_aligned {
112139
core::ptr::copy_nonoverlapping(actual_output as *const u8, output, KECCAK_OUTPUT_SIZE);
113140
}
114141
}
115142
}
116143

144+
/// SAFETY: This function is only called from native_keccak256 which ensures:
145+
/// - input is aligned to MIN_ALIGN
146+
/// - output is aligned to MIN_ALIGN
147+
/// - All internal buffers are aligned by AlignedStackBuf
117148
#[cfg(target_os = "zkvm")]
118149
#[inline(always)]
119-
fn keccak_update(
120-
buffer: &mut AlignedStackBuf<KECCAK_WIDTH_BYTES>,
121-
input: *const u8,
122-
len: usize,
123-
) -> usize {
150+
unsafe fn keccak256_impl(input: *const u8, len: usize, output: *mut u8) {
151+
let mut buffer = AlignedStackBuf::<KECCAK_WIDTH_BYTES> {
152+
data: [0u8; KECCAK_WIDTH_BYTES],
153+
};
124154
let buffer_ptr = buffer.data.as_mut_ptr();
155+
125156
let mut offset = 0;
126157
let mut remaining = len;
127-
let input_aligned = input as usize % MIN_ALIGN == 0;
128158

129159
// Absorb full blocks
130160
while remaining >= KECCAK_RATE {
131-
if input_aligned {
132-
__native_xorin(buffer_ptr, unsafe { input.add(offset) }, KECCAK_RATE);
133-
} else {
134-
let mut block = AlignedStackBuf::<KECCAK_RATE> {
135-
data: [0u8; KECCAK_RATE],
136-
};
137-
unsafe {
138-
core::ptr::copy_nonoverlapping(
139-
input.add(offset),
140-
block.data.as_mut_ptr(),
141-
KECCAK_RATE,
142-
);
143-
__native_xorin(buffer_ptr, block.data.as_ptr(), KECCAK_RATE);
144-
}
145-
}
146-
unsafe {
147-
__native_keccakf(buffer_ptr);
148-
}
161+
native_xorin_unchecked(buffer_ptr, input.add(offset), KECCAK_RATE);
162+
native_keccakf_unchecked(buffer_ptr);
149163
offset += KECCAK_RATE;
150164
remaining -= KECCAK_RATE;
151165
}
152166

153167
// Handle remaining bytes
154168
if remaining > 0 {
155-
unsafe {
156-
if input_aligned && remaining % MIN_ALIGN == 0 {
157-
__native_xorin(buffer_ptr, input.add(offset), remaining);
158-
} else {
159-
let adjusted_len = remaining.next_multiple_of(MIN_ALIGN);
160-
let mut padded_input = AlignedStackBuf::<KECCAK_RATE> {
161-
data: [0u8; KECCAK_RATE],
162-
};
163-
core::ptr::copy_nonoverlapping(
164-
input.add(offset),
165-
padded_input.data.as_mut_ptr(),
166-
remaining,
167-
);
168-
__native_xorin(buffer_ptr, padded_input.data.as_ptr(), adjusted_len);
169-
}
169+
if remaining % MIN_ALIGN == 0 {
170+
native_xorin_unchecked(buffer_ptr, input.add(offset), remaining);
171+
} else {
172+
let adjusted_len = remaining.next_multiple_of(MIN_ALIGN);
173+
let mut padded_input = AlignedStackBuf::<KECCAK_RATE> {
174+
data: [0u8; KECCAK_RATE],
175+
};
176+
core::ptr::copy_nonoverlapping(
177+
input.add(offset),
178+
padded_input.data.as_mut_ptr(),
179+
remaining,
180+
);
181+
native_xorin_unchecked(buffer_ptr, padded_input.data.as_ptr(), adjusted_len);
170182
}
171183
}
172184

173-
remaining
174-
}
175-
176-
#[cfg(target_os = "zkvm")]
177-
#[inline(always)]
178-
fn keccak_finalize(
179-
buffer: &mut AlignedStackBuf<KECCAK_WIDTH_BYTES>,
180-
remaining_len: usize,
181-
output: *mut u8,
182-
) {
183185
// Apply Keccak padding (pad10*1)
184-
buffer.data[remaining_len] ^= 0x01;
186+
buffer.data[remaining] ^= 0x01;
185187
buffer.data[KECCAK_RATE - 1] ^= 0x80;
186188

187189
// Final permutation
188-
unsafe {
189-
__native_keccakf(buffer.data.as_mut_ptr());
190-
191-
// Extract output
192-
core::ptr::copy_nonoverlapping(buffer.data.as_ptr(), output, KECCAK_OUTPUT_SIZE);
193-
}
194-
}
190+
native_keccakf_unchecked(buffer_ptr);
195191

196-
#[cfg(target_os = "zkvm")]
197-
#[inline(always)]
198-
fn keccak256_impl(input: *const u8, len: usize, output: *mut u8) {
199-
let mut buffer = AlignedStackBuf::<KECCAK_WIDTH_BYTES> {
200-
data: [0u8; KECCAK_WIDTH_BYTES],
201-
};
202-
let remaining_len = keccak_update(&mut buffer, input, len);
203-
keccak_finalize(&mut buffer, remaining_len, output);
192+
// Extract output
193+
core::ptr::copy_nonoverlapping(buffer.data.as_ptr(), output, KECCAK_OUTPUT_SIZE);
204194
}
205195

206196
#[cfg(target_os = "zkvm")]

0 commit comments

Comments
 (0)