Skip to content

WIP - Speed improvements to resize convolution (no vpermps w/ FMA) #2793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions src/ImageSharp/Common/Helpers/Numerics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1080,4 +1080,79 @@ public static nuint Vector512Count<TVector>(this Span<float> span)
public static nuint Vector512Count<TVector>(int length)
where TVector : struct
=> (uint)length / (uint)Vector512<TVector>.Count;

/// <summary>
/// Normalizes the values in a given <see cref="Span{T}"/>.
/// </summary>
/// <param name="span">The sequence of <see cref="float"/> values to normalize.</param>
/// <param name="sum">The sum of the values in <paramref name="span"/>.</param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Normalize(Span<float> span, float sum)
{
if (Vector512.IsHardwareAccelerated)
{
ref float startRef = ref MemoryMarshal.GetReference(span);
ref float endRef = ref Unsafe.Add(ref startRef, span.Length & ~15);
Vector512<float> sum512 = Vector512.Create(sum);

while (Unsafe.IsAddressLessThan(ref startRef, ref endRef))
{
Unsafe.As<float, Vector512<float>>(ref startRef) /= sum512;
startRef = ref Unsafe.Add(ref startRef, (nuint)16);
}

if ((span.Length & 15) >= 8)
{
Unsafe.As<float, Vector256<float>>(ref startRef) /= sum512.GetLower();
startRef = ref Unsafe.Add(ref startRef, (nuint)8);
}

if ((span.Length & 7) >= 4)
{
Unsafe.As<float, Vector128<float>>(ref startRef) /= sum512.GetLower().GetLower();
startRef = ref Unsafe.Add(ref startRef, (nuint)4);
}

endRef = ref Unsafe.Add(ref startRef, span.Length & 3);

while (Unsafe.IsAddressLessThan(ref startRef, ref endRef))
{
startRef /= sum;
startRef = ref Unsafe.Add(ref startRef, (nuint)1);
}
}
else if (Vector256.IsHardwareAccelerated)
{
ref float startRef = ref MemoryMarshal.GetReference(span);
ref float endRef = ref Unsafe.Add(ref startRef, span.Length & ~7);
Vector256<float> sum256 = Vector256.Create(sum);

while (Unsafe.IsAddressLessThan(ref startRef, ref endRef))
{
Unsafe.As<float, Vector256<float>>(ref startRef) /= sum256;
startRef = ref Unsafe.Add(ref startRef, (nuint)8);
}

if ((span.Length & 7) >= 4)
{
Unsafe.As<float, Vector128<float>>(ref startRef) /= sum256.GetLower();
startRef = ref Unsafe.Add(ref startRef, (nuint)4);
}

endRef = ref Unsafe.Add(ref startRef, span.Length & 3);

while (Unsafe.IsAddressLessThan(ref startRef, ref endRef))
{
startRef /= sum;
startRef = ref Unsafe.Add(ref startRef, (nuint)1);
}
}
else
{
for (int i = 0; i < span.Length; i++)
{
span[i] /= sum;
}
}
}
}
1 change: 1 addition & 0 deletions src/ImageSharp/Common/Helpers/Vector512Utilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Six Labors Split License.

using System.Diagnostics.CodeAnalysis;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
Expand Down
191 changes: 128 additions & 63 deletions src/ImageSharp/Processing/Processors/Transforms/Resize/ResizeKernel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using SixLabors.ImageSharp.Common.Helpers;

namespace SixLabors.ImageSharp.Processing.Processors.Transforms;

Expand All @@ -14,11 +14,18 @@ namespace SixLabors.ImageSharp.Processing.Processors.Transforms;
/// </summary>
internal readonly unsafe struct ResizeKernel
{
/// <summary>
/// The buffer with the convolution factors.
/// Note that when FMA is supported, this is of size 4x that reported in <see cref="Length"/>.
/// </summary>
private readonly float* bufferPtr;

/// <summary>
/// Initializes a new instance of the <see cref="ResizeKernel"/> struct.
/// </summary>
/// <param name="startIndex">The starting index for the destination row.</param>
/// <param name="bufferPtr">The pointer to the buffer with the convolution factors.</param>
/// <param name="length">The length of the kernel.</param>
[MethodImpl(InliningOptions.ShortMethod)]
internal ResizeKernel(int startIndex, float* bufferPtr, int length)
{
Expand All @@ -27,6 +34,15 @@ internal ResizeKernel(int startIndex, float* bufferPtr, int length)
this.Length = length;
}

/// <summary>
/// Gets a value indicating whether vectorization is supported.
/// </summary>
public static bool IsHardwareAccelerated
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => Vector256.IsHardwareAccelerated;
}

/// <summary>
/// Gets the start index for the destination row.
/// </summary>
Expand All @@ -53,7 +69,15 @@ public int Length
public Span<float> Values
{
[MethodImpl(InliningOptions.ShortMethod)]
get => new(this.bufferPtr, this.Length);
get
{
if (Vector256.IsHardwareAccelerated)
{
return new(this.bufferPtr, this.Length * 4);
}

return new(this.bufferPtr, this.Length);
}
}

/// <summary>
Expand All @@ -68,73 +92,99 @@ public Vector4 Convolve(Span<Vector4> rowSpan)
[MethodImpl(InliningOptions.ShortMethod)]
public Vector4 ConvolveCore(ref Vector4 rowStartRef)
{
if (Avx2.IsSupported && Fma.IsSupported)
if (IsHardwareAccelerated)
{
float* bufferStart = this.bufferPtr;
float* bufferEnd = bufferStart + (this.Length & ~3);
Vector256<float> result256_0 = Vector256<float>.Zero;
Vector256<float> result256_1 = Vector256<float>.Zero;
ReadOnlySpan<byte> maskBytes = new byte[]
if (Vector512.IsHardwareAccelerated)
{
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 1, 0, 0, 0,
1, 0, 0, 0, 1, 0, 0, 0,
};
Vector256<int> mask = Unsafe.ReadUnaligned<Vector256<int>>(ref MemoryMarshal.GetReference(maskBytes));
float* bufferStart = this.bufferPtr;
ref Vector4 rowEndRef = ref Unsafe.Add(ref rowStartRef, this.Length & ~7);
Vector512<float> result512_0 = Vector512<float>.Zero;
Vector512<float> result512_1 = Vector512<float>.Zero;

while (bufferStart < bufferEnd)
{
// It is important to use a single expression here so that the JIT will correctly use vfmadd231ps
// for the FMA operation, and execute it directly on the target register and reading directly from
// memory for the first parameter. This skips initializing a SIMD register, and an extra copy.
// The code below should compile in the following assembly on .NET 5 x64:
//
// vmovsd xmm2, [rax] ; load *(double*)bufferStart into xmm2 as [ab, _]
// vpermps ymm2, ymm1, ymm2 ; permute as a float YMM register to [a, a, a, a, b, b, b, b]
// vfmadd231ps ymm0, ymm2, [r8] ; result256_0 = FMA(pixels, factors) + result256_0
//
// For tracking the codegen issue with FMA, see: https://github.com/dotnet/runtime/issues/12212.
// Additionally, we're also unrolling two computations per each loop iterations to leverage the
// fact that most CPUs have two ports to schedule multiply operations for FMA instructions.
result256_0 = Fma.MultiplyAdd(
Unsafe.As<Vector4, Vector256<float>>(ref rowStartRef),
Avx2.PermuteVar8x32(Vector256.CreateScalarUnsafe(*(double*)bufferStart).AsSingle(), mask),
result256_0);

result256_1 = Fma.MultiplyAdd(
Unsafe.As<Vector4, Vector256<float>>(ref Unsafe.Add(ref rowStartRef, 2)),
Avx2.PermuteVar8x32(Vector256.CreateScalarUnsafe(*(double*)(bufferStart + 2)).AsSingle(), mask),
result256_1);

bufferStart += 4;
rowStartRef = ref Unsafe.Add(ref rowStartRef, 4);
}
while (Unsafe.IsAddressLessThan(ref rowStartRef, ref rowEndRef))
{
Vector512<float> pixels512_0 = Unsafe.As<Vector4, Vector512<float>>(ref rowStartRef);
Vector512<float> pixels512_1 = Unsafe.As<Vector4, Vector512<float>>(ref Unsafe.Add(ref rowStartRef, (nuint)4));

result256_0 = Avx.Add(result256_0, result256_1);
result512_0 = Vector512_.MultiplyAdd(result512_0, Vector512.Load(bufferStart), pixels512_0);
result512_1 = Vector512_.MultiplyAdd(result512_1, Vector512.Load(bufferStart + 16), pixels512_1);

if ((this.Length & 3) >= 2)
{
result256_0 = Fma.MultiplyAdd(
Unsafe.As<Vector4, Vector256<float>>(ref rowStartRef),
Avx2.PermuteVar8x32(Vector256.CreateScalarUnsafe(*(double*)bufferStart).AsSingle(), mask),
result256_0);
bufferStart += 32;
rowStartRef = ref Unsafe.Add(ref rowStartRef, (nuint)8);
}

bufferStart += 2;
rowStartRef = ref Unsafe.Add(ref rowStartRef, 2);
}
result512_0 += result512_1;

Vector128<float> result128 = Sse.Add(result256_0.GetLower(), result256_0.GetUpper());
if ((this.Length & 7) >= 4)
{
Vector512<float> pixels512_0 = Unsafe.As<Vector4, Vector512<float>>(ref rowStartRef);
result512_0 = Vector512_.MultiplyAdd(result512_0, Vector512.Load(bufferStart), pixels512_0);

if ((this.Length & 1) != 0)
{
result128 = Fma.MultiplyAdd(
Unsafe.As<Vector4, Vector128<float>>(ref rowStartRef),
Vector128.Create(*bufferStart),
result128);
bufferStart += 16;
rowStartRef = ref Unsafe.Add(ref rowStartRef, (nuint)4);
}

Vector256<float> result256 = result512_0.GetLower() + result512_0.GetUpper();

if ((this.Length & 3) >= 2)
{
Vector256<float> pixels256_0 = Unsafe.As<Vector4, Vector256<float>>(ref rowStartRef);
result256 = Vector256_.MultiplyAdd(result256, Vector256.Load(bufferStart), pixels256_0);

bufferStart += 8;
rowStartRef = ref Unsafe.Add(ref rowStartRef, (nuint)2);
}

Vector128<float> result128 = result256.GetLower() + result256.GetUpper();

if ((this.Length & 1) != 0)
{
Vector128<float> pixels128 = Unsafe.As<Vector4, Vector128<float>>(ref rowStartRef);
result128 = Vector128_.MultiplyAdd(result128, Vector128.Load(bufferStart), pixels128);
}

return result128.AsVector4();
}
else
{
float* bufferStart = this.bufferPtr;
ref Vector4 rowEndRef = ref Unsafe.Add(ref rowStartRef, this.Length & ~3);
Vector256<float> result256_0 = Vector256<float>.Zero;
Vector256<float> result256_1 = Vector256<float>.Zero;

while (Unsafe.IsAddressLessThan(ref rowStartRef, ref rowEndRef))
{
Vector256<float> pixels256_0 = Unsafe.As<Vector4, Vector256<float>>(ref rowStartRef);
Vector256<float> pixels256_1 = Unsafe.As<Vector4, Vector256<float>>(ref Unsafe.Add(ref rowStartRef, (nuint)2));

result256_0 = Vector256_.MultiplyAdd(result256_0, Vector256.Load(bufferStart), pixels256_0);
result256_1 = Vector256_.MultiplyAdd(result256_1, Vector256.Load(bufferStart + 8), pixels256_1);

bufferStart += 16;
rowStartRef = ref Unsafe.Add(ref rowStartRef, (nuint)4);
}

result256_0 += result256_1;

if ((this.Length & 3) >= 2)
{
Vector256<float> pixels256_0 = Unsafe.As<Vector4, Vector256<float>>(ref rowStartRef);
result256_0 = Vector256_.MultiplyAdd(result256_0, Vector256.Load(bufferStart), pixels256_0);

bufferStart += 8;
rowStartRef = ref Unsafe.Add(ref rowStartRef, (nuint)2);
}

Vector128<float> result128 = result256_0.GetLower() + result256_0.GetUpper();

return *(Vector4*)&result128;
if ((this.Length & 1) != 0)
{
Vector128<float> pixels128 = Unsafe.As<Vector4, Vector128<float>>(ref rowStartRef);
result128 = Vector128_.MultiplyAdd(result128, Vector128.Load(bufferStart), pixels128);
}

return result128.AsVector4();
}
}
else
{
Expand All @@ -149,7 +199,7 @@ public Vector4 ConvolveCore(ref Vector4 rowStartRef)
result += rowStartRef * *bufferStart;

bufferStart++;
rowStartRef = ref Unsafe.Add(ref rowStartRef, 1);
rowStartRef = ref Unsafe.Add(ref rowStartRef, (nuint)1);
}

return result;
Expand All @@ -160,17 +210,32 @@ public Vector4 ConvolveCore(ref Vector4 rowStartRef)
/// Copy the contents of <see cref="ResizeKernel"/> altering <see cref="StartIndex"/>
/// to the value <paramref name="left"/>.
/// </summary>
/// <param name="left">The new value for <see cref="StartIndex"/>.</param>
[MethodImpl(InliningOptions.ShortMethod)]
internal ResizeKernel AlterLeftValue(int left)
=> new(left, this.bufferPtr, this.Length);

internal void Fill(Span<double> values)
internal void FillOrCopyAndExpand(Span<float> values)
{
DebugGuard.IsTrue(values.Length == this.Length, nameof(values), "ResizeKernel.Fill: values.Length != this.Length!");

for (int i = 0; i < this.Length; i++)
if (Vector256.IsHardwareAccelerated)
{
Vector4* bufferStart = (Vector4*)this.bufferPtr;
ref float valuesStart = ref MemoryMarshal.GetReference(values);
ref float valuesEnd = ref Unsafe.Add(ref valuesStart, values.Length);

while (Unsafe.IsAddressLessThan(ref valuesStart, ref valuesEnd))
{
*bufferStart = new Vector4(valuesStart);

bufferStart++;
valuesStart = ref Unsafe.Add(ref valuesStart, (nuint)1);
}
}
else
{
this.Values[i] = (float)values[i];
values.CopyTo(this.Values);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ protected internal override void Initialize<TResampler>(in TResampler sampler)
int bottomStartDest = this.DestinationLength - this.cornerInterval;
for (int i = startOfFirstRepeatedMosaic; i < bottomStartDest; i++)
{
double center = ((i + .5) * this.ratio) - .5;
float center = (float)(((i + .5) * this.ratio) - .5);
int left = (int)TolerantMath.Ceiling(center - this.radius);
ResizeKernel kernel = this.kernels[i - this.period];
this.kernels[i] = kernel.AlterLeftValue(left);
Expand Down
Loading
Loading