Skip to content

Commit a6d8bc0

Browse files
committed
[MINOR] minor cleanups and optimizations to CLA MM primitives
This commit include specialized decompressing MM for DDC with identity matrix dictionaries. Closes #2210
1 parent ece172a commit a6d8bc0

File tree

4 files changed

+230
-114
lines changed

4 files changed

+230
-114
lines changed

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,30 @@ public void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, i
601601

602602
@Override
603603
public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int nRows, int crl, int cru) {
604+
if(_dict instanceof IdentityDictionary)
605+
identityRightDecompressingMult(right, ret, rl, ru, crl, cru);
606+
else
607+
defaultRightDecompressingMult(right, ret, rl, ru, crl, cru);
608+
}
609+
610+
private void identityRightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int crl, int cru) {
611+
final double[] b = right.getDenseBlockValues();
612+
final double[] c = ret.getDenseBlockValues();
613+
final int jd = right.getNumColumns();
614+
final int vLen = 8;
615+
final int lenJ = cru - crl;
616+
final int end = cru - (lenJ % vLen);
617+
for(int i = rl; i < ru; i++) {
618+
int k = _data.getIndex(i);
619+
final int offOut = i * jd + crl;
620+
final double aa = 1;
621+
final int k_right = _colIndexes.get(k);
622+
vectMM(aa, b, c, end, jd, crl, cru, offOut, k_right, vLen);
623+
624+
}
625+
}
626+
627+
private void defaultRightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int crl, int cru) {
604628
final double[] a = _dict.getValues();
605629
final double[] b = right.getDenseBlockValues();
606630
final double[] c = ret.getDenseBlockValues();
@@ -930,8 +954,6 @@ protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret
930954
}
931955
}
932956

933-
934-
935957
private void leftMMIdentityPreAggregateDenseSingleRow(double[] values, int pos, double[] values2, int pos2, int cl,
936958
int cu) {
937959
IdentityDictionary a = (IdentityDictionary) _dict;

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
3535
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
3636
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
37+
import org.apache.sysds.utils.stats.Timing;
3738

3839
/**
3940
* Support compressed MM chain operation to fuse the following cases :
@@ -53,6 +54,9 @@
5354
public final class CLALibMMChain {
5455
static final Log LOG = LogFactory.getLog(CLALibMMChain.class.getName());
5556

57+
/** Reusable cache intermediate double array for temporary decompression */
58+
private static ThreadLocal<double[]> cacheIntermediate = null;
59+
5660
private CLALibMMChain() {
5761
// private constructor
5862
}
@@ -87,20 +91,31 @@ private CLALibMMChain() {
8791
public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, MatrixBlock w, MatrixBlock out,
8892
ChainType ctype, int k) {
8993

94+
Timing t = new Timing();
9095
if(x.isEmpty())
9196
return returnEmpty(x, out);
9297

9398
// Morph the columns to efficient types for the operation.
9499
x = filterColGroups(x);
100+
double preFilterTime = t.stop();
95101

96102
// Allow overlapping intermediate if the intermediate is guaranteed not to be overlapping.
97103
final boolean allowOverlap = x.getColGroups().size() == 1 && isOverlappingAllowed();
98104

99105
// Right hand side multiplication
100-
MatrixBlock tmp = CLALibRightMultBy.rightMultByMatrix(x, v, null, k, allowOverlap);
106+
MatrixBlock tmp = CLALibRightMultBy.rightMultByMatrix(x, v, null, k, true);
107+
108+
double rmmTime = t.stop();
101109

102-
if(ctype == ChainType.XtwXv) // Multiply intermediate with vector if needed
110+
if(ctype == ChainType.XtwXv) { // Multiply intermediate with vector if needed
103111
tmp = binaryMultW(tmp, w, k);
112+
}
113+
114+
if(!allowOverlap && tmp instanceof CompressedMatrixBlock) {
115+
tmp = decompressIntermediate((CompressedMatrixBlock) tmp, k);
116+
}
117+
118+
double decompressTime = t.stop();
104119

105120
if(tmp instanceof CompressedMatrixBlock)
106121
// Compressed Compressed Matrix Multiplication
@@ -109,12 +124,50 @@ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, Matrix
109124
// LMM with Compressed - uncompressed multiplication.
110125
CLALibLeftMultBy.leftMultByMatrixTransposed(x, tmp, out, k);
111126

127+
double lmmTime = t.stop();
112128
if(out.getNumColumns() != 1) // transpose the output to make it a row output if needed
113129
out = LibMatrixReorg.transposeInPlace(out, k);
114130

131+
if(LOG.isDebugEnabled()) {
132+
StringBuilder sb = new StringBuilder("\n");
133+
sb.append("\nPreFilter Time : " + preFilterTime);
134+
sb.append("\nChain RMM : " + rmmTime);
135+
sb.append("\nChain RMM Decompress: " + decompressTime);
136+
sb.append("\nChain LMM : " + lmmTime);
137+
sb.append("\nChain Transpose : " + t.stop());
138+
LOG.debug(sb.toString());
139+
}
140+
115141
return out;
116142
}
117143

144+
private static MatrixBlock decompressIntermediate(CompressedMatrixBlock tmp, int k) {
145+
// cacheIntermediate
146+
final int rows = tmp.getNumRows();
147+
final int cols = tmp.getNumColumns();
148+
final int nCells = rows * cols;
149+
final double[] tmpArr;
150+
if(cacheIntermediate == null) {
151+
tmpArr = new double[nCells];
152+
cacheIntermediate = new ThreadLocal<>();
153+
cacheIntermediate.set(tmpArr);
154+
}
155+
else {
156+
double[] cachedArr = cacheIntermediate.get();
157+
if(cachedArr == null || cachedArr.length < nCells) {
158+
tmpArr = new double[nCells];
159+
cacheIntermediate.set(tmpArr);
160+
}
161+
else {
162+
tmpArr = cachedArr;
163+
}
164+
}
165+
166+
final MatrixBlock tmpV = new MatrixBlock(tmp.getNumRows(), tmp.getNumColumns(), tmpArr);
167+
CLALibDecompress.decompressTo((CompressedMatrixBlock) tmp, tmpV, 0, 0, k, false, true);
168+
return tmpV;
169+
}
170+
118171
private static boolean isOverlappingAllowed() {
119172
return ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.COMPRESSED_OVERLAPPING);
120173
}
@@ -146,6 +199,8 @@ private static CompressedMatrixBlock filterColGroups(CompressedMatrixBlock x) {
146199
final List<AColGroup> groups = x.getColGroups();
147200
final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups);
148201
if(shouldFilter) {
202+
if(CLALibUtils.alreadyPreFiltered(groups, x.getNumColumns()))
203+
return x;
149204
final int nCol = x.getNumColumns();
150205
final double[] constV = new double[nCol];
151206
final List<AColGroup> filteredGroups = CLALibUtils.filterGroups(groups, constV);

0 commit comments

Comments
 (0)