diff --git a/program-libs/hasher/src/keccak.rs b/program-libs/hasher/src/keccak.rs index 81d81d810c..ab1c666ee8 100644 --- a/program-libs/hasher/src/keccak.rs +++ b/program-libs/hasher/src/keccak.rs @@ -9,6 +9,8 @@ use crate::{ pub struct Keccak; impl Hasher for Keccak { + const ID: u8 = 2; + fn hash(val: &[u8]) -> Result { Self::hashv(&[val]) } diff --git a/program-libs/hasher/src/lib.rs b/program-libs/hasher/src/lib.rs index 9f4e4758c0..83a0875ae9 100644 --- a/program-libs/hasher/src/lib.rs +++ b/program-libs/hasher/src/lib.rs @@ -24,6 +24,7 @@ pub const HASH_BYTES: usize = 32; pub type Hash = [u8; HASH_BYTES]; pub trait Hasher { + const ID: u8; fn hash(val: &[u8]) -> Result; fn hashv(vals: &[&[u8]]) -> Result; fn zero_bytes() -> ZeroBytes; diff --git a/program-libs/hasher/src/poseidon.rs b/program-libs/hasher/src/poseidon.rs index 0cd6c670da..b13d4a6a83 100644 --- a/program-libs/hasher/src/poseidon.rs +++ b/program-libs/hasher/src/poseidon.rs @@ -78,6 +78,8 @@ impl From for u64 { pub struct Poseidon; impl Hasher for Poseidon { + const ID: u8 = 0; + fn hash(val: &[u8]) -> Result { Self::hashv(&[val]) } diff --git a/program-libs/hasher/src/sha256.rs b/program-libs/hasher/src/sha256.rs index 8a4b985a52..acf55cc21a 100644 --- a/program-libs/hasher/src/sha256.rs +++ b/program-libs/hasher/src/sha256.rs @@ -9,6 +9,7 @@ use crate::{ pub struct Sha256; impl Hasher for Sha256 { + const ID: u8 = 1; fn hash(val: &[u8]) -> Result { Self::hashv(&[val]) } diff --git a/sdk-libs/macros/src/discriminator.rs b/sdk-libs/macros/src/discriminator.rs index 1d289db888..0b1e3ea0ff 100644 --- a/sdk-libs/macros/src/discriminator.rs +++ b/sdk-libs/macros/src/discriminator.rs @@ -4,6 +4,14 @@ use quote::quote; use syn::{ItemStruct, Result}; pub(crate) fn discriminator(input: ItemStruct) -> Result { + discriminator_with_hasher(input, false) +} + +pub(crate) fn discriminator_sha(input: ItemStruct) -> Result { + discriminator_with_hasher(input, true) +} + +fn discriminator_with_hasher(input: ItemStruct, is_sha: bool) -> Result { let account_name = &input.ident; let (impl_gen, type_gen, where_clause) = input.generics.split_for_impl(); @@ -12,6 +20,10 @@ pub(crate) fn discriminator(input: ItemStruct) -> Result { discriminator.copy_from_slice(&Sha256::hash(account_name.to_string().as_bytes()).unwrap()[..8]); let discriminator: proc_macro2::TokenStream = format!("{discriminator:?}").parse().unwrap(); + // For SHA256 variant, we could add specific logic here if needed + // Currently both variants work the same way since discriminator is just based on struct name + let _variant_marker = if is_sha { "sha256" } else { "poseidon" }; + Ok(quote! { impl #impl_gen LightDiscriminator for #account_name #type_gen #where_clause { const LIGHT_DISCRIMINATOR: [u8; 8] = #discriminator; @@ -47,4 +59,46 @@ mod tests { assert!(output.contains("impl LightDiscriminator for MyAccount")); assert!(output.contains("[181 , 255 , 112 , 42 , 17 , 188 , 66 , 199]")); } + + #[test] + fn test_discriminator_sha() { + let input: ItemStruct = parse_quote! { + struct MyAccount { + a: u32, + b: i32, + c: u64, + d: i64, + } + }; + + let output = discriminator_sha(input).unwrap(); + let output = output.to_string(); + + assert!(output.contains("impl LightDiscriminator for MyAccount")); + assert!(output.contains("[181 , 255 , 112 , 42 , 17 , 188 , 66 , 199]")); + } + + #[test] + fn test_discriminator_sha_large_struct() { + // Test that SHA256 discriminator can handle large structs (that would fail with regular hasher) + let input: ItemStruct = parse_quote! { + struct LargeAccount { + pub field1: u64, pub field2: u64, pub field3: u64, pub field4: u64, + pub field5: u64, pub field6: u64, pub field7: u64, pub field8: u64, + pub field9: u64, pub field10: u64, pub field11: u64, pub field12: u64, + pub field13: u64, pub field14: u64, pub field15: u64, + pub owner: solana_program::pubkey::Pubkey, + pub authority: solana_program::pubkey::Pubkey, + } + }; + + let result = discriminator_sha(input); + assert!( + result.is_ok(), + "SHA256 discriminator should handle large structs" + ); + + let output = result.unwrap().to_string(); + assert!(output.contains("impl LightDiscriminator for LargeAccount")); + } } diff --git a/sdk-libs/macros/src/hasher/data_hasher.rs b/sdk-libs/macros/src/hasher/data_hasher.rs index 2486fdd4b7..7d27bdc619 100644 --- a/sdk-libs/macros/src/hasher/data_hasher.rs +++ b/sdk-libs/macros/src/hasher/data_hasher.rs @@ -37,7 +37,14 @@ pub(crate) fn generate_data_hasher_impl( slices[num_flattned_fields] = element.as_slice(); } - H::hashv(slices.as_slice()) + let mut result = H::hashv(slices.as_slice())?; + + // Apply field size truncation for non-Poseidon hashers + if H::ID != 0 { + result[0] = 0; + } + + Ok(result) } } } @@ -59,10 +66,50 @@ pub(crate) fn generate_data_hasher_impl( println!("DataHasher::hash inputs {:?}", debug_prints); } } - H::hashv(&[ + let mut result = H::hashv(&[ #(#data_hasher_assignments.as_slice(),)* - ]) + ])?; + + // Apply field size truncation for non-Poseidon hashers + if H::ID != 0 { + result[0] = 0; + } + + Ok(result) + } + } + } + }; + + Ok(hasher_impl) +} + +/// SHA256-specific DataHasher implementation that serializes the whole struct +pub(crate) fn generate_data_hasher_impl_sha( + struct_name: &syn::Ident, + generics: &syn::Generics, +) -> Result { + let (impl_gen, type_gen, where_clause) = generics.split_for_impl(); + + let hasher_impl = quote! { + impl #impl_gen ::light_hasher::DataHasher for #struct_name #type_gen #where_clause { + fn hash(&self) -> ::std::result::Result<[u8; 32], ::light_hasher::HasherError> + where + H: ::light_hasher::Hasher + { + use ::light_hasher::Hasher; + use borsh::BorshSerialize; + + // For SHA256, we serialize the whole struct and hash it in one go + let serialized = self.try_to_vec().map_err(|_| ::light_hasher::HasherError::BorshError)?; + let mut result = H::hash(&serialized)?; + + // Truncate field size for non-Poseidon hashers + if H::ID != 0 { + result[0] = 0; } + + Ok(result) } } }; diff --git a/sdk-libs/macros/src/hasher/input_validator.rs b/sdk-libs/macros/src/hasher/input_validator.rs index af57976b8d..0b2800e15a 100644 --- a/sdk-libs/macros/src/hasher/input_validator.rs +++ b/sdk-libs/macros/src/hasher/input_validator.rs @@ -60,6 +60,36 @@ pub(crate) fn validate_input(input: &ItemStruct) -> Result<()> { Ok(()) } +/// SHA256-specific validation - much more relaxed constraints +pub(crate) fn validate_input_sha(input: &ItemStruct) -> Result<()> { + // Check that we have a struct with named fields + match &input.fields { + Fields::Named(_) => (), + _ => { + return Err(Error::new_spanned( + input, + "Only structs with named fields are supported", + )) + } + }; + + // For SHA256, we don't limit field count or require specific attributes + // Just ensure flatten is not used (not implemented for SHA256 path) + let flatten_field_exists = input + .fields + .iter() + .any(|field| get_field_attribute(field) == FieldAttribute::Flatten); + + if flatten_field_exists { + return Err(Error::new_spanned( + input, + "Flatten attribute is not supported in SHA256 hasher.", + )); + } + + Ok(()) +} + /// Gets the primary attribute for a field (only one attribute can be active) pub(crate) fn get_field_attribute(field: &Field) -> FieldAttribute { if field.attrs.iter().any(|attr| attr.path().is_ident("hash")) { diff --git a/sdk-libs/macros/src/hasher/light_hasher.rs b/sdk-libs/macros/src/hasher/light_hasher.rs index 911cc35f73..fbb9da4271 100644 --- a/sdk-libs/macros/src/hasher/light_hasher.rs +++ b/sdk-libs/macros/src/hasher/light_hasher.rs @@ -3,10 +3,10 @@ use quote::quote; use syn::{Fields, ItemStruct, Result}; use crate::hasher::{ - data_hasher::generate_data_hasher_impl, + data_hasher::{generate_data_hasher_impl, generate_data_hasher_impl_sha}, field_processor::{process_field, FieldProcessingContext}, - input_validator::{get_field_attribute, validate_input, FieldAttribute}, - to_byte_array::generate_to_byte_array_impl, + input_validator::{get_field_attribute, validate_input, validate_input_sha, FieldAttribute}, + to_byte_array::{generate_to_byte_array_impl_sha, generate_to_byte_array_impl_with_hasher}, }; /// - ToByteArray: @@ -49,6 +49,33 @@ use crate::hasher::{ /// - Enums, References, SmartPointers: /// - Not supported pub(crate) fn derive_light_hasher(input: ItemStruct) -> Result { + derive_light_hasher_with_hasher(input, "e!(::light_hasher::Poseidon)) +} + +pub(crate) fn derive_light_hasher_sha(input: ItemStruct) -> Result { + // Use SHA256-specific validation (no field count limits) + validate_input_sha(&input)?; + + let generics = input.generics.clone(); + + let fields = match &input.fields { + Fields::Named(fields) => fields.clone(), + _ => unreachable!("Validation should have caught this"), + }; + + let field_count = fields.named.len(); + + let to_byte_array_impl = generate_to_byte_array_impl_sha(&input.ident, &generics, field_count)?; + let data_hasher_impl = generate_data_hasher_impl_sha(&input.ident, &generics)?; + + Ok(quote! { + #to_byte_array_impl + + #data_hasher_impl + }) +} + +fn derive_light_hasher_with_hasher(input: ItemStruct, hasher: &TokenStream) -> Result { // Validate the input structure validate_input(&input)?; @@ -74,8 +101,13 @@ pub(crate) fn derive_light_hasher(input: ItemStruct) -> Result { process_field(field, i, &mut context); }); - let to_byte_array_impl = - generate_to_byte_array_impl(&input.ident, &generics, field_count, &context)?; + let to_byte_array_impl = generate_to_byte_array_impl_with_hasher( + &input.ident, + &generics, + field_count, + &context, + hasher, + )?; let data_hasher_impl = generate_data_hasher_impl(&input.ident, &generics, &context)?; @@ -244,7 +276,7 @@ impl ::light_hasher::DataHasher for TruncateOptionStruct { #[cfg(debug_assertions)] { if std::env::var("RUST_BACKTRACE").is_ok() { - let debug_prints: Vec<[u8; 32]> = vec![ + let debug_prints: Vec<[u8;32]> = vec![ if let Some(a) = & self.a { let result = a.hash_to_field_size() ?; if result == [0u8; 32] { return Err(::light_hasher::errors::HasherError::OptionHashToFieldSizeZero); } @@ -405,4 +437,277 @@ impl ::light_hasher::DataHasher for OuterStruct { }; assert!(derive_light_hasher(input).is_ok()); } + + #[test] + fn test_sha256_large_struct_with_pubkeys() { + // Test that SHA256 can handle large structs with Pubkeys that would fail with Poseidon + // This struct has 15 fields including Pubkeys without #[hash] attribute + let input: ItemStruct = parse_quote! { + struct LargeAccountSha { + pub field1: u64, + pub field2: u64, + pub field3: u64, + pub field4: u64, + pub field5: u64, + pub field6: u64, + pub field7: u64, + pub field8: u64, + pub field9: u64, + pub field10: u64, + pub field11: u64, + pub field12: u64, + pub field13: u64, + // Pubkeys without #[hash] attribute - this would fail with Poseidon + pub owner: solana_program::pubkey::Pubkey, + pub authority: solana_program::pubkey::Pubkey, + } + }; + + // SHA256 should handle this fine + let sha_result = derive_light_hasher_sha(input.clone()); + assert!( + sha_result.is_ok(), + "SHA256 should handle large structs with Pubkeys" + ); + + // Regular Poseidon hasher should fail due to field count (>12) and Pubkey without #[hash] + let poseidon_result = derive_light_hasher(input); + assert!( + poseidon_result.is_err(), + "Poseidon should fail with >12 fields and unhashed Pubkeys" + ); + } + + #[test] + fn test_sha256_vs_poseidon_hashing_behavior() { + // Test a struct that both can handle to show the difference in hashing approach + let input: ItemStruct = parse_quote! { + struct TestAccount { + pub data: [u8; 31], + pub counter: u64, + } + }; + + // Both should succeed + let sha_result = derive_light_hasher_sha(input.clone()); + assert!(sha_result.is_ok()); + + let poseidon_result = derive_light_hasher(input); + assert!(poseidon_result.is_ok()); + + // Verify SHA256 implementation serializes whole struct + let sha_output = sha_result.unwrap(); + let sha_code = sha_output.to_string(); + + // SHA256 should use try_to_vec() for whole struct serialization (account for spaces) + assert!( + sha_code.contains("try_to_vec") && sha_code.contains("BorshSerialize"), + "SHA256 should serialize whole struct using try_to_vec. Actual code: {}", + sha_code + ); + assert!( + sha_code.contains("result [0] = 0") || sha_code.contains("result[0] = 0"), + "SHA256 should truncate first byte. Actual code: {}", + sha_code + ); + + // Poseidon should use field-by-field hashing + let poseidon_output = poseidon_result.unwrap(); + let poseidon_code = poseidon_output.to_string(); + + assert!( + poseidon_code.contains("to_byte_array") && poseidon_code.contains("as_slice"), + "Poseidon should use field-by-field hashing with to_byte_array. Actual code: {}", + poseidon_code + ); + } + + #[test] + fn test_sha256_no_field_limit() { + // Test that SHA256 doesn't enforce the 12-field limit + let input: ItemStruct = parse_quote! { + struct ManyFieldsStruct { + pub f1: u32, pub f2: u32, pub f3: u32, pub f4: u32, + pub f5: u32, pub f6: u32, pub f7: u32, pub f8: u32, + pub f9: u32, pub f10: u32, pub f11: u32, pub f12: u32, + pub f13: u32, pub f14: u32, pub f15: u32, pub f16: u32, + pub f17: u32, pub f18: u32, pub f19: u32, pub f20: u32, + } + }; + + // SHA256 should handle 20 fields without issue + let result = derive_light_hasher_sha(input); + assert!(result.is_ok(), "SHA256 should handle any number of fields"); + } + + #[test] + fn test_sha256_flatten_not_supported() { + // Test that SHA256 rejects flatten attribute (not implemented) + let input: ItemStruct = parse_quote! { + struct FlattenStruct { + #[flatten] + pub inner: InnerStruct, + pub data: u64, + } + }; + + let result = derive_light_hasher_sha(input); + assert!(result.is_err(), "SHA256 should reject flatten attribute"); + + let error_msg = result.unwrap_err().to_string(); + assert!( + error_msg.contains("not supported in SHA256"), + "Should mention SHA256 limitation" + ); + } + + #[test] + fn test_sha256_with_discriminator_integration() { + // Test that shows LightHasherSha works with LightDiscriminatorSha for large structs + // This would be impossible with regular Poseidon-based macros + let input: ItemStruct = parse_quote! { + struct LargeIntegratedAccount { + pub field1: u64, pub field2: u64, pub field3: u64, pub field4: u64, + pub field5: u64, pub field6: u64, pub field7: u64, pub field8: u64, + pub field9: u64, pub field10: u64, pub field11: u64, pub field12: u64, + pub field13: u64, pub field14: u64, pub field15: u64, pub field16: u64, + pub field17: u64, pub field18: u64, pub field19: u64, pub field20: u64, + // Pubkeys without #[hash] attribute + pub owner: solana_program::pubkey::Pubkey, + pub authority: solana_program::pubkey::Pubkey, + pub delegate: solana_program::pubkey::Pubkey, + } + }; + + // Both SHA256 hasher and discriminator should work + let sha_hasher_result = derive_light_hasher_sha(input.clone()); + assert!( + sha_hasher_result.is_ok(), + "SHA256 hasher should work with large structs" + ); + + let sha_discriminator_result = crate::discriminator::discriminator_sha(input.clone()); + assert!( + sha_discriminator_result.is_ok(), + "SHA256 discriminator should work with large structs" + ); + + // Regular Poseidon variants should fail + let poseidon_hasher_result = derive_light_hasher(input); + assert!( + poseidon_hasher_result.is_err(), + "Poseidon hasher should fail with large structs" + ); + + // Verify the generated code contains expected patterns + let sha_hasher_code = sha_hasher_result.unwrap().to_string(); + assert!( + sha_hasher_code.contains("try_to_vec"), + "Should use serialization approach" + ); + assert!( + sha_hasher_code.contains("BorshSerialize"), + "Should use Borsh serialization" + ); + + let sha_discriminator_code = sha_discriminator_result.unwrap().to_string(); + assert!( + sha_discriminator_code.contains("LightDiscriminator"), + "Should implement LightDiscriminator" + ); + assert!( + sha_discriminator_code.contains("LIGHT_DISCRIMINATOR"), + "Should provide discriminator constant" + ); + } + + #[test] + fn test_complete_sha256_ecosystem_practical_example() { + // Demonstrates a real-world scenario where SHA256 variants are essential + // This struct would be impossible with Poseidon due to: + // 1. >12 fields (23+ fields) + // 2. Multiple Pubkeys without #[hash] attribute + // 3. Large data structures + let input: ItemStruct = parse_quote! { + pub struct ComplexGameState { + // Game metadata (13 fields) + pub game_id: u64, + pub round: u32, + pub turn: u8, + pub phase: u8, + pub start_time: i64, + pub end_time: i64, + pub max_players: u8, + pub current_players: u8, + pub entry_fee: u64, + pub prize_pool: u64, + pub game_mode: u32, + pub difficulty: u8, + pub status: u8, + + // Player information (6 Pubkey fields - would require #[hash] with Poseidon) + pub creator: solana_program::pubkey::Pubkey, + pub winner: solana_program::pubkey::Pubkey, + pub current_player: solana_program::pubkey::Pubkey, + pub authority: solana_program::pubkey::Pubkey, + pub treasury: solana_program::pubkey::Pubkey, + pub program_id: solana_program::pubkey::Pubkey, + + // Game state data (4+ more fields) + pub board_state: [u8; 64], // Large array + pub player_scores: [u32; 8], // Array of scores + pub moves_history: [u16; 32], // Move history + pub special_flags: u32, + + // This gives us 23+ fields total - way beyond Poseidon's 12-field limit + } + }; + + // SHA256 variants should handle this complex struct effortlessly + let sha_hasher_result = derive_light_hasher_sha(input.clone()); + assert!( + sha_hasher_result.is_ok(), + "SHA256 hasher must handle complex real-world structs" + ); + + let sha_discriminator_result = crate::discriminator::discriminator_sha(input.clone()); + assert!( + sha_discriminator_result.is_ok(), + "SHA256 discriminator must handle complex real-world structs" + ); + + // Poseidon would fail with this struct + let poseidon_result = derive_light_hasher(input); + assert!( + poseidon_result.is_err(), + "Poseidon cannot handle structs with >12 fields and unhashed Pubkeys" + ); + + // Verify SHA256 generates efficient serialization-based code + let hasher_code = sha_hasher_result.unwrap().to_string(); + assert!( + hasher_code.contains("try_to_vec"), + "Should serialize entire struct efficiently" + ); + assert!( + hasher_code.contains("BorshSerialize"), + "Should use Borsh for serialization" + ); + assert!( + hasher_code.contains("result [0] = 0") || hasher_code.contains("result[0] = 0"), + "Should apply field size truncation. Actual code: {}", + hasher_code + ); + + // Verify discriminator works correctly + let discriminator_code = sha_discriminator_result.unwrap().to_string(); + assert!( + discriminator_code.contains("ComplexGameState"), + "Should target correct struct" + ); + assert!( + discriminator_code.contains("LIGHT_DISCRIMINATOR"), + "Should provide discriminator constant" + ); + } } diff --git a/sdk-libs/macros/src/hasher/mod.rs b/sdk-libs/macros/src/hasher/mod.rs index 5c81807edf..c2ebd8034e 100644 --- a/sdk-libs/macros/src/hasher/mod.rs +++ b/sdk-libs/macros/src/hasher/mod.rs @@ -4,4 +4,4 @@ mod input_validator; mod light_hasher; mod to_byte_array; -pub(crate) use light_hasher::derive_light_hasher; +pub(crate) use light_hasher::{derive_light_hasher, derive_light_hasher_sha}; diff --git a/sdk-libs/macros/src/hasher/to_byte_array.rs b/sdk-libs/macros/src/hasher/to_byte_array.rs index 27d49ae232..9cec46c117 100644 --- a/sdk-libs/macros/src/hasher/to_byte_array.rs +++ b/sdk-libs/macros/src/hasher/to_byte_array.rs @@ -4,11 +4,12 @@ use syn::Result; use crate::hasher::field_processor::FieldProcessingContext; -pub(crate) fn generate_to_byte_array_impl( +pub(crate) fn generate_to_byte_array_impl_with_hasher( struct_name: &syn::Ident, generics: &syn::Generics, field_count: usize, context: &FieldProcessingContext, + hasher: &TokenStream, ) -> Result { let (impl_gen, type_gen, where_clause) = generics.split_for_impl(); @@ -20,34 +21,70 @@ pub(crate) fn generate_to_byte_array_impl( Some(s) => s, None => &alt_res, }; - let field_assignment: TokenStream = syn::parse_str(str)?; - - // Create a token stream with the field_assignment and the import code - let mut hash_imports = proc_macro2::TokenStream::new(); - for code in &context.hash_to_field_size_code { - hash_imports.extend(code.clone()); - } + let content: TokenStream = str.parse().expect("Invalid generated code"); Ok(quote! { impl #impl_gen ::light_hasher::to_byte_array::ToByteArray for #struct_name #type_gen #where_clause { - const NUM_FIELDS: usize = #field_count; + const NUM_FIELDS: usize = 1; fn to_byte_array(&self) -> ::std::result::Result<[u8; 32], ::light_hasher::HasherError> { - #hash_imports - #field_assignment + use ::light_hasher::to_byte_array::ToByteArray; + use ::light_hasher::hash_to_field_size::HashToFieldSize; + #content } } }) } else { + let data_hasher_assignments = &context.data_hasher_assignments; Ok(quote! { impl #impl_gen ::light_hasher::to_byte_array::ToByteArray for #struct_name #type_gen #where_clause { const NUM_FIELDS: usize = #field_count; fn to_byte_array(&self) -> ::std::result::Result<[u8; 32], ::light_hasher::HasherError> { - ::light_hasher::DataHasher::hash::<::light_hasher::Poseidon>(self) - } + use ::light_hasher::to_byte_array::ToByteArray; + use ::light_hasher::hash_to_field_size::HashToFieldSize; + use ::light_hasher::Hasher; + let mut result = #hasher::hashv(&[ + #(#data_hasher_assignments.as_slice(),)* + ])?; + + // Truncate field size for non-Poseidon hashers + if #hasher::ID != 0 { + result[0] = 0; + } + Ok(result) + } } }) } } + +/// SHA256-specific ToByteArray implementation that serializes the whole struct +pub(crate) fn generate_to_byte_array_impl_sha( + struct_name: &syn::Ident, + generics: &syn::Generics, + field_count: usize, +) -> Result { + let (impl_gen, type_gen, where_clause) = generics.split_for_impl(); + + Ok(quote! { + impl #impl_gen ::light_hasher::to_byte_array::ToByteArray for #struct_name #type_gen #where_clause { + const NUM_FIELDS: usize = #field_count; + + fn to_byte_array(&self) -> ::std::result::Result<[u8; 32], ::light_hasher::HasherError> { + use borsh::BorshSerialize; + use ::light_hasher::Hasher; + + // For SHA256, we can serialize the whole struct and hash it in one go + let serialized = self.try_to_vec().map_err(|_| ::light_hasher::HasherError::BorshError)?; + let mut result = ::light_hasher::Sha256::hash(&serialized)?; + + // Truncate field size for non-Poseidon hashers + result[0] = 0; + + Ok(result) + } + } + }) +} diff --git a/sdk-libs/macros/src/lib.rs b/sdk-libs/macros/src/lib.rs index 324660c861..8cd83ecbcb 100644 --- a/sdk-libs/macros/src/lib.rs +++ b/sdk-libs/macros/src/lib.rs @@ -1,6 +1,7 @@ extern crate proc_macro; use accounts::{process_light_accounts, process_light_system_accounts}; -use hasher::derive_light_hasher; +use discriminator::{discriminator, discriminator_sha}; +use hasher::{derive_light_hasher, derive_light_hasher_sha}; use proc_macro::TokenStream; use syn::{parse_macro_input, DeriveInput, ItemMod, ItemStruct}; use traits::process_light_traits; @@ -135,7 +136,35 @@ pub fn light_traits_derive(input: TokenStream) -> TokenStream { #[proc_macro_derive(LightDiscriminator)] pub fn light_discriminator(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as ItemStruct); - discriminator::discriminator(input) + discriminator(input) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} + +/// SHA256 variant of the LightDiscriminator derive macro. +/// +/// This derive macro provides the same discriminator functionality as LightDiscriminator +/// but is designed to be used with SHA256-based hashing for consistency. +/// +/// ## Example +/// +/// ```ignore +/// use light_sdk::sha::{LightHasher, LightDiscriminator}; +/// +/// #[derive(LightHasher, LightDiscriminator)] +/// pub struct LargeGameState { +/// pub field1: u64, pub field2: u64, pub field3: u64, pub field4: u64, +/// pub field5: u64, pub field6: u64, pub field7: u64, pub field8: u64, +/// pub field9: u64, pub field10: u64, pub field11: u64, pub field12: u64, +/// pub field13: u64, pub field14: u64, pub field15: u64, +/// pub owner: Pubkey, +/// pub authority: Pubkey, +/// } +/// ``` +#[proc_macro_derive(LightDiscriminatorSha)] +pub fn light_discriminator_sha(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as ItemStruct); + discriminator_sha(input) .unwrap_or_else(|err| err.to_compile_error()) .into() } @@ -256,6 +285,32 @@ pub fn light_hasher(input: TokenStream) -> TokenStream { .into() } +/// SHA256 variant of the LightHasher derive macro. +/// +/// This derive macro automatically implements the `DataHasher` and `ToByteArray` traits +/// for structs, using SHA256 as the hashing algorithm instead of Poseidon. +/// +/// ## Example +/// +/// ```ignore +/// use light_sdk::sha::LightHasher; +/// +/// #[derive(LightHasher)] +/// pub struct GameState { +/// #[hash] +/// pub player: Pubkey, // Will be hashed to 31 bytes +/// pub level: u32, +/// } +/// ``` +#[proc_macro_derive(LightHasherSha, attributes(hash, skip))] +pub fn light_hasher_sha(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as ItemStruct); + + derive_light_hasher_sha(input) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} + /// Alias of `LightHasher`. #[proc_macro_derive(DataHasher, attributes(skip, hash))] pub fn data_hasher(input: TokenStream) -> TokenStream { diff --git a/sdk-libs/sdk/src/account.rs b/sdk-libs/sdk/src/account.rs index 8206696040..44d83c83f3 100644 --- a/sdk-libs/sdk/src/account.rs +++ b/sdk-libs/sdk/src/account.rs @@ -65,7 +65,10 @@ //! ``` // TODO: add example for manual hashing -use std::ops::{Deref, DerefMut}; +use std::{ + marker::PhantomData, + ops::{Deref, DerefMut}, +}; use light_compressed_account::{ compressed_account::PackedMerkleContext, @@ -76,22 +79,42 @@ use solana_pubkey::Pubkey; use crate::{ error::LightSdkError, - light_hasher::{DataHasher, Poseidon}, + light_hasher::{DataHasher, Hasher, Poseidon, Sha256}, AnchorDeserialize, AnchorSerialize, LightDiscriminator, }; +const DEFAULT_DATA_HASH: [u8; 32] = [0u8; 32]; + +pub trait Size { + fn size(&self) -> usize; +} + +pub type LightAccount<'a, A> = LightAccountInner<'a, Poseidon, A>; + +pub mod sha { + use super::*; + /// LightAccount variant that uses SHA256 hashing + pub type LightAccount<'a, A> = super::LightAccountInner<'a, Sha256, A>; +} + #[derive(Debug, PartialEq)] -pub struct LightAccount< +pub struct LightAccountInner< 'a, + H: Hasher, A: AnchorSerialize + AnchorDeserialize + LightDiscriminator + DataHasher + Default, > { owner: &'a Pubkey, pub account: A, account_info: CompressedAccountInfo, + should_remove_data: bool, + _hasher: PhantomData, } -impl<'a, A: AnchorSerialize + AnchorDeserialize + LightDiscriminator + DataHasher + Default> - LightAccount<'a, A> +impl< + 'a, + H: Hasher, + A: AnchorSerialize + AnchorDeserialize + LightDiscriminator + DataHasher + Default, + > LightAccountInner<'a, H, A> { pub fn new_init( owner: &'a Pubkey, @@ -111,6 +134,8 @@ impl<'a, A: AnchorSerialize + AnchorDeserialize + LightDiscriminator + DataHashe input: None, output: Some(output_account_info), }, + should_remove_data: false, + _hasher: PhantomData, } } @@ -120,7 +145,7 @@ impl<'a, A: AnchorSerialize + AnchorDeserialize + LightDiscriminator + DataHashe input_account: A, ) -> Result { let input_account_info = { - let input_data_hash = input_account.hash::()?; + let input_data_hash = input_account.hash::()?; let tree_info = input_account_meta.get_tree_info(); InAccountInfo { data_hash: input_data_hash, @@ -155,6 +180,8 @@ impl<'a, A: AnchorSerialize + AnchorDeserialize + LightDiscriminator + DataHashe input: Some(input_account_info), output: Some(output_account_info), }, + should_remove_data: false, + _hasher: PhantomData, }) } @@ -164,7 +191,7 @@ impl<'a, A: AnchorSerialize + AnchorDeserialize + LightDiscriminator + DataHashe input_account: A, ) -> Result { let input_account_info = { - let input_data_hash = input_account.hash::()?; + let input_data_hash = input_account.hash::()?; let tree_info = input_account_meta.get_tree_info(); InAccountInfo { data_hash: input_data_hash, @@ -179,6 +206,7 @@ impl<'a, A: AnchorSerialize + AnchorDeserialize + LightDiscriminator + DataHashe discriminator: A::LIGHT_DISCRIMINATOR, } }; + Ok(Self { owner, account: input_account, @@ -187,6 +215,8 @@ impl<'a, A: AnchorSerialize + AnchorDeserialize + LightDiscriminator + DataHashe input: Some(input_account_info), output: None, }, + should_remove_data: false, + _hasher: PhantomData, }) } @@ -237,18 +267,28 @@ impl<'a, A: AnchorSerialize + AnchorDeserialize + LightDiscriminator + DataHashe /// that should only be called once per instruction. pub fn to_account_info(mut self) -> Result { if let Some(output) = self.account_info.output.as_mut() { - output.data_hash = self.account.hash::()?; - output.data = self - .account - .try_to_vec() - .map_err(|_| LightSdkError::Borsh)?; + if self.should_remove_data { + // TODO: review security. + output.data_hash = DEFAULT_DATA_HASH; + } else { + output.data_hash = self.account.hash::()?; + if H::ID != 0 { + output.data_hash[0] = 0; + } + output.data = self + .account + .try_to_vec() + .map_err(|_| LightSdkError::Borsh)?; + } } Ok(self.account_info) } } -impl Deref - for LightAccount<'_, A> +impl< + H: Hasher, + A: AnchorSerialize + AnchorDeserialize + LightDiscriminator + DataHasher + Default, + > Deref for LightAccountInner<'_, H, A> { type Target = A; @@ -257,8 +297,10 @@ impl DerefMut - for LightAccount<'_, A> +impl< + H: Hasher, + A: AnchorSerialize + AnchorDeserialize + LightDiscriminator + DataHasher + Default, + > DerefMut for LightAccountInner<'_, H, A> { fn deref_mut(&mut self) -> &mut ::Target { &mut self.account diff --git a/sdk-libs/sdk/src/lib.rs b/sdk-libs/sdk/src/lib.rs index b8eef1be97..ad2f41c7da 100644 --- a/sdk-libs/sdk/src/lib.rs +++ b/sdk-libs/sdk/src/lib.rs @@ -103,6 +103,17 @@ /// Compressed account abstraction similar to anchor Account. pub mod account; +pub use account::LightAccount; + +/// SHA256-based variants +pub mod sha { + pub use light_sdk_macros::{ + LightDiscriminatorSha as LightDiscriminator, LightHasherSha as LightHasher, + }; + + pub use crate::account::sha::LightAccount; +} + /// Functions to derive compressed account addresses. pub mod address; /// Utilities to invoke the light-system-program via cpi. @@ -123,7 +134,8 @@ use borsh::{BorshDeserialize as AnchorDeserialize, BorshSerialize as AnchorSeria pub use light_account_checks::{self, discriminator::Discriminator as LightDiscriminator}; pub use light_hasher; pub use light_sdk_macros::{ - derive_light_cpi_signer, light_system_accounts, LightDiscriminator, LightHasher, LightTraits, + derive_light_cpi_signer, light_system_accounts, LightDiscriminator, LightDiscriminatorSha, + LightHasher, LightHasherSha, LightTraits, }; pub use light_sdk_types::constants; use solana_account_info::AccountInfo;