@@ -21,38 +21,64 @@ struct AlignedStackBuf<const N: usize> {
2121 data : [ u8 ; N ] ,
2222}
2323
24+ /// SAFETY: Caller must ensure:
25+ /// - buffer and input are aligned to MIN_ALIGN
26+ /// - len is a multiple of MIN_ALIGN
27+ #[ cfg( target_os = "zkvm" ) ]
28+ #[ inline( always) ]
29+ unsafe fn native_xorin_unchecked ( buffer : * mut u8 , input : * const u8 , len : usize ) {
30+ __native_xorin ( buffer, input, len) ;
31+ }
32+
2433#[ cfg( target_os = "zkvm" ) ]
2534#[ no_mangle]
2635pub extern "C" fn native_xorin ( buffer : * mut u8 , input : * const u8 , len : usize ) {
2736 if len == 0 {
2837 return ;
2938 }
3039 unsafe {
31- let aligned_buffer;
32- let aligned_input;
40+ let buffer_aligned = buffer as usize % MIN_ALIGN == 0 ;
41+ let input_aligned = input as usize % MIN_ALIGN == 0 ;
42+ let len_aligned = len % MIN_ALIGN == 0 ;
43+ let all_aligned = buffer_aligned && input_aligned && len_aligned;
3344
34- let actual_buffer = if buffer as usize % MIN_ALIGN == 0 {
35- buffer
45+ if all_aligned {
46+ __native_xorin ( buffer, input , len ) ;
3647 } else {
37- aligned_buffer = AlignedBuf :: new ( buffer , len, MIN_ALIGN ) ;
38- aligned_buffer. ptr
39- } ;
48+ let adjusted_len = len. next_multiple_of ( MIN_ALIGN ) ;
49+ let aligned_buffer;
50+ let aligned_input ;
4051
41- let actual_input = if input as usize % MIN_ALIGN == 0 {
42- input
43- } else {
44- aligned_input = AlignedBuf :: new ( input, len, MIN_ALIGN ) ;
45- aligned_input. ptr
46- } ;
52+ let actual_buffer = if buffer_aligned && len_aligned {
53+ buffer
54+ } else {
55+ aligned_buffer = AlignedBuf :: new ( buffer, adjusted_len, MIN_ALIGN ) ;
56+ aligned_buffer. ptr
57+ } ;
58+
59+ let actual_input = if input_aligned && len_aligned {
60+ input
61+ } else {
62+ aligned_input = AlignedBuf :: new ( input, adjusted_len, MIN_ALIGN ) ;
63+ aligned_input. ptr
64+ } ;
4765
48- __native_xorin ( actual_buffer, actual_input, len ) ;
66+ __native_xorin ( actual_buffer, actual_input, adjusted_len ) ;
4967
50- if buffer as usize % MIN_ALIGN != 0 {
51- core:: ptr:: copy_nonoverlapping ( actual_buffer as * const u8 , buffer, len) ;
68+ if !buffer_aligned || !len_aligned {
69+ core:: ptr:: copy_nonoverlapping ( actual_buffer as * const u8 , buffer, len) ;
70+ }
5271 }
5372 }
5473}
5574
75+ /// SAFETY: Caller must ensure buffer is aligned to MIN_ALIGN
76+ #[ cfg( target_os = "zkvm" ) ]
77+ #[ inline( always) ]
78+ unsafe fn native_keccakf_unchecked ( buffer : * mut u8 ) {
79+ __native_keccakf ( buffer) ;
80+ }
81+
5682#[ cfg( target_os = "zkvm" ) ]
5783#[ no_mangle]
5884pub extern "C" fn native_keccakf ( buffer : * mut u8 ) {
@@ -86,20 +112,21 @@ pub extern "C" fn native_keccakf(buffer: *mut u8) {
86112#[ cfg( target_os = "zkvm" ) ]
87113#[ no_mangle]
88114pub extern "C" fn native_keccak256 ( bytes : * const u8 , len : usize , output : * mut u8 ) {
89- // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or
90- // `output` are not aligned to 4 bytes.
91115 unsafe {
116+ let bytes_aligned = bytes as usize % MIN_ALIGN == 0 ;
117+ let output_aligned = output as usize % MIN_ALIGN == 0 ;
118+
92119 let aligned_bytes;
93120 let aligned_output;
94121
95- let actual_bytes = if len == 0 || bytes as usize % MIN_ALIGN == 0 {
122+ let actual_bytes = if len == 0 || bytes_aligned {
96123 bytes
97124 } else {
98125 aligned_bytes = AlignedBuf :: new ( bytes, len, MIN_ALIGN ) ;
99126 aligned_bytes. ptr
100127 } ;
101128
102- let actual_output = if output as usize % MIN_ALIGN == 0 {
129+ let actual_output = if output_aligned {
103130 output
104131 } else {
105132 aligned_output = AlignedBuf :: uninit ( KECCAK_OUTPUT_SIZE , MIN_ALIGN ) ;
@@ -108,99 +135,62 @@ pub extern "C" fn native_keccak256(bytes: *const u8, len: usize, output: *mut u8
108135
109136 keccak256_impl ( actual_bytes, len, actual_output) ;
110137
111- if output as usize % MIN_ALIGN != 0 {
138+ if !output_aligned {
112139 core:: ptr:: copy_nonoverlapping ( actual_output as * const u8 , output, KECCAK_OUTPUT_SIZE ) ;
113140 }
114141 }
115142}
116143
144+ /// SAFETY: This function is only called from native_keccak256 which ensures:
145+ /// - input is aligned to MIN_ALIGN
146+ /// - output is aligned to MIN_ALIGN
147+ /// - All internal buffers are aligned by AlignedStackBuf
117148#[ cfg( target_os = "zkvm" ) ]
118149#[ inline( always) ]
119- fn keccak_update (
120- buffer : & mut AlignedStackBuf < KECCAK_WIDTH_BYTES > ,
121- input : * const u8 ,
122- len : usize ,
123- ) -> usize {
150+ unsafe fn keccak256_impl ( input : * const u8 , len : usize , output : * mut u8 ) {
151+ let mut buffer = AlignedStackBuf :: < KECCAK_WIDTH_BYTES > {
152+ data : [ 0u8 ; KECCAK_WIDTH_BYTES ] ,
153+ } ;
124154 let buffer_ptr = buffer. data . as_mut_ptr ( ) ;
155+
125156 let mut offset = 0 ;
126157 let mut remaining = len;
127- let input_aligned = input as usize % MIN_ALIGN == 0 ;
128158
129159 // Absorb full blocks
130160 while remaining >= KECCAK_RATE {
131- if input_aligned {
132- __native_xorin ( buffer_ptr, unsafe { input. add ( offset) } , KECCAK_RATE ) ;
133- } else {
134- let mut block = AlignedStackBuf :: < KECCAK_RATE > {
135- data : [ 0u8 ; KECCAK_RATE ] ,
136- } ;
137- unsafe {
138- core:: ptr:: copy_nonoverlapping (
139- input. add ( offset) ,
140- block. data . as_mut_ptr ( ) ,
141- KECCAK_RATE ,
142- ) ;
143- __native_xorin ( buffer_ptr, block. data . as_ptr ( ) , KECCAK_RATE ) ;
144- }
145- }
146- unsafe {
147- __native_keccakf ( buffer_ptr) ;
148- }
161+ native_xorin_unchecked ( buffer_ptr, input. add ( offset) , KECCAK_RATE ) ;
162+ native_keccakf_unchecked ( buffer_ptr) ;
149163 offset += KECCAK_RATE ;
150164 remaining -= KECCAK_RATE ;
151165 }
152166
153167 // Handle remaining bytes
154168 if remaining > 0 {
155- unsafe {
156- if input_aligned && remaining % MIN_ALIGN == 0 {
157- __native_xorin ( buffer_ptr, input. add ( offset) , remaining) ;
158- } else {
159- let adjusted_len = remaining. next_multiple_of ( MIN_ALIGN ) ;
160- let mut padded_input = AlignedStackBuf :: < KECCAK_RATE > {
161- data : [ 0u8 ; KECCAK_RATE ] ,
162- } ;
163- core:: ptr:: copy_nonoverlapping (
164- input. add ( offset) ,
165- padded_input. data . as_mut_ptr ( ) ,
166- remaining,
167- ) ;
168- __native_xorin ( buffer_ptr, padded_input. data . as_ptr ( ) , adjusted_len) ;
169- }
169+ if remaining % MIN_ALIGN == 0 {
170+ native_xorin_unchecked ( buffer_ptr, input. add ( offset) , remaining) ;
171+ } else {
172+ let adjusted_len = remaining. next_multiple_of ( MIN_ALIGN ) ;
173+ let mut padded_input = AlignedStackBuf :: < KECCAK_RATE > {
174+ data : [ 0u8 ; KECCAK_RATE ] ,
175+ } ;
176+ core:: ptr:: copy_nonoverlapping (
177+ input. add ( offset) ,
178+ padded_input. data . as_mut_ptr ( ) ,
179+ remaining,
180+ ) ;
181+ native_xorin_unchecked ( buffer_ptr, padded_input. data . as_ptr ( ) , adjusted_len) ;
170182 }
171183 }
172184
173- remaining
174- }
175-
176- #[ cfg( target_os = "zkvm" ) ]
177- #[ inline( always) ]
178- fn keccak_finalize (
179- buffer : & mut AlignedStackBuf < KECCAK_WIDTH_BYTES > ,
180- remaining_len : usize ,
181- output : * mut u8 ,
182- ) {
183185 // Apply Keccak padding (pad10*1)
184- buffer. data [ remaining_len ] ^= 0x01 ;
186+ buffer. data [ remaining ] ^= 0x01 ;
185187 buffer. data [ KECCAK_RATE - 1 ] ^= 0x80 ;
186188
187189 // Final permutation
188- unsafe {
189- __native_keccakf ( buffer. data . as_mut_ptr ( ) ) ;
190-
191- // Extract output
192- core:: ptr:: copy_nonoverlapping ( buffer. data . as_ptr ( ) , output, KECCAK_OUTPUT_SIZE ) ;
193- }
194- }
190+ native_keccakf_unchecked ( buffer_ptr) ;
195191
196- #[ cfg( target_os = "zkvm" ) ]
197- #[ inline( always) ]
198- fn keccak256_impl ( input : * const u8 , len : usize , output : * mut u8 ) {
199- let mut buffer = AlignedStackBuf :: < KECCAK_WIDTH_BYTES > {
200- data : [ 0u8 ; KECCAK_WIDTH_BYTES ] ,
201- } ;
202- let remaining_len = keccak_update ( & mut buffer, input, len) ;
203- keccak_finalize ( & mut buffer, remaining_len, output) ;
192+ // Extract output
193+ core:: ptr:: copy_nonoverlapping ( buffer. data . as_ptr ( ) , output, KECCAK_OUTPUT_SIZE ) ;
204194}
205195
206196#[ cfg( target_os = "zkvm" ) ]
0 commit comments