Skip to content

Commit 92d4478

Browse files
lispcyi-sun
authored andcommitted
feat: make keccak/sha256 memory access aligned to 4-byte (#1859)
picked from 1.4 branch. i think at least this PR will not make anything worse?
1 parent ca36de3 commit 92d4478

File tree

4 files changed

+157
-6
lines changed

4 files changed

+157
-6
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
extern crate alloc;
2+
3+
use alloc::alloc::{alloc, dealloc, handle_alloc_error, Layout};
4+
use core::ptr::NonNull;
5+
6+
/// Bytes allocated according to the given Layout
7+
pub struct AlignedBuf {
8+
pub ptr: *mut u8,
9+
pub layout: Layout,
10+
}
11+
12+
impl AlignedBuf {
13+
/// Allocate a new buffer whose start address is aligned to `align` bytes.
14+
/// *NOTE* if `len` is zero then a creates new `NonNull` that is dangling and 16-byte aligned.
15+
pub fn uninit(len: usize, align: usize) -> Self {
16+
let layout = Layout::from_size_align(len, align).unwrap();
17+
if layout.size() == 0 {
18+
return Self {
19+
ptr: NonNull::<u128>::dangling().as_ptr() as *mut u8,
20+
layout,
21+
};
22+
}
23+
// SAFETY: `len` is nonzero
24+
let ptr = unsafe { alloc(layout) };
25+
if ptr.is_null() {
26+
handle_alloc_error(layout);
27+
}
28+
AlignedBuf { ptr, layout }
29+
}
30+
31+
/// Allocate a new buffer whose start address is aligned to `align` bytes
32+
/// and copy the given data into it.
33+
///
34+
/// # Safety
35+
/// - `bytes` must not be null
36+
/// - `len` should not be zero
37+
///
38+
/// See [alloc]. In particular `data` should not be empty.
39+
pub unsafe fn new(bytes: *const u8, len: usize, align: usize) -> Self {
40+
let buf = Self::uninit(len, align);
41+
// SAFETY:
42+
// - src and dst are not null
43+
// - src and dst are allocated for size
44+
// - no alignment requirements on u8
45+
// - non-overlapping since ptr is newly allocated
46+
unsafe {
47+
core::ptr::copy_nonoverlapping(bytes, buf.ptr, len);
48+
}
49+
50+
buf
51+
}
52+
}
53+
54+
impl Drop for AlignedBuf {
55+
fn drop(&mut self) {
56+
if self.layout.size() != 0 {
57+
unsafe {
58+
dealloc(self.ptr, self.layout);
59+
}
60+
}
61+
}
62+
}

crates/toolchain/platform/src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
#![deny(rustdoc::broken_intra_doc_links)]
55
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
66

7-
#[cfg(all(feature = "rust-runtime", target_os = "zkvm"))]
7+
#[cfg(target_os = "zkvm")]
88
pub use openvm_custom_insn::{custom_insn_i, custom_insn_r};
9+
#[cfg(target_os = "zkvm")]
10+
pub mod alloc;
911
#[cfg(all(feature = "rust-runtime", target_os = "zkvm"))]
1012
pub mod heap;
1113
#[cfg(all(feature = "export-libm", target_os = "zkvm"))]
1214
mod libm_extern;
15+
1316
pub mod memory;
1417
pub mod print;
1518
#[cfg(feature = "rust-runtime")]
@@ -19,9 +22,6 @@ pub mod rust_rt;
1922
/// 4 bytes (i.e. 32 bits) as the zkVM is an implementation of the rv32im ISA.
2023
pub const WORD_SIZE: usize = core::mem::size_of::<u32>();
2124

22-
/// Size of a zkVM memory page.
23-
pub const PAGE_SIZE: usize = 1024;
24-
2525
/// Standard IO file descriptors for use with sys_read and sys_write.
2626
pub mod fileno {
2727
pub const STDIN: u32 = 0;

extensions/keccak256/guest/src/lib.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
#![no_std]
22

3+
#[cfg(target_os = "zkvm")]
4+
extern crate alloc;
5+
#[cfg(target_os = "zkvm")]
6+
use openvm_platform::alloc::AlignedBuf;
7+
38
/// This is custom-0 defined in RISC-V spec document
49
pub const OPCODE: u8 = 0x0b;
510
pub const KECCAK256_FUNCT3: u8 = 0b100;
@@ -21,6 +26,43 @@ pub const KECCAK256_FUNCT7: u8 = 0;
2126
#[inline(always)]
2227
#[no_mangle]
2328
pub extern "C" fn native_keccak256(bytes: *const u8, len: usize, output: *mut u8) {
29+
// SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or
30+
// `output` are not aligned to 4 bytes.
31+
const MIN_ALIGN: usize = 4;
32+
unsafe {
33+
if bytes as usize % MIN_ALIGN != 0 {
34+
let aligned_buff = AlignedBuf::new(bytes, len, MIN_ALIGN);
35+
if output as usize % MIN_ALIGN != 0 {
36+
let aligned_out = AlignedBuf::uninit(32, MIN_ALIGN);
37+
__native_keccak256(aligned_buff.ptr, len, aligned_out.ptr);
38+
core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32);
39+
} else {
40+
__native_keccak256(aligned_buff.ptr, len, output);
41+
}
42+
} else {
43+
if output as usize % MIN_ALIGN != 0 {
44+
let aligned_out = AlignedBuf::uninit(32, MIN_ALIGN);
45+
__native_keccak256(bytes, len, aligned_out.ptr);
46+
core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32);
47+
} else {
48+
__native_keccak256(bytes, len, output);
49+
}
50+
};
51+
}
52+
}
53+
54+
/// keccak256 intrinsic binding
55+
///
56+
/// # Safety
57+
///
58+
/// The VM accepts the preimage by pointer and length, and writes the
59+
/// 32-byte hash.
60+
/// - `bytes` must point to an input buffer at least `len` long.
61+
/// - `output` must point to a buffer that is at least 32-bytes long.
62+
/// - `bytes` and `output` must be 4-byte aligned.
63+
#[cfg(target_os = "zkvm")]
64+
#[inline(always)]
65+
fn __native_keccak256(bytes: *const u8, len: usize, output: *mut u8) {
2466
openvm_platform::custom_insn_r!(
2567
opcode = OPCODE,
2668
funct3 = KECCAK256_FUNCT3,

extensions/sha256/guest/src/lib.rs

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,69 @@
11
#![no_std]
22

3+
#[cfg(target_os = "zkvm")]
4+
use openvm_platform::alloc::AlignedBuf;
5+
36
/// This is custom-0 defined in RISC-V spec document
47
pub const OPCODE: u8 = 0x0b;
58
pub const SHA256_FUNCT3: u8 = 0b100;
69
pub const SHA256_FUNCT7: u8 = 0x1;
710

8-
/// zkvm native implementation of sha256
11+
/// Native hook for sha256
12+
///
913
/// # Safety
1014
///
1115
/// The VM accepts the preimage by pointer and length, and writes the
1216
/// 32-byte hash.
1317
/// - `bytes` must point to an input buffer at least `len` long.
1418
/// - `output` must point to a buffer that is at least 32-bytes long.
1519
///
16-
/// [`sha2-256`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf
20+
/// [`sha2`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf
1721
#[cfg(target_os = "zkvm")]
1822
#[inline(always)]
1923
#[no_mangle]
2024
pub extern "C" fn zkvm_sha256_impl(bytes: *const u8, len: usize, output: *mut u8) {
25+
// SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or
26+
// `output` are not aligned to 4 bytes.
27+
// The minimum alignment required for the input and output buffers
28+
const MIN_ALIGN: usize = 4;
29+
// The preferred alignment for the input buffer, since the input is read in chunks of 16 bytes
30+
const INPUT_ALIGN: usize = 16;
31+
// The preferred alignment for the output buffer, since the output is written in chunks of 32
32+
// bytes
33+
const OUTPUT_ALIGN: usize = 32;
34+
unsafe {
35+
if bytes as usize % MIN_ALIGN != 0 {
36+
let aligned_buff = AlignedBuf::new(bytes, len, INPUT_ALIGN);
37+
if output as usize % MIN_ALIGN != 0 {
38+
let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN);
39+
__native_sha256(aligned_buff.ptr, len, aligned_out.ptr);
40+
core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32);
41+
} else {
42+
__native_sha256(aligned_buff.ptr, len, output);
43+
}
44+
} else {
45+
if output as usize % MIN_ALIGN != 0 {
46+
let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN);
47+
__native_sha256(bytes, len, aligned_out.ptr);
48+
core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32);
49+
} else {
50+
__native_sha256(bytes, len, output);
51+
}
52+
};
53+
}
54+
}
55+
56+
/// sha256 intrinsic binding
57+
///
58+
/// # Safety
59+
///
60+
/// The VM accepts the preimage by pointer and length, and writes the
61+
/// 32-byte hash.
62+
/// - `bytes` must point to an input buffer at least `len` long.
63+
/// - `output` must point to a buffer that is at least 32-bytes long.
64+
/// - `bytes` and `output` must be 4-byte aligned.
65+
#[cfg(target_os = "zkvm")]
66+
#[inline(always)]
67+
fn __native_sha256(bytes: *const u8, len: usize, output: *mut u8) {
2168
openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA256_FUNCT3, funct7 = SHA256_FUNCT7, rd = In output, rs1 = In bytes, rs2 = In len);
2269
}

0 commit comments

Comments
 (0)