@@ -71,7 +71,7 @@ public Vector4 Convolve(Span<Vector4> rowSpan)
7171 public Vector4 ConvolveCore ( ref Vector4 rowStartRef )
7272 {
7373#if SUPPORTS_RUNTIME_INTRINSICS
74- if ( Avx2 . IsSupported )
74+ if ( Fma . IsSupported )
7575 {
7676 float * bufferStart = this . bufferPtr ;
7777 float * bufferEnd = bufferStart + ( this . Length & ~ 1 ) ;
@@ -80,11 +80,20 @@ public Vector4 ConvolveCore(ref Vector4 rowStartRef)
8080
8181 while ( bufferStart < bufferEnd )
8282 {
83- Vector256 < float > rowItem256 = Unsafe . As < Vector4 , Vector256 < float > > ( ref rowStartRef ) ;
84- Vector256 < float > bufferItem256 = Avx2 . PermuteVar8x32 ( Vector256 . Create ( * ( double * ) bufferStart ) . AsSingle ( ) , mask ) ;
85- Vector256 < float > multiply256 = Avx . Multiply ( rowItem256 , bufferItem256 ) ;
86-
87- result256 = Avx . Add ( multiply256 , result256 ) ;
83+ // It is important to use a single expression here so that the JIT will correctly use vfmadd231ps
84+ // for the FMA operation, and execute it directly on the target register and reading directly from
85+ // memory for the first parameter. This skips initializing a SIMD register, and an extra copy.
86+ // The code below should compile in the following assembly on .NET 5 x64:
87+ //
88+ // vmovsd xmm2, [rax] ; load *(double*)bufferStart into xmm2 as [ab, _]
89+ // vpermps ymm2, ymm1, ymm2 ; permute as a float YMM register to [a, a, a, a, b, b, b, b]
90+ // vfmadd231ps ymm0, ymm2, [r8] ; result256 = FMA(pixels, factors) + result256
91+ //
92+ // For tracking the codegen issue with FMA, see: https://github.com/dotnet/runtime/issues/12212.
93+ result256 = Fma . MultiplyAdd (
94+ Unsafe . As < Vector4 , Vector256 < float > > ( ref rowStartRef ) ,
95+ Avx2 . PermuteVar8x32 ( Vector256 . CreateScalarUnsafe ( * ( double * ) bufferStart ) . AsSingle ( ) , mask ) ,
96+ result256 ) ;
8897
8998 bufferStart += 2 ;
9099 rowStartRef = ref Unsafe . Add ( ref rowStartRef , 2 ) ;
@@ -94,11 +103,10 @@ public Vector4 ConvolveCore(ref Vector4 rowStartRef)
94103
95104 if ( ( this . Length & 1 ) != 0 )
96105 {
97- Vector128 < float > rowItem128 = Unsafe . As < Vector4 , Vector128 < float > > ( ref rowStartRef ) ;
98- var bufferItem128 = Vector128 . Create ( * bufferStart ) ;
99- Vector128 < float > multiply128 = Sse . Multiply ( rowItem128 , bufferItem128 ) ;
100-
101- result128 = Sse . Add ( multiply128 , result128 ) ;
106+ result128 = Fma . MultiplyAdd (
107+ Unsafe . As < Vector4 , Vector128 < float > > ( ref rowStartRef ) ,
108+ Vector128 . Create ( * bufferStart ) ,
109+ result128 ) ;
102110 }
103111
104112 return * ( Vector4 * ) & result128 ;
0 commit comments