Skip to content

Commit 6a4b3d8

Browse files
committed
[MINOR] minor cleanups and optimizations to CLA MM primitives
1 parent b751389 commit 6a4b3d8

File tree

3 files changed

+192
-61
lines changed

3 files changed

+192
-61
lines changed

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
3535
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
3636
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
37+
import org.apache.sysds.utils.stats.Timing;
3738

3839
/**
3940
* Support compressed MM chain operation to fuse the following cases :
@@ -53,6 +54,9 @@
5354
public final class CLALibMMChain {
5455
static final Log LOG = LogFactory.getLog(CLALibMMChain.class.getName());
5556

57+
/** Reusable cache intermediate double array for temporary decompression */
58+
private static ThreadLocal<double[]> cacheIntermediate = null;
59+
5660
private CLALibMMChain() {
5761
// private constructor
5862
}
@@ -87,20 +91,31 @@ private CLALibMMChain() {
8791
public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, MatrixBlock w, MatrixBlock out,
8892
ChainType ctype, int k) {
8993

94+
Timing t = new Timing();
9095
if(x.isEmpty())
9196
return returnEmpty(x, out);
9297

9398
// Morph the columns to efficient types for the operation.
9499
x = filterColGroups(x);
100+
double preFilterTime = t.stop();
95101

96102
// Allow overlapping intermediate if the intermediate is guaranteed not to be overlapping.
97103
final boolean allowOverlap = x.getColGroups().size() == 1 && isOverlappingAllowed();
98104

99105
// Right hand side multiplication
100-
MatrixBlock tmp = CLALibRightMultBy.rightMultByMatrix(x, v, null, k, allowOverlap);
106+
MatrixBlock tmp = CLALibRightMultBy.rightMultByMatrix(x, v, null, k, true);
107+
108+
double rmmTime = t.stop();
101109

102-
if(ctype == ChainType.XtwXv) // Multiply intermediate with vector if needed
110+
if(ctype == ChainType.XtwXv) { // Multiply intermediate with vector if needed
103111
tmp = binaryMultW(tmp, w, k);
112+
}
113+
114+
if(!allowOverlap && tmp instanceof CompressedMatrixBlock) {
115+
tmp = decompressIntermediate((CompressedMatrixBlock) tmp, k);
116+
}
117+
118+
double decompressTime = t.stop();
104119

105120
if(tmp instanceof CompressedMatrixBlock)
106121
// Compressed Compressed Matrix Multiplication
@@ -109,12 +124,50 @@ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, Matrix
109124
// LMM with Compressed - uncompressed multiplication.
110125
CLALibLeftMultBy.leftMultByMatrixTransposed(x, tmp, out, k);
111126

127+
double lmmTime = t.stop();
112128
if(out.getNumColumns() != 1) // transpose the output to make it a row output if needed
113129
out = LibMatrixReorg.transposeInPlace(out, k);
114130

131+
if(LOG.isDebugEnabled()) {
132+
StringBuilder sb = new StringBuilder("\n");
133+
sb.append("\nPreFilter Time : " + preFilterTime);
134+
sb.append("\nChain RMM : " + rmmTime);
135+
sb.append("\nChain RMM Decompress: " + decompressTime);
136+
sb.append("\nChain LMM : " + lmmTime);
137+
sb.append("\nChain Transpose : " + t.stop());
138+
LOG.debug(sb.toString());
139+
}
140+
115141
return out;
116142
}
117143

144+
private static MatrixBlock decompressIntermediate(CompressedMatrixBlock tmp, int k) {
145+
// cacheIntermediate
146+
final int rows = tmp.getNumRows();
147+
final int cols = tmp.getNumColumns();
148+
final int nCells = rows * cols;
149+
final double[] tmpArr;
150+
if(cacheIntermediate == null) {
151+
tmpArr = new double[nCells];
152+
cacheIntermediate = new ThreadLocal<>();
153+
cacheIntermediate.set(tmpArr);
154+
}
155+
else {
156+
double[] cachedArr = cacheIntermediate.get();
157+
if(cachedArr == null || cachedArr.length < nCells) {
158+
tmpArr = new double[nCells];
159+
cacheIntermediate.set(tmpArr);
160+
}
161+
else {
162+
tmpArr = cachedArr;
163+
}
164+
}
165+
166+
final MatrixBlock tmpV = new MatrixBlock(tmp.getNumRows(), tmp.getNumColumns(), tmpArr);
167+
CLALibDecompress.decompressTo((CompressedMatrixBlock) tmp, tmpV, 0, 0, k, false, true);
168+
return tmpV;
169+
}
170+
118171
private static boolean isOverlappingAllowed() {
119172
return ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.COMPRESSED_OVERLAPPING);
120173
}
@@ -146,6 +199,8 @@ private static CompressedMatrixBlock filterColGroups(CompressedMatrixBlock x) {
146199
final List<AColGroup> groups = x.getColGroups();
147200
final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups);
148201
if(shouldFilter) {
202+
if(CLALibUtils.alreadyPreFiltered(groups, x.getNumColumns()))
203+
return x;
149204
final int nCol = x.getNumColumns();
150205
final double[] constV = new double[nCol];
151206
final List<AColGroup> filteredGroups = CLALibUtils.filterGroups(groups, constV);

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java

Lines changed: 114 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,17 @@
3434
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
3535
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
3636
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
37-
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
37+
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
3838
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
3939
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
40-
import org.apache.sysds.runtime.functionobjects.Plus;
4140
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
4241
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
43-
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
4442
import org.apache.sysds.runtime.util.CommonThreadPool;
4543

4644
public final class CLALibRightMultBy {
4745
private static final Log LOG = LogFactory.getLog(CLALibRightMultBy.class.getName());
4846

49-
private CLALibRightMultBy(){
47+
private CLALibRightMultBy() {
5048
// private constructor
5149
}
5250

@@ -74,6 +72,11 @@ public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBloc
7472
if(m2 instanceof CompressedMatrixBlock)
7573
m2 = ((CompressedMatrixBlock) m2).getUncompressed("Uncompressed right side of right MM", k);
7674

75+
if(betterIfDecompressed(m1)) {
76+
// perform uncompressed multiplication.
77+
return decompressingMatrixMult(m1, m2, k);
78+
}
79+
7780
if(!allowOverlap) {
7881
LOG.trace("Overlapping output not allowed in call to Right MM");
7982
return RMM(m1, m2, k);
@@ -87,14 +90,67 @@ public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBloc
8790
if(retC.isOverlapping())
8891
retC.setNonZeros((long) rr * rc); // set non zeros to fully dense in case of overlapping.
8992
else
90-
retC.recomputeNonZeros(); // recompute if non overlapping compressed out.
93+
retC.recomputeNonZeros(k); // recompute if non overlapping compressed out.
9194
return retC;
9295
}
9396
}
97+
}
98+
99+
private static MatrixBlock decompressingMatrixMult(CompressedMatrixBlock m1, MatrixBlock m2, int k) {
100+
ExecutorService pool = CommonThreadPool.get(k);
101+
try {
102+
final int rl = m1.getNumRows();
103+
final int cr = m2.getNumColumns();
104+
// final int rr = m2.getNumRows(); // shared dim
105+
final MatrixBlock ret = new MatrixBlock(rl, cr, false);
106+
ret.allocateBlock();
107+
108+
// MatrixBlock m1uc = m1.decompress(k);
109+
final List<Future<Long>> tasks = new ArrayList<>();
110+
final List<AColGroup> groups = m1.getColGroups();
111+
final int blkI = Math.max((int) Math.ceil((double) rl / k), 16);
112+
final int blkJ = blkI > 16 ? cr : Math.max((cr / k), 512); // make it a multiplicative of 8.
113+
for(int i = 0; i < rl; i += blkI) {
114+
final int startI = i;
115+
final int endI = Math.min(i + blkI, rl);
116+
for(int j = 0; j < cr; j += blkJ){
117+
final int startJ = j;
118+
final int endJ = Math.min(j + blkJ, cr);
119+
tasks.add(pool.submit(() -> {
120+
for(AColGroup g : groups)
121+
g.rightDecompressingMult(m2, ret, startI, endI, rl, startJ, endJ);
122+
return ret.recomputeNonZeros(startI, endI - 1, startJ, endJ-1);
123+
}));
124+
}
125+
}
126+
long nnz = 0;
127+
for(Future<Long> t : tasks)
128+
nnz += t.get();
129+
130+
ret.setNonZeros(nnz);
131+
ret.examSparsity();
132+
return ret;
133+
}
134+
catch(InterruptedException | ExecutionException e) {
135+
throw new DMLRuntimeException(e);
136+
}
137+
finally {
138+
pool.shutdown();
139+
}
94140

95141
}
96142

143+
private static boolean betterIfDecompressed(CompressedMatrixBlock m) {
144+
for(AColGroup g : m.getColGroups()) {
145+
if(!(g instanceof ColGroupUncompressed) && g.getNumValues() * 2 >= m.getNumRows()) {
146+
return true;
147+
}
148+
}
149+
return false;
150+
}
151+
97152
private static CompressedMatrixBlock RMMOverlapping(CompressedMatrixBlock m1, MatrixBlock that, int k) {
153+
98154
final int rl = m1.getNumRows();
99155
final int cr = that.getNumColumns();
100156
final int rr = that.getNumRows(); // shared dim
@@ -103,21 +159,27 @@ private static CompressedMatrixBlock RMMOverlapping(CompressedMatrixBlock m1, Ma
103159
final CompressedMatrixBlock ret = new CompressedMatrixBlock(rl, cr);
104160

105161
final boolean shouldFilter = CLALibUtils.shouldPreFilter(colGroups);
162+
final double[] constV;
163+
final List<AColGroup> filteredGroups;
106164

107-
double[] constV = shouldFilter ? new double[rr] : null;
108-
final List<AColGroup> filteredGroups = CLALibUtils.filterGroups(colGroups, constV);
109-
if(colGroups == filteredGroups)
165+
if(shouldFilter) {
166+
constV = new double[rr];
167+
filteredGroups = CLALibUtils.filterGroups(colGroups, constV);
168+
}
169+
else {
170+
filteredGroups = colGroups;
110171
constV = null;
172+
}
111173

112-
if(k == 1)
174+
if(k == 1 || filteredGroups.size() == 1)
113175
RMMSingle(filteredGroups, that, retCg);
114176
else
115177
RMMParallel(filteredGroups, that, retCg, k);
116178

117179
if(constV != null) {
118180
final MatrixBlock cb = new MatrixBlock(1, constV.length, constV);
119181
final MatrixBlock cbRet = new MatrixBlock(1, that.getNumColumns(), false);
120-
LibMatrixMult.matrixMult(cb, that, cbRet);
182+
LibMatrixMult.matrixMult(cb, that, cbRet); // mm on row vector left.
121183
if(!cbRet.isEmpty())
122184
addConstant(cbRet, retCg);
123185
}
@@ -133,52 +195,72 @@ private static CompressedMatrixBlock RMMOverlapping(CompressedMatrixBlock m1, Ma
133195
}
134196

135197
private static void addConstant(MatrixBlock constantRow, List<AColGroup> out) {
136-
final int nCol = constantRow.getNumColumns();
137-
int bestCandidate = -1;
138-
int bestCandidateValuesSize = Integer.MAX_VALUE;
139-
for(int i = 0; i < out.size(); i++) {
140-
AColGroup g = out.get(i);
141-
if(g instanceof ColGroupDDC && g.getNumCols() == nCol && g.getNumValues() < bestCandidateValuesSize)
142-
bestCandidate = i;
143-
}
198+
// it is fairly safe to add the constant row to a column group.
199+
// but it is not necessary the fastest.
200+
201+
// final int nCol = constantRow.getNumColumns();
202+
// int bestCandidate = -1;
203+
// int bestCandidateValuesSize = Integer.MAX_VALUE;
204+
// for(int i = 0; i < out.size(); i++) {
205+
// AColGroup g = out.get(i);
206+
// if(g instanceof ColGroupDDC && g.getNumCols() == nCol && g.getNumValues() < bestCandidateValuesSize)
207+
// bestCandidate = i;
208+
// }
144209

145210
constantRow.sparseToDense();
146211

147-
if(bestCandidate != -1) {
148-
AColGroup bc = out.get(bestCandidate);
149-
out.remove(bestCandidate);
150-
AColGroup ng = bc.binaryRowOpRight(new BinaryOperator(Plus.getPlusFnObject(), 1),
151-
constantRow.getDenseBlockValues(), true);
152-
out.add(ng);
153-
}
154-
else
155-
out.add(ColGroupConst.create(constantRow.getDenseBlockValues()));
212+
// if(bestCandidate != -1) {
213+
// AColGroup bc = out.get(bestCandidate);
214+
// out.remove(bestCandidate);
215+
// AColGroup ng = bc.binaryRowOpRight(new BinaryOperator(Plus.getPlusFnObject(), 1),
216+
// constantRow.getDenseBlockValues(), true);
217+
// out.add(ng);
218+
// }
219+
// else
220+
out.add(ColGroupConst.create(constantRow.getDenseBlockValues()));
156221
}
157222

158223
private static MatrixBlock RMM(CompressedMatrixBlock m1, MatrixBlock that, int k) {
224+
225+
// Timing t = new Timing();
159226
// this version returns a decompressed result.
160227
final int rl = m1.getNumRows();
161228
final int cr = that.getNumColumns();
162229
final int rr = that.getNumRows(); // shared dim
163230
final List<AColGroup> colGroups = m1.getColGroups();
164-
final List<AColGroup> retCg = new ArrayList<>();
165231

166232
final boolean shouldFilter = CLALibUtils.shouldPreFilter(colGroups);
167233

168234
// start allocation of output.
169235
MatrixBlock ret = new MatrixBlock(rl, cr, false);
170236
final Future<MatrixBlock> f = ret.allocateBlockAsync();
171237

172-
double[] constV = shouldFilter ? new double[rr] : null;
173-
final List<AColGroup> filteredGroups = CLALibUtils.filterGroups(colGroups, constV);
174-
if(colGroups == filteredGroups)
238+
double[] constV;
239+
final List<AColGroup> filteredGroups;
240+
241+
if(shouldFilter) {
242+
if(CLALibUtils.alreadyPreFiltered(colGroups, cr)) {
243+
filteredGroups = new ArrayList<>(colGroups.size() - 1);
244+
constV = CLALibUtils.filterGroupsAndSplitPreAggOneConst(colGroups, filteredGroups);
245+
}
246+
else {
247+
constV = new double[rr];
248+
filteredGroups = CLALibUtils.filterGroups(colGroups, constV);
249+
}
250+
}
251+
else {
252+
filteredGroups = colGroups;
175253
constV = null;
254+
}
176255

256+
257+
final List<AColGroup> retCg = new ArrayList<>(filteredGroups.size());
177258
if(k == 1)
178259
RMMSingle(filteredGroups, that, retCg);
179260
else
180261
RMMParallel(filteredGroups, that, retCg, k);
181262

263+
182264
if(constV != null) {
183265
MatrixBlock constVMB = new MatrixBlock(1, constV.length, constV);
184266
MatrixBlock mmTemp = new MatrixBlock(1, cr, false);
@@ -233,7 +315,7 @@ private static boolean RMMParallel(List<AColGroup> filteredGroups, MatrixBlock t
233315
catch(InterruptedException | ExecutionException e) {
234316
throw new DMLRuntimeException(e);
235317
}
236-
finally{
318+
finally {
237319
pool.shutdown();
238320
}
239321
return containsNull;

0 commit comments

Comments
 (0)