@@ -235,78 +235,84 @@ impl CubicSplineKernelNeonF32 {
235235 }
236236}
237237
238- #[ cfg( all( target_arch = "aarch64" , target_feature = "neon" ) ) ]
239238#[ test]
239+ #[ cfg_attr(
240+ not( all( target_arch = "aarch64" , target_feature = "neon" ) ) ,
241+ ignore = "Skipped on non-aarch64 targets"
242+ ) ]
240243fn test_cubic_spline_kernel_neon ( ) {
241- use core:: arch:: aarch64:: * ;
244+ #[ cfg( all( target_arch = "aarch64" , target_feature = "neon" ) ) ]
245+ {
246+ use core:: arch:: aarch64:: * ;
242247
243- // Test a few representative compact support radii
244- let hs: [ f32 ; 3 ] = [ 0.025 , 0.1 , 2.0 ] ;
245- for & h in hs. iter ( ) {
246- let scalar = CubicSplineKernel :: new ( h) ;
247- let neon = CubicSplineKernelNeonF32 :: new ( h) ;
248-
249- // Sample radii from 0 to 2h (beyond support should be 0)
250- let n: usize = 1024 ;
251- let mut r0: f32 = 0.0 ;
252- let dr: f32 = ( 2.0 * h) / ( n as f32 ) ;
253-
254- for _chunk in 0 ..( n / 4 ) {
255- // Prepare 4 lanes of radii
256- let rs = [ r0, r0 + dr, r0 + 2.0 * dr, r0 + 3.0 * dr] ;
257- let r_vec = unsafe { vld1q_f32 ( rs. as_ptr ( ) ) } ;
258-
259- // Evaluate NEON and store back to array
260- let w_vec = unsafe { neon. evaluate ( r_vec) } ;
261- let mut w_neon = [ 0.0f32 ; 4 ] ;
262- unsafe { vst1q_f32 ( w_neon. as_mut_ptr ( ) , w_vec) } ;
263-
264- // Compare against scalar lane-wise
265- for lane in 0 ..4 {
266- let r_lane = rs[ lane] ;
267- let w_scalar = scalar. evaluate ( r_lane) ;
268- let diff = ( w_neon[ lane] - w_scalar) . abs ( ) ;
269-
270- // Absolute tolerance with mild relative component to be robust across scales
271- let tol = 5e-6_f32 . max ( 2e-5_f32 * w_scalar. abs ( ) ) ;
248+ // Test a few representative compact support radii
249+ let hs: [ f32 ; 3 ] = [ 0.025 , 0.1 , 2.0 ] ;
250+ for & h in hs. iter ( ) {
251+ let scalar = CubicSplineKernel :: new ( h) ;
252+ let neon = CubicSplineKernelNeonF32 :: new ( h) ;
253+
254+ // Sample radii from 0 to 2h (beyond support should be 0)
255+ let n: usize = 1024 ;
256+ let mut r0: f32 = 0.0 ;
257+ let dr: f32 = ( 2.0 * h) / ( n as f32 ) ;
258+
259+ for _chunk in 0 ..( n / 4 ) {
260+ // Prepare 4 lanes of radii
261+ let rs = [ r0, r0 + dr, r0 + 2.0 * dr, r0 + 3.0 * dr] ;
262+ let r_vec = unsafe { vld1q_f32 ( rs. as_ptr ( ) ) } ;
263+
264+ // Evaluate NEON and store back to array
265+ let w_vec = unsafe { neon. evaluate ( r_vec) } ;
266+ let mut w_neon = [ 0.0f32 ; 4 ] ;
267+ unsafe { vst1q_f32 ( w_neon. as_mut_ptr ( ) , w_vec) } ;
268+
269+ // Compare against scalar lane-wise
270+ for lane in 0 ..4 {
271+ let r_lane = rs[ lane] ;
272+ let w_scalar = scalar. evaluate ( r_lane) ;
273+ let diff = ( w_neon[ lane] - w_scalar) . abs ( ) ;
274+
275+ // Absolute tolerance with mild relative component to be robust across scales
276+ let tol = 5e-6_f32 . max ( 2e-5_f32 * w_scalar. abs ( ) ) ;
277+ assert ! (
278+ diff <= tol,
279+ "NEON kernel mismatch (h={}, r={}, lane={}): neon={}, scalar={}, diff={}, tol={}" ,
280+ h,
281+ r_lane,
282+ lane,
283+ w_neon[ lane] ,
284+ w_scalar,
285+ diff,
286+ tol
287+ ) ;
288+ }
289+
290+ r0 += 4.0 * dr;
291+ }
292+
293+ // Also check a couple of out-of-support points explicitly
294+ for & r in & [ h * 1.01 , h * 1.5 , h * 2.0 , h * 2.5 ] {
295+ let w_scalar = scalar. evaluate ( r) ;
296+ let w_neon = {
297+ let v = unsafe { vld1q_f32 ( [ r, r, r, r] . as_ptr ( ) ) } ;
298+ let w = unsafe { neon. evaluate ( v) } ;
299+ let mut tmp = [ 0.0f32 ; 4 ] ;
300+ unsafe { vst1q_f32 ( tmp. as_mut_ptr ( ) , w) } ;
301+ tmp[ 0 ]
302+ } ;
303+ let diff = ( w_neon - w_scalar) . abs ( ) ;
304+ let tol = 5e-6_f32 . max ( 1e-5_f32 * w_scalar. abs ( ) ) ;
272305 assert ! (
273306 diff <= tol,
274- "NEON kernel mismatch (h={}, r={}, lane ={}): neon={}, scalar={}, diff={}, tol={}" ,
307+ "NEON kernel mismatch outside support (h={}, r={}): neon={}, scalar={}, diff={}, tol={}" ,
275308 h,
276- r_lane,
277- lane,
278- w_neon[ lane] ,
309+ r,
310+ w_neon,
279311 w_scalar,
280312 diff,
281313 tol
282314 ) ;
283315 }
284-
285- r0 += 4.0 * dr;
286- }
287-
288- // Also check a couple of out-of-support points explicitly
289- for & r in & [ h * 1.01 , h * 1.5 , h * 2.0 , h * 2.5 ] {
290- let w_scalar = scalar. evaluate ( r) ;
291- let w_neon = {
292- let v = unsafe { vld1q_f32 ( [ r, r, r, r] . as_ptr ( ) ) } ;
293- let w = unsafe { neon. evaluate ( v) } ;
294- let mut tmp = [ 0.0f32 ; 4 ] ;
295- unsafe { vst1q_f32 ( tmp. as_mut_ptr ( ) , w) } ;
296- tmp[ 0 ]
297- } ;
298- let diff = ( w_neon - w_scalar) . abs ( ) ;
299- let tol = 5e-6_f32 . max ( 1e-5_f32 * w_scalar. abs ( ) ) ;
300- assert ! (
301- diff <= tol,
302- "NEON kernel mismatch outside support (h={}, r={}): neon={}, scalar={}, diff={}, tol={}" ,
303- h,
304- r,
305- w_neon,
306- w_scalar,
307- diff,
308- tol
309- ) ;
310316 }
311317 }
312318}
@@ -372,94 +378,104 @@ impl CubicSplineKernelAvxF32 {
372378 }
373379}
374380
375- #[ cfg( all(
376- any( target_arch = "x86_64" , target_arch = "x86" ) ,
377- target_feature = "avx2" ,
378- target_feature = "fma"
379- ) ) ]
380381#[ test]
382+ #[ cfg_attr(
383+ not( all(
384+ any( target_arch = "x86_64" , target_arch = "x86" ) ,
385+ target_feature = "avx2" ,
386+ target_feature = "fma"
387+ ) ) ,
388+ ignore = "Skipped on non-x86 targets"
389+ ) ]
381390fn test_cubic_spline_kernel_avx ( ) {
382- #[ cfg( target_arch = "x86" ) ]
383- use core:: arch:: x86:: * ;
384- #[ cfg( target_arch = "x86_64" ) ]
385- use core:: arch:: x86_64:: * ;
391+ #[ cfg( all(
392+ any( target_arch = "x86_64" , target_arch = "x86" ) ,
393+ target_feature = "avx2" ,
394+ target_feature = "fma"
395+ ) ) ]
396+ {
397+ #[ cfg( target_arch = "x86" ) ]
398+ use core:: arch:: x86:: * ;
399+ #[ cfg( target_arch = "x86_64" ) ]
400+ use core:: arch:: x86_64:: * ;
386401
387- // Test a few representative compact support radii
388- let hs: [ f32 ; 3 ] = [ 0.025 , 0.1 , 2.0 ] ;
389- for & h in hs. iter ( ) {
390- let scalar = CubicSplineKernel :: new ( h) ;
391- let avx = CubicSplineKernelAvxF32 :: new ( h) ;
392-
393- // Sample radii from 0 to 2h (beyond support should be 0)
394- let n: usize = 1024 ;
395- let mut r0: f32 = 0.0 ;
396- let dr: f32 = ( 2.0 * h) / ( n as f32 ) ;
397-
398- for _chunk in 0 ..( n / 8 ) {
399- // Prepare 8 lanes of radii
400- let rs = [
401- r0,
402- r0 + dr,
403- r0 + 2.0 * dr,
404- r0 + 3.0 * dr,
405- r0 + 4.0 * dr,
406- r0 + 5.0 * dr,
407- r0 + 6.0 * dr,
408- r0 + 7.0 * dr,
409- ] ;
410-
411- // Evaluate AVX and store back to array
412- let r_vec = unsafe { _mm256_loadu_ps ( rs. as_ptr ( ) ) } ;
413- let w_vec = unsafe { avx. evaluate ( r_vec) } ;
414- let mut w_avx = [ 0.0f32 ; 8 ] ;
415- unsafe { _mm256_storeu_ps ( w_avx. as_mut_ptr ( ) , w_vec) } ;
416-
417- // Compare against scalar lane-wise
418- for lane in 0 ..8 {
419- let r_lane = rs[ lane] ;
420- let w_scalar = scalar. evaluate ( r_lane) ;
421- let diff = ( w_avx[ lane] - w_scalar) . abs ( ) ;
422-
423- // Absolute tolerance with mild relative component to be robust across scales
402+ // Test a few representative compact support radii
403+ let hs: [ f32 ; 3 ] = [ 0.025 , 0.1 , 2.0 ] ;
404+ for & h in hs. iter ( ) {
405+ let scalar = CubicSplineKernel :: new ( h) ;
406+ let avx = CubicSplineKernelAvxF32 :: new ( h) ;
407+
408+ // Sample radii from 0 to 2h (beyond support should be 0)
409+ let n: usize = 1024 ;
410+ let mut r0: f32 = 0.0 ;
411+ let dr: f32 = ( 2.0 * h) / ( n as f32 ) ;
412+
413+ for _chunk in 0 ..( n / 8 ) {
414+ // Prepare 8 lanes of radii
415+ let rs = [
416+ r0,
417+ r0 + dr,
418+ r0 + 2.0 * dr,
419+ r0 + 3.0 * dr,
420+ r0 + 4.0 * dr,
421+ r0 + 5.0 * dr,
422+ r0 + 6.0 * dr,
423+ r0 + 7.0 * dr,
424+ ] ;
425+
426+ // Evaluate AVX and store back to array
427+ let r_vec = unsafe { _mm256_loadu_ps ( rs. as_ptr ( ) ) } ;
428+ let w_vec = unsafe { avx. evaluate ( r_vec) } ;
429+ let mut w_avx = [ 0.0f32 ; 8 ] ;
430+ unsafe { _mm256_storeu_ps ( w_avx. as_mut_ptr ( ) , w_vec) } ;
431+
432+ // Compare against scalar lane-wise
433+ for lane in 0 ..8 {
434+ let r_lane = rs[ lane] ;
435+ let w_scalar = scalar. evaluate ( r_lane) ;
436+ let diff = ( w_avx[ lane] - w_scalar) . abs ( ) ;
437+
438+ // Absolute tolerance with mild relative component to be robust across scales
439+ let tol = 1e-6_f32 . max ( 1e-5_f32 * w_scalar. abs ( ) ) ;
440+ assert ! (
441+ diff <= tol,
442+ "AVX kernel mismatch (h={}, r={}, lane={}): avx={}, scalar={}, diff={}, tol={}" ,
443+ h,
444+ r_lane,
445+ lane,
446+ w_avx[ lane] ,
447+ w_scalar,
448+ diff,
449+ tol
450+ ) ;
451+ }
452+
453+ r0 += 8.0 * dr;
454+ }
455+
456+ // Also check a couple of out-of-support points explicitly
457+ for & r in & [ h * 1.01 , h * 1.5 , h * 2.0 , h * 2.5 ] {
458+ let w_scalar = scalar. evaluate ( r) ;
459+ let w_avx = {
460+ let v = unsafe { _mm256_set1_ps ( r) } ;
461+ let w = unsafe { avx. evaluate ( v) } ;
462+ let mut tmp = [ 0.0f32 ; 8 ] ;
463+ unsafe { _mm256_storeu_ps ( tmp. as_mut_ptr ( ) , w) } ;
464+ tmp[ 0 ]
465+ } ;
466+ let diff = ( w_avx - w_scalar) . abs ( ) ;
424467 let tol = 1e-6_f32 . max ( 1e-5_f32 * w_scalar. abs ( ) ) ;
425468 assert ! (
426469 diff <= tol,
427- "AVX kernel mismatch (h={}, r={}, lane ={}): avx={}, scalar={}, diff={}, tol={}" ,
470+ "AVX kernel mismatch outside support (h={}, r={}): avx={}, scalar={}, diff={}, tol={}" ,
428471 h,
429- r_lane,
430- lane,
431- w_avx[ lane] ,
472+ r,
473+ w_avx,
432474 w_scalar,
433475 diff,
434476 tol
435477 ) ;
436478 }
437-
438- r0 += 8.0 * dr;
439- }
440-
441- // Also check a couple of out-of-support points explicitly
442- for & r in & [ h * 1.01 , h * 1.5 , h * 2.0 , h * 2.5 ] {
443- let w_scalar = scalar. evaluate ( r) ;
444- let w_avx = {
445- let v = unsafe { _mm256_set1_ps ( r) } ;
446- let w = unsafe { avx. evaluate ( v) } ;
447- let mut tmp = [ 0.0f32 ; 8 ] ;
448- unsafe { _mm256_storeu_ps ( tmp. as_mut_ptr ( ) , w) } ;
449- tmp[ 0 ]
450- } ;
451- let diff = ( w_avx - w_scalar) . abs ( ) ;
452- let tol = 1e-6_f32 . max ( 1e-5_f32 * w_scalar. abs ( ) ) ;
453- assert ! (
454- diff <= tol,
455- "AVX kernel mismatch outside support (h={}, r={}): avx={}, scalar={}, diff={}, tol={}" ,
456- h,
457- r,
458- w_avx,
459- w_scalar,
460- diff,
461- tol
462- ) ;
463479 }
464480 }
465481}
0 commit comments