Skip to content

Commit 92aa65a

Browse files
e-straussBaunsgaard
authored andcommitted
[SYSTEMDS-3541] Exploratory workload-aware compression on intermediates
This commit explores the aggressive compression on intermediates, explicitly for the kmeans builtin algorithm. This commit adds new compressed operations to avoid the decompression and minimize the compression time of intermediates. The runtime of the kmeans algorithm on the census dataset was reduced from initially 50s with intermediate compression down to 17.5s with all the optimizations. Which is an overall improvement of 33% in comparison to the baseline runtime of workload-aware, non-aggressive compression of 27s. A summary of the changes: - added config option for aggressive compression - removed scalars from possible compressible intermediate candidates - extended the compression workload analyzer to pick up aggregation operations if the input is compressed as a single column group - extended the compression workload analyzer to pick up binary matrix-vec op, if of both inputs have the same col indices - updated cost estimation for compression on already compressed inputs with single column group - added support for double compressed binary matrix-matrix op - single-threaded compressed binary matrix-vector operation with single column group encoding - multi threaded compressed binary matrix-vector operation with single column group encoding - added support for left binary matrix-vector operation for sparse outputs Closes #2230
1 parent eb7c65c commit 92aa65a

File tree

9 files changed

+605
-127
lines changed

9 files changed

+605
-127
lines changed

src/main/java/org/apache/sysds/conf/DMLConfig.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ public class DMLConfig
7979
public static final String PARALLEL_TOKENIZE = "sysds.parallel.tokenize";
8080
public static final String PARALLEL_TOKENIZE_NUM_BLOCKS = "sysds.parallel.tokenize.numBlocks";
8181
public static final String COMPRESSED_LINALG = "sysds.compressed.linalg";
82+
public static final String COMPRESSED_LINALG_INTERMEDIATE = "sysds.compressed.linalg.intermediate";
8283
public static final String COMPRESSED_LOSSY = "sysds.compressed.lossy";
8384
public static final String COMPRESSED_VALID_COMPRESSIONS = "sysds.compressed.valid.compressions";
8485
public static final String COMPRESSED_OVERLAPPING = "sysds.compressed.overlapping";

src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,12 @@ public static boolean satisfiesAggressiveCompressionCondition(Hop hop) {
171171
satisfies |= HopRewriteUtils.isTernary(hop, OpOp3.CTABLE)
172172
&& hop.getInput(0).getDataType().isMatrix()
173173
&& hop.getInput(1).getDataType().isMatrix();
174-
satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD) && !hop.isScalar();
174+
satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD);
175175
satisfies |= HopRewriteUtils.isUnary(hop, OpOp1.ROUND, OpOp1.FLOOR, OpOp1.NOT, OpOp1.CEIL);
176176
satisfies |= HopRewriteUtils.isBinary(hop, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS,
177177
OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, OpOp2.AND, OpOp2.OR, OpOp2.MODULUS);
178178
satisfies |= HopRewriteUtils.isTernary(hop, OpOp3.CTABLE);
179+
satisfies &= !hop.isScalar();
179180
}
180181
if(LOG.isDebugEnabled() && satisfies)
181182
LOG.debug("Operation Satisfies: " + hop);

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,8 @@ private void classifyPhase() {
344344
// final int nRows = mb.getNumRows();
345345
final int nCols = mb.getNumColumns();
346346
// Assume the scaling of cocoding is at maximum square root good relative to number of columns.
347-
final double scale = Math.sqrt(nCols);
347+
final double scale = mb instanceof CompressedMatrixBlock &&
348+
((CompressedMatrixBlock) mb).getColGroups().size() == 1 ? 1 : Math.sqrt(nCols);
348349
final double threshold = _stats.estimatedCostCols / scale;
349350

350351
if(threshold < _stats.originalCost *

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java

Lines changed: 374 additions & 86 deletions
Large diffs are not rendered by default.

src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.LinkedList;
2727
import java.util.List;
2828
import java.util.Map;
29+
import java.util.Objects;
2930
import java.util.Set;
3031
import java.util.Stack;
3132

@@ -38,6 +39,8 @@
3839
import org.apache.sysds.common.Types.OpOpData;
3940
import org.apache.sysds.common.Types.ParamBuiltinOp;
4041
import org.apache.sysds.common.Types.ReOrgOp;
42+
import org.apache.sysds.conf.ConfigurationManager;
43+
import org.apache.sysds.conf.DMLConfig;
4144
import org.apache.sysds.hops.AggBinaryOp;
4245
import org.apache.sysds.hops.AggUnaryOp;
4346
import org.apache.sysds.hops.BinaryOp;
@@ -81,9 +84,15 @@ public class WorkloadAnalyzer {
8184
private final DMLProgram prog;
8285
private final Map<Long, Op> treeLookup;
8386
private final Stack<Hop> stack;
87+
private Stack<StatementBlock> lineage = new Stack<>();
8488

8589
public static Map<Long, WTreeRoot> getAllCandidateWorkloads(DMLProgram prog) {
8690
// extract all compression candidates from program (in program order)
91+
String configValue = ConfigurationManager.getDMLConfig()
92+
.getTextValue(DMLConfig.COMPRESSED_LINALG_INTERMEDIATE);
93+
// if set update it, otherwise keep it as set before
94+
ALLOW_INTERMEDIATE_CANDIDATES = configValue != null && Objects.equals(configValue.toUpperCase(), "TRUE") ||
95+
configValue == null && ALLOW_INTERMEDIATE_CANDIDATES;
8796
List<Hop> candidates = getCandidates(prog);
8897

8998
// for each candidate, create pruned workload tree
@@ -115,6 +124,7 @@ private WorkloadAnalyzer(DMLProgram prog) {
115124
this.overlapping = new HashSet<>();
116125
this.treeLookup = new HashMap<>();
117126
this.stack = new Stack<>();
127+
this.lineage = new Stack<>();
118128
}
119129

120130
private WorkloadAnalyzer(DMLProgram prog, Set<Long> compressed, HashMap<String, Long> transientCompressed,
@@ -235,6 +245,7 @@ private static void getCandidates(Hop hop, DMLProgram prog, List<Hop> cands, Set
235245

236246
private void createWorkloadTreeNodes(AWTreeNode n, StatementBlock sb, DMLProgram prog, Set<String> fStack) {
237247
WTreeNode node;
248+
lineage.add(sb);
238249
if(sb instanceof FunctionStatementBlock) {
239250
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
240251
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
@@ -291,7 +302,7 @@ else if(sb instanceof ForStatementBlock) { // incl parfor
291302
if(hop instanceof FunctionOp) {
292303
FunctionOp fop = (FunctionOp) hop;
293304
if(HopRewriteUtils.isTransformEncode(fop))
294-
return;
305+
break;
295306
else if(!fStack.contains(fop.getFunctionKey())) {
296307
fStack.add(fop.getFunctionKey());
297308
FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionKey());
@@ -323,9 +334,11 @@ else if(!fStack.contains(fop.getFunctionKey())) {
323334
}
324335
}
325336
}
337+
lineage.pop();
326338
return;
327339
}
328340
n.addChild(node);
341+
lineage.pop();
329342
}
330343

331344
private void createStack(Hop hop) {
@@ -396,7 +409,22 @@ else if(hop instanceof AggUnaryOp) {
396409
return;
397410
}
398411
else {
399-
o = new OpNormal(hop, false);
412+
boolean compressedOut = false;
413+
Hop parentHop = hop.getInput(0);
414+
if(HopRewriteUtils.isBinary(parentHop, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS,
415+
OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL)){
416+
Hop leftIn = parentHop.getInput(0);
417+
Hop rightIn = parentHop.getInput(1);
418+
// input ops might be not in the current statement block -> check for transient reads
419+
if(HopRewriteUtils.isAggUnaryOp(leftIn, AggOp.MIN, AggOp.MAX) ||
420+
HopRewriteUtils.isAggUnaryOp(rightIn, AggOp.MIN, AggOp.MAX) ||
421+
checkTransientRead(hop, leftIn) ||
422+
checkTransientRead(hop, rightIn)
423+
)
424+
compressedOut = true;
425+
426+
}
427+
o = new OpNormal(hop, compressedOut);
400428
}
401429
}
402430
else if(hop instanceof UnaryOp) {
@@ -477,9 +505,17 @@ else if(ol) {
477505
if(!HopRewriteUtils.isBinarySparseSafe(hop))
478506
o.setDensifying();
479507

508+
} else if(HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) ) {
509+
Hop leftIn = hop.getInput(0);
510+
Hop rightIn = hop.getInput(1);
511+
if(HopRewriteUtils.isBinary(hop, OpOp2.DIV) && rightIn instanceof AggUnaryOp && leftIn == rightIn.getInput(0)){
512+
o = new OpNormal(hop, true);
513+
} else {
514+
setDecompressionOnAllInputs(hop, parent);
515+
return;
516+
}
480517
}
481518
else if(HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) ||
482-
HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) ||
483519
HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) {
484520
setDecompressionOnAllInputs(hop, parent);
485521
return;
@@ -623,6 +659,40 @@ else if(hop instanceof AggUnaryOp) {
623659
}
624660
}
625661

662+
private boolean checkTransientRead(Hop hop, Hop input) {
663+
// op is not in current statement block
664+
if(HopRewriteUtils.isData(input, OpOpData.TRANSIENTREAD)){
665+
String varName = input.getName();
666+
StatementBlock csb = lineage.peek();
667+
StatementBlock parentStatement = lineage.get(lineage.size() -2);
668+
669+
if(parentStatement instanceof WhileStatementBlock) {
670+
WhileStatementBlock wsb = (WhileStatementBlock) parentStatement;
671+
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
672+
ArrayList<StatementBlock> stmts = wstmt.getBody();
673+
boolean foundCurrent = false;
674+
StatementBlock sb;
675+
676+
// traverse statement blocks in reverse to find the statement block, which came before the current
677+
// if we iterate in default order, we might find an earlier updated version of the current variable
678+
for (int i = stmts.size()-1; i >= 0; i--) {
679+
sb = stmts.get(i);
680+
if(foundCurrent && sb.variablesUpdated().containsVariable(varName)) {
681+
for(Hop cand : sb.getHops()){
682+
if(HopRewriteUtils.isData(cand, OpOpData.TRANSIENTWRITE) && cand.getName().equals(varName)
683+
&& HopRewriteUtils.isAggUnaryOp( cand.getInput(0), AggOp.MIN, AggOp.MAX)){
684+
return true;
685+
}
686+
}
687+
} else if(sb == csb){
688+
foundCurrent = true;
689+
}
690+
}
691+
}
692+
}
693+
return false;
694+
}
695+
626696
private boolean isCompressed(Hop hop) {
627697
return compressed.contains(hop.getHopID());
628698
}

src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,28 @@ public int putIfAbsentI(K key, int value) {
152152

153153
}
154154

155+
public int putIfAbsentReturnVal(K key, int value) {
156+
157+
if(key == null) {
158+
if(nullV == -1) {
159+
size++;
160+
nullV = value;
161+
return -1;
162+
}
163+
else
164+
return nullV;
165+
}
166+
else {
167+
final int ix = hash(key);
168+
Node<K> b = buckets[ix];
169+
if(b == null)
170+
return createBucketReturnVal(ix, key, value);
171+
else
172+
return putIfAbsentBucketReturnval(ix, key, value);
173+
}
174+
175+
}
176+
155177
private int putIfAbsentBucket(int ix, K key, int value) {
156178
Node<K> b = buckets[ix];
157179
while(true) {
@@ -167,6 +189,21 @@ private int putIfAbsentBucket(int ix, K key, int value) {
167189
}
168190
}
169191

192+
private int putIfAbsentBucketReturnval(int ix, K key, int value) {
193+
Node<K> b = buckets[ix];
194+
while(true) {
195+
if(b.key.equals(key))
196+
return b.value;
197+
if(b.next == null) {
198+
b.setNext(new Node<>(key, value, null));
199+
size++;
200+
resize();
201+
return value;
202+
}
203+
b = b.next;
204+
}
205+
}
206+
170207
public int putI(K key, int value) {
171208
if(key == null) {
172209
int tmp = nullV;
@@ -191,6 +228,12 @@ private int createBucket(int ix, K key, int value) {
191228
return -1;
192229
}
193230

231+
private int createBucketReturnVal(int ix, K key, int value) {
232+
buckets[ix] = new Node<K>(key, value, null);
233+
size++;
234+
return value;
235+
}
236+
194237
private int addToBucket(int ix, K key, int value) {
195238
Node<K> b = buckets[ix];
196239
while(true) {

src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpCustomTest.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@
2222
import static org.mockito.Mockito.spy;
2323
import static org.mockito.Mockito.when;
2424

25+
import org.apache.commons.lang3.tuple.Pair;
2526
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
2627
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
28+
import org.apache.sysds.runtime.compress.CompressionStatistics;
2729
import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp;
30+
import org.apache.sysds.runtime.functionobjects.GreaterThanEquals;
31+
import org.apache.sysds.runtime.functionobjects.LessThanEquals;
2832
import org.apache.sysds.runtime.functionobjects.Minus;
2933
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
3034
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -47,6 +51,35 @@ public void notColVector() {
4751
TestUtils.compareMatricesBitAvgDistance(new MatrixBlock(10, 10, 1.3), cRet2, 0, 0, op.toString());
4852
}
4953

54+
@Test
55+
public void twoHotEncodedOutput() {
56+
BinaryOperator op = new BinaryOperator(LessThanEquals.getLessThanEqualsFnObject(), 2);
57+
BinaryOperator op2 = new BinaryOperator(LessThanEquals.getLessThanEqualsFnObject());
58+
BinaryOperator opLeft = new BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject(), 2);
59+
BinaryOperator opLeft2 = new BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject());
60+
61+
MatrixBlock cDense = new MatrixBlock(30, 30, 2.0);
62+
for (int i = 0; i < 30; i++) {
63+
cDense.set(i,0, 1);
64+
}
65+
cDense.set(0,1, 1);
66+
Pair<MatrixBlock, CompressionStatistics> pair = CompressedMatrixBlockFactory.compress(cDense, 1);
67+
CompressedMatrixBlock c = (CompressedMatrixBlock) pair.getKey();
68+
MatrixBlock c2 = new MatrixBlock(30, 1, 1.0);
69+
CompressedMatrixBlock spy = spy(c);
70+
when(spy.getCachedDecompressed()).thenReturn(null);
71+
72+
MatrixBlock cRet = CLALibBinaryCellOp.binaryOperationsRight(op, spy, c2);
73+
MatrixBlock cRet2 = CLALibBinaryCellOp.binaryOperationsRight(op2, spy, c2);
74+
TestUtils.compareMatricesBitAvgDistance(cRet, cRet2, 0, 0, op.toString());
75+
76+
MatrixBlock cRetleft = CLALibBinaryCellOp.binaryOperationsLeft(opLeft, spy, c2);
77+
MatrixBlock cRetleft2 = CLALibBinaryCellOp.binaryOperationsLeft(opLeft2, spy, c2);
78+
TestUtils.compareMatricesBitAvgDistance(cRetleft, cRetleft2, 0, 0, op.toString());
79+
80+
TestUtils.compareMatricesBitAvgDistance(cRet, cRetleft, 0, 0, op.toString());
81+
}
82+
5083
@Test
5184
public void notColVectorEmptyReturn() {
5285
BinaryOperator op = new BinaryOperator(Minus.getMinusFnObject(), 2);

0 commit comments

Comments
 (0)