@@ -328,10 +328,13 @@ mod tests {
328328 impl Sysvar for TestSysvar { }
329329 impl SysvarSerialize for TestSysvar { }
330330
331- // NOTE tests that use this mock MUST carry the #[serial] attribute
331+ // NOTE: Tests using these mocks MUST carry the #[serial] attribute
332+ // because they modify global SYSCALL_STUBS state.
333+
332334 struct MockGetSysvarSyscall {
333335 data : Vec < u8 > ,
334336 }
337+
335338 impl SyscallStubs for MockGetSysvarSyscall {
336339 #[ allow( clippy:: arithmetic_side_effects) ]
337340 fn sol_get_sysvar (
@@ -346,12 +349,49 @@ mod tests {
346349 SUCCESS
347350 }
348351 }
352+
353+ /// Mock syscall stub for tests. Requires `#[serial]` attribute.
349354 pub fn mock_get_sysvar_syscall ( data : & [ u8 ] ) {
350355 set_syscall_stubs ( Box :: new ( MockGetSysvarSyscall {
351356 data : data. to_vec ( ) ,
352357 } ) ) ;
353358 }
354359
360+ struct ValidateIdSyscall {
361+ data : Vec < u8 > ,
362+ expected_id : Pubkey ,
363+ }
364+
365+ impl SyscallStubs for ValidateIdSyscall {
366+ #[ allow( clippy:: arithmetic_side_effects) ]
367+ fn sol_get_sysvar (
368+ & self ,
369+ sysvar_id_addr : * const u8 ,
370+ var_addr : * mut u8 ,
371+ offset : u64 ,
372+ length : u64 ,
373+ ) -> u64 {
374+ // Validate that the correct sysvar id pointer was passed
375+ let passed_id = unsafe { * ( sysvar_id_addr as * const Pubkey ) } ;
376+ assert_eq ! ( passed_id, self . expected_id) ;
377+
378+ let slice = unsafe { std:: slice:: from_raw_parts_mut ( var_addr, length as usize ) } ;
379+ slice. copy_from_slice ( & self . data [ offset as usize ..( offset + length) as usize ] ) ;
380+ SUCCESS
381+ }
382+ }
383+
384+ /// Mock syscall stub that validates sysvar ID. Requires `#[serial]` attribute.
385+ pub fn mock_get_sysvar_syscall_with_id (
386+ data : & [ u8 ] ,
387+ expected_id : & Pubkey ,
388+ ) -> Box < dyn SyscallStubs > {
389+ set_syscall_stubs ( Box :: new ( ValidateIdSyscall {
390+ data : data. to_vec ( ) ,
391+ expected_id : * expected_id,
392+ } ) )
393+ }
394+
355395 /// Convert a value to its in-memory byte representation.
356396 ///
357397 /// # Safety
0 commit comments