Skip to content

Commit b838693

Browse files
CommonOps_DSCC
- Speed up more mult variants
1 parent 09db214 commit b838693

File tree

14 files changed

+343
-165
lines changed

14 files changed

+343
-165
lines changed

change.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ Date format: year/month/day
2929
* Removed the functions below since they had a runtime complexity of O(N^2) relative to matrix size instead of O(N)
3030
- multTransA(S,S,S), multTransB(S,S,S), innerProductLower(S,S,S)
3131
- Thanks Florentin Dörre for first noticing the performance issue
32+
* Speed up multTransAB(S,D,D), multTransA(S,D,D), multTransB(S,D,D) by a large margin
3233
- DMatrixSparseCSC
33-
* If sorted a binary search is used to lookup rows. Thanks Florentin Dörre.
34+
* If sorted, a binary search is used to lookup rows. Thanks Florentin Dörre.
3435
- ReadMatrixCsv
3536
* Thanks DEDZTBH for fixing an indexing error when reading complex data types
3637
- Added Concurrent Algorithms

main/ejml-dsparse/benchmarks/src/org/ejml/sparse/csc/BenchmarkMatrixMultDense_DSCC.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
package org.ejml.sparse.csc;
2020

21+
import org.ejml.data.DGrowArray;
2122
import org.ejml.data.DMatrixRMaj;
2223
import org.ejml.data.DMatrixSparseCSC;
2324
import org.ejml.dense.row.RandomMatrices_DDRM;
@@ -47,27 +48,30 @@ public class BenchmarkMatrixMultDense_DSCC {
4748
@Param({"100000"})
4849
private int elementCount;
4950

50-
DMatrixSparseCSC A;
51+
DMatrixSparseCSC A,A_small;
5152
DMatrixRMaj B = new DMatrixRMaj(1, 1);
5253
DMatrixRMaj C = new DMatrixRMaj(1, 1);
5354

55+
DGrowArray work = new DGrowArray();
56+
5457
@Setup
5558
public void setup() {
5659
Random rand = new Random(2345);
5760
A = RandomMatrices_DSCC.rectangle(dimension, dimension, elementCount, rand);
61+
A_small = RandomMatrices_DSCC.rectangle(dimension/4, dimension/4, elementCount/4, rand);
5862
B = RandomMatrices_DDRM.rectangle(dimension, dimension, -1, 1, rand);
5963
C = B.create(dimension, dimension);
6064
}
6165

6266
@Benchmark public void mult() { CommonOps_DSCC.mult(A, B, C); }
6367
@Benchmark public void multAdd() { CommonOps_DSCC.multAdd(A, B, C); }
64-
@Benchmark public void multTransA() { CommonOps_DSCC.multTransA(A, B, C); }
65-
@Benchmark public void multAddTransA() { CommonOps_DSCC.multAddTransA(A, B, C); }
66-
@Benchmark public void multTransB() { CommonOps_DSCC.multTransB(A, B, C); }
67-
@Benchmark public void multAddTransB() { CommonOps_DSCC.multAddTransB(A, B, C); }
68+
@Benchmark public void multTransA() { CommonOps_DSCC.multTransA(A, B, C, work); }
69+
@Benchmark public void multAddTransA() { CommonOps_DSCC.multAddTransA(A, B, C, work); }
70+
@Benchmark public void multTransB() { CommonOps_DSCC.multTransB(A, B, C, work); }
71+
@Benchmark public void multAddTransB() { CommonOps_DSCC.multAddTransB(A, B, C, work); }
6872
@Benchmark public void multTransAB() { CommonOps_DSCC.multTransAB(A, B, C); }
6973
@Benchmark public void multAddTransAB() { CommonOps_DSCC.multAddTransAB(A, B, C); }
70-
@Benchmark public void invert() { CommonOps_DSCC.invert(A, C); }
74+
@Benchmark public void invert() { CommonOps_DSCC.invert(A_small, C); }
7175

7276
public static void main( String[] args ) throws RunnerException {
7377
Options opt = new OptionsBuilder()

main/ejml-dsparse/benchmarks/src/org/ejml/sparse/csc/BenchmarkMatrixMultDense_MT_DSCC.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public class BenchmarkMatrixMultDense_MT_DSCC {
4949
@Param({"100000"})
5050
private int elementCount;
5151

52-
DMatrixSparseCSC A;
52+
DMatrixSparseCSC A,A_small;
5353
DMatrixRMaj B = new DMatrixRMaj(1, 1);
5454
DMatrixRMaj C = new DMatrixRMaj(1, 1);
5555

@@ -59,19 +59,20 @@ public class BenchmarkMatrixMultDense_MT_DSCC {
5959
public void setup() {
6060
Random rand = new Random(2345);
6161
A = RandomMatrices_DSCC.rectangle(dimension, dimension, elementCount, rand);
62+
A_small = RandomMatrices_DSCC.rectangle(dimension/4, dimension/4, elementCount/4, rand);
6263
B = RandomMatrices_DDRM.rectangle(dimension, dimension, -1, 1, rand);
6364
C = B.create(dimension, dimension);
6465
}
6566

6667
@Benchmark public void mult() { CommonOps_MT_DSCC.mult(A, B, C, work); }
6768
@Benchmark public void multAdd() { CommonOps_MT_DSCC.multAdd(A, B, C, work); }
68-
@Benchmark public void multTransA() { CommonOps_MT_DSCC.multTransA(A, B, C); }
69-
@Benchmark public void multAddTransA() { CommonOps_MT_DSCC.multAddTransA(A, B, C); }
69+
@Benchmark public void multTransA() { CommonOps_MT_DSCC.multTransA(A, B, C, work); }
70+
@Benchmark public void multAddTransA() { CommonOps_MT_DSCC.multAddTransA(A, B, C, work); }
7071
@Benchmark public void multTransB() { CommonOps_MT_DSCC.multTransB(A, B, C, work); }
7172
@Benchmark public void multAddTransB() { CommonOps_MT_DSCC.multAddTransB(A, B, C, work); }
72-
// @Benchmark public void multTransAB() { CommonOps_MT_DSCC.multTransAB(A, B, C); }
73-
// @Benchmark public void multAddTransAB() { CommonOps_MT_DSCC.multAddTransAB(A, B, C); }
74-
// @Benchmark public void invert() { CommonOps_MT_DSCC.invert(A, C); }
73+
@Benchmark public void multTransAB() { CommonOps_MT_DSCC.multTransAB(A, B, C); }
74+
@Benchmark public void multAddTransAB() { CommonOps_MT_DSCC.multAddTransAB(A, B, C); }
75+
// @Benchmark public void invert() { CommonOps_MT_DSCC.invert(A_small, C); }
7576

7677
public static void main( String[] args ) throws RunnerException {
7778
Options opt = new OptionsBuilder()

main/ejml-dsparse/src/org/ejml/sparse/csc/CommonOps_DSCC.java

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,27 +188,35 @@ public static void multAdd( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outpu
188188
* @param B Dense Matrix
189189
* @param outputC Dense Matrix
190190
*/
191-
public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC ) {
191+
public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC,
192+
@Nullable DGrowArray work ) {
192193
if (A.numRows != B.numRows)
193194
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
194195

196+
if (work == null)
197+
work = new DGrowArray();
198+
195199
outputC = reshapeOrDeclare(outputC, A.numCols, B.numCols);
196200

197-
ImplMultiplication_DSCC.multTransA(A, B, outputC);
201+
ImplMultiplication_DSCC.multTransA(A, B, outputC, work);
198202

199203
return outputC;
200204
}
201205

202206
/**
203207
* <p>C = C + A<sup>T</sup>*B</p>
204208
*/
205-
public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) {
209+
public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC,
210+
@Nullable DGrowArray work ) {
206211
if (A.numRows != B.numRows)
207212
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
208213
if (A.numCols != outputC.numRows || B.numCols != outputC.numCols)
209214
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC));
210215

211-
ImplMultiplication_DSCC.multAddTransA(A, B, outputC);
216+
if (work == null)
217+
work = new DGrowArray();
218+
219+
ImplMultiplication_DSCC.multAddTransA(A, B, outputC, work);
212220
}
213221

214222
/**
@@ -218,26 +226,34 @@ public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj
218226
* @param B Dense Matrix
219227
* @param outputC Dense Matrix
220228
*/
221-
public static DMatrixRMaj multTransB( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC ) {
229+
public static DMatrixRMaj multTransB( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC,
230+
@Nullable DGrowArray work ) {
222231
if (A.numCols != B.numCols)
223232
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
224233
outputC = reshapeOrDeclare(outputC, A.numRows, B.numRows);
225234

226-
ImplMultiplication_DSCC.multTransB(A, B, outputC);
235+
if (work == null)
236+
work = new DGrowArray();
237+
238+
ImplMultiplication_DSCC.multTransB(A, B, outputC, work);
227239

228240
return outputC;
229241
}
230242

231243
/**
232244
* <p>C = C + A*B<sup>T</sup></p>
233245
*/
234-
public static void multAddTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) {
246+
public static void multAddTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC,
247+
@Nullable DGrowArray work ) {
235248
if (A.numCols != B.numCols)
236249
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
237250
if (A.numRows != outputC.numRows || B.numRows != outputC.numCols)
238251
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC));
239252

240-
ImplMultiplication_DSCC.multAddTransB(A, B, outputC);
253+
if (work == null)
254+
work = new DGrowArray();
255+
256+
ImplMultiplication_DSCC.multAddTransB(A, B, outputC, work);
241257
}
242258

243259
/**

main/ejml-dsparse/src/org/ejml/sparse/csc/CommonOps_MT_DSCC.java

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ public static DMatrixSparseCSC mult( DMatrixSparseCSC A, DMatrixSparseCSC B, @Nu
5656
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
5757
outputC = reshapeOrDeclare(outputC, A, A.numRows, B.numCols);
5858

59+
if (listWork == null)
60+
listWork = new GrowArray<>(Workspace_MT_DSCC::new);
61+
5962
ImplMultiplication_MT_DSCC.mult(A, B, outputC, listWork);
6063

6164
return outputC;
@@ -80,6 +83,9 @@ public static DMatrixSparseCSC add( double alpha, DMatrixSparseCSC A, double bet
8083
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
8184
outputC = reshapeOrDeclare(outputC, A, A.numRows, A.numCols);
8285

86+
if (listWork == null)
87+
listWork = new GrowArray<>(Workspace_MT_DSCC::new);
88+
8389
ImplCommonOps_MT_DSCC.add(alpha, A, beta, B, outputC, listWork);
8490

8591
return outputC;
@@ -93,12 +99,14 @@ public static DMatrixSparseCSC add( double alpha, DMatrixSparseCSC A, double bet
9399
* @param outputC Dense Matrix
94100
*/
95101
public static DMatrixRMaj mult( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC,
96-
@Nullable GrowArray<DGrowArray> listWork ) {
102+
@Nullable GrowArray<DGrowArray> workArrays ) {
97103
if (A.numCols != B.numRows)
98104
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
99105
outputC = reshapeOrDeclare(outputC, A.numRows, B.numCols);
106+
if (workArrays == null)
107+
workArrays = new GrowArray<>(DGrowArray::new);
100108

101-
ImplMultiplication_MT_DSCC.mult(A, B, outputC, listWork);
109+
ImplMultiplication_MT_DSCC.mult(A, B, outputC, workArrays);
102110

103111
return outputC;
104112
}
@@ -107,13 +115,16 @@ public static DMatrixRMaj mult( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMa
107115
* <p>C = C + A<sup>T</sup>*B</p>
108116
*/
109117
public static void multAdd( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC,
110-
@Nullable GrowArray<DGrowArray> listWork ) {
118+
@Nullable GrowArray<DGrowArray> workArrays ) {
111119
if (A.numCols != B.numRows)
112120
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
113121
if (A.numRows != outputC.numRows || B.numCols != outputC.numCols)
114122
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC));
115123

116-
ImplMultiplication_MT_DSCC.multAdd(A, B, outputC, listWork);
124+
if (workArrays == null)
125+
workArrays = new GrowArray<>(DGrowArray::new);
126+
127+
ImplMultiplication_MT_DSCC.multAdd(A, B, outputC, workArrays);
117128
}
118129

119130
/**
@@ -123,27 +134,35 @@ public static void multAdd( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outpu
123134
* @param B Dense Matrix
124135
* @param outputC Dense Matrix
125136
*/
126-
public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC ) {
137+
public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC,
138+
@Nullable GrowArray<DGrowArray> workArray ) {
127139
if (A.numRows != B.numRows)
128140
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
129141

130142
outputC = reshapeOrDeclare(outputC, A.numCols, B.numCols);
131143

132-
ImplMultiplication_MT_DSCC.multTransA(A, B, outputC);
144+
if (workArray == null)
145+
workArray = new GrowArray<>(DGrowArray::new);
146+
147+
ImplMultiplication_MT_DSCC.multTransA(A, B, outputC, workArray);
133148

134149
return outputC;
135150
}
136151

137152
/**
138153
* <p>C = C + A<sup>T</sup>*B</p>
139154
*/
140-
public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) {
155+
public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC,
156+
@Nullable GrowArray<DGrowArray> workArray ) {
141157
if (A.numRows != B.numRows)
142158
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
143159
if (A.numCols != outputC.numRows || B.numCols != outputC.numCols)
144160
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC));
145161

146-
ImplMultiplication_MT_DSCC.multAddTransA(A, B, outputC);
162+
if (workArray == null)
163+
workArray = new GrowArray<>(DGrowArray::new);
164+
165+
ImplMultiplication_MT_DSCC.multAddTransA(A, B, outputC, workArray);
147166
}
148167

149168
/**
@@ -154,12 +173,15 @@ public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj
154173
* @param outputC Dense Matrix
155174
*/
156175
public static DMatrixRMaj multTransB( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC,
157-
@Nullable GrowArray<DGrowArray> listWork ) {
176+
@Nullable GrowArray<DGrowArray> workArrays ) {
158177
if (A.numCols != B.numCols)
159178
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
160179
outputC = reshapeOrDeclare(outputC, A.numRows, B.numRows);
161180

162-
ImplMultiplication_MT_DSCC.multTransB(A, B, outputC, listWork);
181+
if (workArrays == null)
182+
workArrays = new GrowArray<>(DGrowArray::new);
183+
184+
ImplMultiplication_MT_DSCC.multTransB(A, B, outputC, workArrays);
163185

164186
return outputC;
165187
}
@@ -168,12 +190,44 @@ public static DMatrixRMaj multTransB( DMatrixSparseCSC A, DMatrixRMaj B, @Nullab
168190
* <p>C = C + A*B<sup>T</sup></p>
169191
*/
170192
public static void multAddTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC,
171-
@Nullable GrowArray<DGrowArray> listWork ) {
193+
@Nullable GrowArray<DGrowArray> workArrays ) {
172194
if (A.numCols != B.numCols)
173195
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
174196
if (A.numRows != outputC.numRows || B.numRows != outputC.numCols)
175197
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC));
176198

177-
ImplMultiplication_MT_DSCC.multAddTransB(A, B, outputC, listWork);
199+
if (workArrays == null)
200+
workArrays = new GrowArray<>(DGrowArray::new);
201+
202+
ImplMultiplication_MT_DSCC.multAddTransB(A, B, outputC, workArrays);
203+
}
204+
205+
/**
206+
* Performs matrix multiplication. C = A<sup>T</sup>*B<sup>T</sup>
207+
*
208+
* @param A Matrix
209+
* @param B Dense Matrix
210+
* @param outputC Dense Matrix
211+
*/
212+
public static DMatrixRMaj multTransAB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) {
213+
if (A.numRows != B.numCols)
214+
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
215+
outputC = reshapeOrDeclare(outputC, A.numCols, B.numRows);
216+
217+
ImplMultiplication_MT_DSCC.multTransAB(A, B, outputC);
218+
219+
return outputC;
220+
}
221+
222+
/**
223+
* <p>C = C + A<sup>T</sup>*B<sup>T</sup></p>
224+
*/
225+
public static void multAddTransAB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) {
226+
if (A.numRows != B.numCols)
227+
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
228+
if (A.numCols != outputC.numRows || B.numRows != outputC.numCols)
229+
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC));
230+
231+
ImplMultiplication_MT_DSCC.multAddTransAB(A, B, outputC);
178232
}
179233
}

main/ejml-dsparse/src/org/ejml/sparse/csc/misc/ImplCommonOps_MT_DSCC.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import org.ejml.concurrency.EjmlConcurrency;
2222
import org.ejml.data.DMatrixSparseCSC;
2323
import org.ejml.sparse.csc.mult.Workspace_MT_DSCC;
24-
import org.jetbrains.annotations.Nullable;
2524
import pabeles.concurrency.GrowArray;
2625

2726
import static org.ejml.UtilEjml.adjust;
@@ -46,10 +45,7 @@ public class ImplCommonOps_MT_DSCC {
4645
* @param listWork (Optional) Storage for internal workspace. Can be null.
4746
*/
4847
public static void add( double alpha, DMatrixSparseCSC A, double beta, DMatrixSparseCSC B, DMatrixSparseCSC C,
49-
@Nullable GrowArray<Workspace_MT_DSCC> listWork ) {
50-
if (listWork == null)
51-
listWork = new GrowArray<>(Workspace_MT_DSCC::new);
52-
48+
GrowArray<Workspace_MT_DSCC> listWork ) {
5349
// Break the problem up into blocks of columns and process them independently
5450
EjmlConcurrency.loopBlocks(0, A.numCols, listWork, ( workspace, col0, col1 ) -> {
5551
DMatrixSparseCSC workC = workspace.mat;

0 commit comments

Comments
 (0)