Skip to content

Commit f6dfda7

Browse files
committed
maybe faster
1 parent ecdccb8 commit f6dfda7

File tree

3 files changed

+27
-17
lines changed

3 files changed

+27
-17
lines changed

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -611,14 +611,6 @@ public MatrixBlock aggregateUnaryOperations(AggregateUnaryOperator op, MatrixVal
611611
public MatrixBlock transposeSelfMatrixMultOperations(MatrixBlock out, MMTSJType tstype, int k) {
612612
// check for transpose type
613613
if(tstype == MMTSJType.LEFT) {
614-
if(isEmpty())
615-
return new MatrixBlock(clen, clen, true);
616-
// create output matrix block
617-
if(out == null)
618-
out = new MatrixBlock(clen, clen, false);
619-
else
620-
out.reset(clen, clen, false);
621-
out.allocateDenseBlock();
622614
CLALibTSMM.leftMultByTransposeSelf(this, out, k);
623615
return out;
624616
}

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
3131
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
3232
import org.apache.sysds.runtime.functionobjects.Multiply;
33+
import org.apache.sysds.runtime.instructions.InstructionUtils;
3334
import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
3435
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
3536
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
37+
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
3638
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
3739
import org.apache.sysds.utils.stats.Timing;
3840

@@ -95,6 +97,11 @@ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, Matrix
9597
if(x.isEmpty())
9698
return returnEmpty(x, out);
9799

100+
if(ctype == ChainType.XtXv && x.getColGroups().size() < 5 && x.getNumColumns()> 30){
101+
MatrixBlock tmp = CLALibTSMM.leftMultByTransposeSelf(x, k);
102+
return tmp.aggregateBinaryOperations(tmp, v, out, InstructionUtils.getMatMultOperator(k));
103+
}
104+
98105
// Morph the columns to efficient types for the operation.
99106
x = filterColGroups(x);
100107
double preFilterTime = t.stop();

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ private CLALibTSMM() {
4242
// private constructor
4343
}
4444

45+
public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, int k) {
46+
return leftMultByTransposeSelf(cmb, new MatrixBlock(), k);
47+
}
48+
4549
/**
4650
* Self left Matrix multiplication (tsmm)
4751
*
@@ -51,17 +55,25 @@ private CLALibTSMM() {
5155
* @param ret The output matrix to put the result into
5256
* @param k The parallelization degree allowed
5357
*/
54-
public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {
58+
public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {
5559

60+
final int numColumns = cmb.getNumColumns();
61+
final int numRows = cmb.getNumRows();
62+
if(cmb.isEmpty())
63+
return new MatrixBlock(numColumns, numColumns, true);
64+
// create output matrix block
65+
if(ret == null)
66+
ret = new MatrixBlock(numColumns, numColumns, false);
67+
else
68+
ret.reset(numColumns, numColumns, false);
69+
ret.allocateDenseBlock();
5670
final List<AColGroup> groups = cmb.getColGroups();
5771

58-
final int numColumns = cmb.getNumColumns();
5972
if(groups.size() >= numColumns) {
6073
MatrixBlock m = cmb.getUncompressed("TSMM to many columngroups", k);
6174
LibMatrixMult.matrixMultTransposeSelf(m, ret, true, k);
62-
return;
75+
return ret;
6376
}
64-
final int numRows = cmb.getNumRows();
6577
final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups);
6678
final boolean overlapping = cmb.isOverlapping();
6779
if(shouldFilter) {
@@ -77,6 +89,7 @@ public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBloc
7789

7890
ret.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(ret));
7991
ret.examSparsity();
92+
return ret;
8093
}
8194

8295
private static void addCorrectionLayer(List<AColGroup> filteredGroups, MatrixBlock result, int nRows, int nCols,
@@ -86,8 +99,6 @@ private static void addCorrectionLayer(List<AColGroup> filteredGroups, MatrixBlo
8699
addCorrectionLayer(constV, filteredColSum, nRows, retV);
87100
}
88101

89-
90-
91102
private static void tsmmColGroups(List<AColGroup> groups, MatrixBlock ret, int nRows, boolean overlapping, int k) {
92103
if(k <= 1)
93104
tsmmColGroupsSingleThread(groups, ret, nRows);
@@ -136,12 +147,12 @@ private static void tsmmColGroupsMultiThread(List<AColGroup> groups, MatrixBlock
136147

137148
public static void addCorrectionLayer(double[] constV, double[] filteredColSum, int nRow, double[] ret) {
138149
final int nColRow = constV.length;
139-
for(int row = 0; row < nColRow; row++){
150+
for(int row = 0; row < nColRow; row++) {
140151
int offOut = nColRow * row;
141152
final double v1l = constV[row];
142153
final double v2l = filteredColSum[row] + constV[row] * nRow;
143-
for(int col = row; col < nColRow; col++){
144-
ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col];
154+
for(int col = row; col < nColRow; col++) {
155+
ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col];
145156
}
146157
}
147158
}

0 commit comments

Comments
 (0)