Skip to content

Commit 609313c

Browse files
authored
Merge pull request #18 from fslaborg/perf/optimize-outer-product-1760211579-b72b5d54-e75939c77af8ef4c
Perfect! The fix looks good. I will later add some additional manual test to be on the save side.
2 parents 1e1622b + d23d401 commit 609313c

File tree

3 files changed

+97
-12
lines changed

3 files changed

+97
-12
lines changed

src/FsMath/SpanMath.fs

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,9 @@ type SpanMath =
318318

319319

320320
// outer product #######
321-
321+
322322
/// Computes the outer product of two spans.
323+
/// Result[i,j] = u[i] * v[j]
323324
static member inline outerProduct<'T
324325
when 'T :> Numerics.INumber<'T>
325326
and 'T : struct
@@ -335,19 +336,36 @@ type SpanMath =
335336
let cols = v.Length
336337
let data = Array.zeroCreate<'T> (rows * cols)
337338

338-
for i = 0 to rows - 1 do
339-
let ui = u[i]
340-
for j = 0 to cols - 1 do
341-
let vSpan = v
342-
let simdCols = Numerics.Vector<'T>.Count
343-
let simdCount = cols / simdCols
344-
let ceiling = simdCount * simdCols
339+
if Numerics.Vector.IsHardwareAccelerated && cols >= Numerics.Vector<'T>.Count then
340+
// SIMD-accelerated path
341+
let simdWidth = Numerics.Vector<'T>.Count
342+
let simdCount = cols / simdWidth
343+
let scalarStart = simdCount * simdWidth
344+
345+
// Cast v to SIMD vectors once
346+
let vVec = MemoryMarshal.Cast<'T, Numerics.Vector<'T>>(v)
345347

346-
let vVec = MemoryMarshal.Cast<'T, Numerics.Vector<'T>>(v)
348+
for i = 0 to rows - 1 do
349+
let rowOffset = i * cols
350+
let rowSpan = data.AsSpan(rowOffset, cols)
351+
let rowVec = MemoryMarshal.Cast<'T, Numerics.Vector<'T>>(rowSpan)
347352

353+
// Broadcast u[i] to a SIMD vector
354+
let uBroadcast = Numerics.Vector<'T>(u[i])
355+
356+
// Process SIMD chunks
348357
for k = 0 to simdCount - 1 do
349-
let vi = Numerics.Vector<'T>(ui)
350-
let res = vi * vVec[k]
351-
res.CopyTo(MemoryMarshal.CreateSpan(&data.[i * cols + k * simdCols], simdCols))
358+
rowVec[k] <- uBroadcast * vVec[k]
359+
360+
// Process scalar tail
361+
for j = scalarStart to cols - 1 do
362+
data[rowOffset + j] <- u[i] * v[j]
363+
else
364+
// Scalar fallback
365+
for i = 0 to rows - 1 do
366+
let ui = u[i]
367+
let rowOffset = i * cols
368+
for j = 0 to cols - 1 do
369+
data[rowOffset + j] <- ui * v[j]
352370

353371
(rows, cols, data)

tests/FsMath.Tests/FsMath.Tests.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
<Compile Include="VectorOpsCoverageTests.fs" />
2626
<Compile Include="AlgTypesTopLevelOpsCoverageTests.fs" />
2727
<Compile Include="MatrixFloatTests.fs" />
28+
<Compile Include="MatrixOuterProductTests.fs" />
2829
<Compile Include="MatrixFormattingTests.fs" />
2930
<Compile Include="MatrixAdditionalTests.fs" />
3031
<Compile Include="MatrixEdgeCaseTests.fs" />
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
module MatrixOuterProductTests
2+
3+
open Xunit
4+
open FsMath
5+
6+
[<Fact>]
7+
let ``Outer product produces correct dimensions`` () =
8+
let u = [| 1.0; 2.0; 3.0 |]
9+
let v = [| 4.0; 5.0 |]
10+
let result = Matrix.outerProduct u v
11+
Assert.Equal(3, result.NumRows)
12+
Assert.Equal(2, result.NumCols)
13+
14+
[<Fact>]
15+
let ``Outer product computes correct values`` () =
16+
let u = [| 1.0; 2.0; 3.0 |]
17+
let v = [| 4.0; 5.0 |]
18+
let result = Matrix.outerProduct u v
19+
// Expected: [[1*4, 1*5], [2*4, 2*5], [3*4, 3*5]]
20+
// = [[4, 5], [8, 10], [12, 15]]
21+
Assert.Equal(4.0, result.[0, 0])
22+
Assert.Equal(5.0, result.[0, 1])
23+
Assert.Equal(8.0, result.[1, 0])
24+
Assert.Equal(10.0, result.[1, 1])
25+
Assert.Equal(12.0, result.[2, 0])
26+
Assert.Equal(15.0, result.[2, 1])
27+
28+
[<Fact>]
29+
let ``Outer product works with single element vectors`` () =
30+
let u = [| 3.0 |]
31+
let v = [| 7.0 |]
32+
let result = Matrix.outerProduct u v
33+
Assert.Equal(1, result.NumRows)
34+
Assert.Equal(1, result.NumCols)
35+
Assert.Equal(21.0, result.[0, 0])
36+
37+
[<Fact>]
38+
let ``Outer product works with larger vectors`` () =
39+
let u = [| 1.0; 2.0; 3.0; 4.0 |]
40+
let v = [| 10.0; 20.0; 30.0 |]
41+
let result = Matrix.outerProduct u v
42+
Assert.Equal(4, result.NumRows)
43+
Assert.Equal(3, result.NumCols)
44+
// Check a few values
45+
Assert.Equal(10.0, result.[0, 0]) // 1 * 10
46+
Assert.Equal(20.0, result.[0, 1]) // 1 * 20
47+
Assert.Equal(30.0, result.[0, 2]) // 1 * 30
48+
Assert.Equal(30.0, result.[2, 0]) // 3 * 10
49+
Assert.Equal(80.0, result.[3, 1]) // 4 * 20
50+
Assert.Equal(120.0, result.[3, 2]) // 4 * 30
51+
52+
[<Fact>]
53+
let ``Outer product with SIMD-friendly size`` () =
54+
// Size 16 ensures we use SIMD path on most systems (Vector<float>.Count is usually 4 or 8)
55+
let u = Array.init 10 (fun i -> float (i + 1))
56+
let v = Array.init 16 (fun i -> float (i + 1))
57+
let result = Matrix.outerProduct u v
58+
59+
Assert.Equal(10, result.NumRows)
60+
Assert.Equal(16, result.NumCols)
61+
62+
// Verify a few values
63+
Assert.Equal(1.0, result.[0, 0]) // 1 * 1
64+
Assert.Equal(16.0, result.[0, 15]) // 1 * 16
65+
Assert.Equal(50.0, result.[4, 9]) // 5 * 10
66+
Assert.Equal(160.0, result.[9, 15]) // 10 * 16

0 commit comments

Comments
 (0)