Skip to content

Commit db8820d

Browse files
committed
[SYSTEMDS-3541] Exploratory workload-aware compression on intermediates
Added a config option for aggressive compression and extended the compression workload analyzer to detect aggregation operations and binary matrix-vector operations when inputs are compressed as a single column group. Updated cost estimation for compression on already compressed inputs and removed scalars from compressible intermediate candidates. Added support for double compressed binary matrix-matrix operations and implemented both single-threaded and multithreaded compressed binary matrix-vector operations with single column group encoding. Removed the relaxed compression threshold and added a logging statement for potential improvements in compressed binary matrix-vector operations. Enabled always sampling for binary matrix-vector operations in CLALibBinaryCellOp, expanded test coverage, and introduced a new compression algorithm test case for k-means with intermediate compression enabled. I also extended the CLALibBinaryCellOp binary matrix-vector (sparse & dense) op task to support left and right operations.
1 parent 78b23cf commit db8820d

File tree

9 files changed

+605
-127
lines changed

9 files changed

+605
-127
lines changed

src/main/java/org/apache/sysds/conf/DMLConfig.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ public class DMLConfig
7979
public static final String PARALLEL_TOKENIZE = "sysds.parallel.tokenize";
8080
public static final String PARALLEL_TOKENIZE_NUM_BLOCKS = "sysds.parallel.tokenize.numBlocks";
8181
public static final String COMPRESSED_LINALG = "sysds.compressed.linalg";
82+
public static final String COMPRESSED_LINALG_INTERMEDIATE = "sysds.compressed.linalg.intermediate";
8283
public static final String COMPRESSED_LOSSY = "sysds.compressed.lossy";
8384
public static final String COMPRESSED_VALID_COMPRESSIONS = "sysds.compressed.valid.compressions";
8485
public static final String COMPRESSED_OVERLAPPING = "sysds.compressed.overlapping";

src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,12 @@ public static boolean satisfiesAggressiveCompressionCondition(Hop hop) {
171171
satisfies |= HopRewriteUtils.isTernary(hop, OpOp3.CTABLE)
172172
&& hop.getInput(0).getDataType().isMatrix()
173173
&& hop.getInput(1).getDataType().isMatrix();
174-
satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD) && !hop.isScalar();
174+
satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD);
175175
satisfies |= HopRewriteUtils.isUnary(hop, OpOp1.ROUND, OpOp1.FLOOR, OpOp1.NOT, OpOp1.CEIL);
176176
satisfies |= HopRewriteUtils.isBinary(hop, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS,
177177
OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, OpOp2.AND, OpOp2.OR, OpOp2.MODULUS);
178178
satisfies |= HopRewriteUtils.isTernary(hop, OpOp3.CTABLE);
179+
satisfies &= !hop.isScalar();
179180
}
180181
if(LOG.isDebugEnabled() && satisfies)
181182
LOG.debug("Operation Satisfies: " + hop);

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,8 @@ private void classifyPhase() {
344344
// final int nRows = mb.getNumRows();
345345
final int nCols = mb.getNumColumns();
346346
// Assume the scaling of cocoding is at maximum square root good relative to number of columns.
347-
final double scale = Math.sqrt(nCols);
347+
final double scale = mb instanceof CompressedMatrixBlock &&
348+
((CompressedMatrixBlock) mb).getColGroups().size() == 1 ? 1 : Math.sqrt(nCols);
348349
final double threshold = _stats.estimatedCostCols / scale;
349350

350351
if(threshold < _stats.originalCost *

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java

Lines changed: 374 additions & 86 deletions
Large diffs are not rendered by default.

src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.LinkedList;
2727
import java.util.List;
2828
import java.util.Map;
29+
import java.util.Objects;
2930
import java.util.Set;
3031
import java.util.Stack;
3132

@@ -38,6 +39,8 @@
3839
import org.apache.sysds.common.Types.OpOpData;
3940
import org.apache.sysds.common.Types.ParamBuiltinOp;
4041
import org.apache.sysds.common.Types.ReOrgOp;
42+
import org.apache.sysds.conf.ConfigurationManager;
43+
import org.apache.sysds.conf.DMLConfig;
4144
import org.apache.sysds.hops.AggBinaryOp;
4245
import org.apache.sysds.hops.AggUnaryOp;
4346
import org.apache.sysds.hops.BinaryOp;
@@ -81,9 +84,15 @@ public class WorkloadAnalyzer {
8184
private final DMLProgram prog;
8285
private final Map<Long, Op> treeLookup;
8386
private final Stack<Hop> stack;
87+
private Stack<StatementBlock> lineage = new Stack<>();
8488

8589
public static Map<Long, WTreeRoot> getAllCandidateWorkloads(DMLProgram prog) {
8690
// extract all compression candidates from program (in program order)
91+
String configValue = ConfigurationManager.getDMLConfig()
92+
.getTextValue(DMLConfig.COMPRESSED_LINALG_INTERMEDIATE);
93+
// if set update it, otherwise keep it as set before
94+
ALLOW_INTERMEDIATE_CANDIDATES = configValue != null && Objects.equals(configValue.toUpperCase(), "TRUE") ||
95+
configValue == null && ALLOW_INTERMEDIATE_CANDIDATES;
8796
List<Hop> candidates = getCandidates(prog);
8897

8998
// for each candidate, create pruned workload tree
@@ -115,6 +124,7 @@ private WorkloadAnalyzer(DMLProgram prog) {
115124
this.overlapping = new HashSet<>();
116125
this.treeLookup = new HashMap<>();
117126
this.stack = new Stack<>();
127+
this.lineage = new Stack<>();
118128
}
119129

120130
private WorkloadAnalyzer(DMLProgram prog, Set<Long> compressed, HashMap<String, Long> transientCompressed,
@@ -235,6 +245,7 @@ private static void getCandidates(Hop hop, DMLProgram prog, List<Hop> cands, Set
235245

236246
private void createWorkloadTreeNodes(AWTreeNode n, StatementBlock sb, DMLProgram prog, Set<String> fStack) {
237247
WTreeNode node;
248+
lineage.add(sb);
238249
if(sb instanceof FunctionStatementBlock) {
239250
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
240251
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
@@ -291,7 +302,7 @@ else if(sb instanceof ForStatementBlock) { // incl parfor
291302
if(hop instanceof FunctionOp) {
292303
FunctionOp fop = (FunctionOp) hop;
293304
if(HopRewriteUtils.isTransformEncode(fop))
294-
return;
305+
break;
295306
else if(!fStack.contains(fop.getFunctionKey())) {
296307
fStack.add(fop.getFunctionKey());
297308
FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionKey());
@@ -323,9 +334,11 @@ else if(!fStack.contains(fop.getFunctionKey())) {
323334
}
324335
}
325336
}
337+
lineage.pop();
326338
return;
327339
}
328340
n.addChild(node);
341+
lineage.pop();
329342
}
330343

331344
private void createStack(Hop hop) {
@@ -396,7 +409,22 @@ else if(hop instanceof AggUnaryOp) {
396409
return;
397410
}
398411
else {
399-
o = new OpNormal(hop, false);
412+
boolean compressedOut = false;
413+
Hop parentHop = hop.getInput(0);
414+
if(HopRewriteUtils.isBinary(parentHop, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS,
415+
OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL)){
416+
Hop leftIn = parentHop.getInput(0);
417+
Hop rightIn = parentHop.getInput(1);
418+
// input ops might be not in the current statement block -> check for transient reads
419+
if(HopRewriteUtils.isAggUnaryOp(leftIn, AggOp.MIN, AggOp.MAX) ||
420+
HopRewriteUtils.isAggUnaryOp(rightIn, AggOp.MIN, AggOp.MAX) ||
421+
checkTransientRead(hop, leftIn) ||
422+
checkTransientRead(hop, rightIn)
423+
)
424+
compressedOut = true;
425+
426+
}
427+
o = new OpNormal(hop, compressedOut);
400428
}
401429
}
402430
else if(hop instanceof UnaryOp) {
@@ -477,9 +505,17 @@ else if(ol) {
477505
if(!HopRewriteUtils.isBinarySparseSafe(hop))
478506
o.setDensifying();
479507

508+
} else if(HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) ) {
509+
Hop leftIn = hop.getInput(0);
510+
Hop rightIn = hop.getInput(1);
511+
if(HopRewriteUtils.isBinary(hop, OpOp2.DIV) && rightIn instanceof AggUnaryOp && leftIn == rightIn.getInput(0)){
512+
o = new OpNormal(hop, true);
513+
} else {
514+
setDecompressionOnAllInputs(hop, parent);
515+
return;
516+
}
480517
}
481518
else if(HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) ||
482-
HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) ||
483519
HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) {
484520
setDecompressionOnAllInputs(hop, parent);
485521
return;
@@ -623,6 +659,40 @@ else if(hop instanceof AggUnaryOp) {
623659
}
624660
}
625661

662+
private boolean checkTransientRead(Hop hop, Hop input) {
663+
// op is not in current statement block
664+
if(HopRewriteUtils.isData(input, OpOpData.TRANSIENTREAD)){
665+
String varName = input.getName();
666+
StatementBlock csb = lineage.peek();
667+
StatementBlock parentStatement = lineage.get(lineage.size() -2);
668+
669+
if(parentStatement instanceof WhileStatementBlock) {
670+
WhileStatementBlock wsb = (WhileStatementBlock) parentStatement;
671+
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
672+
ArrayList<StatementBlock> stmts = wstmt.getBody();
673+
boolean foundCurrent = false;
674+
StatementBlock sb;
675+
676+
// traverse statement blocks in reverse to find the statement block, which came before the current
677+
// if we iterate in default order, we might find an earlier updated version of the current variable
678+
for (int i = stmts.size()-1; i >= 0; i--) {
679+
sb = stmts.get(i);
680+
if(foundCurrent && sb.variablesUpdated().containsVariable(varName)) {
681+
for(Hop cand : sb.getHops()){
682+
if(HopRewriteUtils.isData(cand, OpOpData.TRANSIENTWRITE) && cand.getName().equals(varName)
683+
&& HopRewriteUtils.isAggUnaryOp( cand.getInput(0), AggOp.MIN, AggOp.MAX)){
684+
return true;
685+
}
686+
}
687+
} else if(sb == csb){
688+
foundCurrent = true;
689+
}
690+
}
691+
}
692+
}
693+
return false;
694+
}
695+
626696
private boolean isCompressed(Hop hop) {
627697
return compressed.contains(hop.getHopID());
628698
}

src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,28 @@ public int putIfAbsentI(K key, int value) {
152152

153153
}
154154

155+
public int putIfAbsentReturnVal(K key, int value) {
156+
157+
if(key == null) {
158+
if(nullV == -1) {
159+
size++;
160+
nullV = value;
161+
return -1;
162+
}
163+
else
164+
return nullV;
165+
}
166+
else {
167+
final int ix = hash(key);
168+
Node<K> b = buckets[ix];
169+
if(b == null)
170+
return createBucketReturnVal(ix, key, value);
171+
else
172+
return putIfAbsentBucketReturnval(ix, key, value);
173+
}
174+
175+
}
176+
155177
private int putIfAbsentBucket(int ix, K key, int value) {
156178
Node<K> b = buckets[ix];
157179
while(true) {
@@ -167,6 +189,21 @@ private int putIfAbsentBucket(int ix, K key, int value) {
167189
}
168190
}
169191

192+
private int putIfAbsentBucketReturnval(int ix, K key, int value) {
193+
Node<K> b = buckets[ix];
194+
while(true) {
195+
if(b.key.equals(key))
196+
return b.value;
197+
if(b.next == null) {
198+
b.setNext(new Node<>(key, value, null));
199+
size++;
200+
resize();
201+
return value;
202+
}
203+
b = b.next;
204+
}
205+
}
206+
170207
public int putI(K key, int value) {
171208
if(key == null) {
172209
int tmp = nullV;
@@ -191,6 +228,12 @@ private int createBucket(int ix, K key, int value) {
191228
return -1;
192229
}
193230

231+
private int createBucketReturnVal(int ix, K key, int value) {
232+
buckets[ix] = new Node<K>(key, value, null);
233+
size++;
234+
return value;
235+
}
236+
194237
private int addToBucket(int ix, K key, int value) {
195238
Node<K> b = buckets[ix];
196239
while(true) {

src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpCustomTest.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@
2222
import static org.mockito.Mockito.spy;
2323
import static org.mockito.Mockito.when;
2424

25+
import org.apache.commons.lang3.tuple.Pair;
2526
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
2627
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
28+
import org.apache.sysds.runtime.compress.CompressionStatistics;
2729
import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp;
30+
import org.apache.sysds.runtime.functionobjects.GreaterThanEquals;
31+
import org.apache.sysds.runtime.functionobjects.LessThanEquals;
2832
import org.apache.sysds.runtime.functionobjects.Minus;
2933
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
3034
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -47,6 +51,35 @@ public void notColVector() {
4751
TestUtils.compareMatricesBitAvgDistance(new MatrixBlock(10, 10, 1.3), cRet2, 0, 0, op.toString());
4852
}
4953

54+
@Test
55+
public void twoHotEncodedOutput() {
56+
BinaryOperator op = new BinaryOperator(LessThanEquals.getLessThanEqualsFnObject(), 2);
57+
BinaryOperator op2 = new BinaryOperator(LessThanEquals.getLessThanEqualsFnObject());
58+
BinaryOperator opLeft = new BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject(), 2);
59+
BinaryOperator opLeft2 = new BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject());
60+
61+
MatrixBlock cDense = new MatrixBlock(30, 30, 2.0);
62+
for (int i = 0; i < 30; i++) {
63+
cDense.set(i,0, 1);
64+
}
65+
cDense.set(0,1, 1);
66+
Pair<MatrixBlock, CompressionStatistics> pair = CompressedMatrixBlockFactory.compress(cDense, 1);
67+
CompressedMatrixBlock c = (CompressedMatrixBlock) pair.getKey();
68+
MatrixBlock c2 = new MatrixBlock(30, 1, 1.0);
69+
CompressedMatrixBlock spy = spy(c);
70+
when(spy.getCachedDecompressed()).thenReturn(null);
71+
72+
MatrixBlock cRet = CLALibBinaryCellOp.binaryOperationsRight(op, spy, c2);
73+
MatrixBlock cRet2 = CLALibBinaryCellOp.binaryOperationsRight(op2, spy, c2);
74+
TestUtils.compareMatricesBitAvgDistance(cRet, cRet2, 0, 0, op.toString());
75+
76+
MatrixBlock cRetleft = CLALibBinaryCellOp.binaryOperationsLeft(opLeft, spy, c2);
77+
MatrixBlock cRetleft2 = CLALibBinaryCellOp.binaryOperationsLeft(opLeft2, spy, c2);
78+
TestUtils.compareMatricesBitAvgDistance(cRetleft, cRetleft2, 0, 0, op.toString());
79+
80+
TestUtils.compareMatricesBitAvgDistance(cRet, cRetleft, 0, 0, op.toString());
81+
}
82+
5083
@Test
5184
public void notColVectorEmptyReturn() {
5285
BinaryOperator op = new BinaryOperator(Minus.getMinusFnObject(), 2);

0 commit comments

Comments
 (0)