Skip to content

Commit 74b9d86

Browse files
anuunchinBaunsgaard
authored andcommitted
[SYSTEMDS-3780] Compression-fused Quantization
This commit adds a new fused operator that both quantize an input and compresses the result. The operator does not allocate the intermediate quantized matrix. Closes #2226
1 parent 744d98a commit 74b9d86

32 files changed

+2183
-114
lines changed

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ public enum Builtins {
8989
COLVAR("colVars", false),
9090
COMPONENTS("components", true),
9191
COMPRESS("compress", false, ReturnType.MULTI_RETURN),
92+
QUANTIZE_COMPRESS("quantize_compress", false, ReturnType.MULTI_RETURN),
9293
CONFUSIONMATRIX("confusionMatrix", true),
9394
CONV2D("conv2d", false),
9495
CONV2D_BACKWARD_FILTER("conv2d_backward_filter", false),

src/main/java/org/apache/sysds/common/InstructionType.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ public enum InstructionType {
5050
Partition,
5151
Compression,
5252
DeCompression,
53+
QuantizeCompression,
5354
SpoofFused,
5455
Prefetch,
5556
EvictLineageCache,

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ public enum Opcodes {
296296
PARTITION("partition", InstructionType.Partition),
297297
COMPRESS(Compression.OPCODE, InstructionType.Compression, InstructionType.Compression),
298298
DECOMPRESS(DeCompression.OPCODE, InstructionType.DeCompression, InstructionType.DeCompression),
299+
QUANTIZE_COMPRESS("quantize_compress", InstructionType.QuantizeCompression),
299300
SPOOF("spoof", InstructionType.SpoofFused),
300301
PREFETCH("prefetch", InstructionType.Prefetch),
301302
EVICT("_evict", InstructionType.EvictLineageCache),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -634,8 +634,9 @@ public enum OpOp2 {
634634
//fused ML-specific operators for performance
635635
MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=))
636636
LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5)
637-
MINUS1_MULT(false); //1-X*Y
638-
637+
MINUS1_MULT(false), //1-X*Y
638+
QUANTIZE_COMPRESS(false); //quantization-fused compression
639+
639640
private final boolean _validOuter;
640641

641642
private OpOp2(boolean outer) {

src/main/java/org/apache/sysds/hops/OptimizerUtils.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,15 @@ public enum MemoryManager {
280280
*/
281281
public static boolean ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND = true;
282282

283+
/**
284+
* This variable allows for insertion of Quantize and compress in the dml script from the user.
285+
*/
286+
public static boolean ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND = true;
287+
288+
/**
289+
* Boolean specifying if quantization-fused compression rewrite is allowed.
290+
*/
291+
public static boolean ALLOW_QUANTIZE_COMPRESS_REWRITE = true;
283292

284293
/**
285294
* Boolean specifying if compression rewrites is allowed. This is disabled at run time if the IPA for Workload aware compression

src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
9090
if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
9191
_dagRuleSet.add( new RewriteIndexingVectorization() ); //dependency: cse, simplifications
9292
_dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock
93-
93+
if( OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE )
94+
_dagRuleSet.add( new RewriteQuantizationFusedCompression() );
95+
9496
//add statement block rewrite rules
9597
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )
9698
_sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.hops.rewrite;
21+
22+
import java.util.ArrayList;
23+
import java.util.HashMap;
24+
import java.util.List;
25+
import java.util.Map.Entry;
26+
27+
import org.apache.sysds.common.Types.OpOp1;
28+
import org.apache.sysds.common.Types.OpOp2;
29+
import org.apache.sysds.hops.UnaryOp;
30+
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
31+
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
32+
import org.apache.sysds.hops.BinaryOp;
33+
34+
import org.apache.sysds.common.Types.DataType;
35+
import org.apache.sysds.common.Types.ValueType;
36+
37+
import org.apache.sysds.hops.Hop;
38+
39+
/**
40+
* Rule: RewriteFloorCompress. Detects the sequence `M2 = floor(M * S)` followed by `C = compress(M2)` and prepares for
41+
* fusion into a single operation. This rewrite improves performance by avoiding intermediate results. Currently, it
42+
* identifies the pattern without applying fusion.
43+
*/
44+
public class RewriteQuantizationFusedCompression extends HopRewriteRule {
45+
@Override
46+
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
47+
if(roots == null)
48+
return null;
49+
50+
// traverse the HOP DAG
51+
HashMap<String, Hop> floors = new HashMap<>();
52+
HashMap<String, Hop> compresses = new HashMap<>();
53+
for(Hop h : roots)
54+
collectFloorCompressSequences(h, floors, compresses);
55+
56+
Hop.resetVisitStatus(roots);
57+
58+
// check compresses for compress-after-floor pattern
59+
for(Entry<String, Hop> e : compresses.entrySet()) {
60+
String inputname = e.getKey();
61+
Hop compresshop = e.getValue();
62+
63+
if(floors.containsKey(inputname) // floors same name
64+
&& ((floors.get(inputname).getBeginLine() < compresshop.getBeginLine()) ||
65+
(floors.get(inputname).getEndLine() < compresshop.getEndLine()) ||
66+
(floors.get(inputname).getBeginLine() == compresshop.getBeginLine() &&
67+
floors.get(inputname).getEndLine() == compresshop.getBeginLine() &&
68+
floors.get(inputname).getBeginColumn() < compresshop.getBeginColumn()))) {
69+
70+
// retrieve the floor hop and inputs
71+
Hop floorhop = floors.get(inputname);
72+
Hop floorInput = floorhop.getInput().get(0);
73+
74+
// check if the input of the floor operation is a matrix
75+
if(floorInput.getDataType() == DataType.MATRIX) {
76+
77+
// Check if the input of the floor operation involves a multiplication operation
78+
if(floorInput instanceof BinaryOp && ((BinaryOp) floorInput).getOp() == OpOp2.MULT) {
79+
Hop initialMatrix = floorInput.getInput().get(0);
80+
Hop sf = floorInput.getInput().get(1);
81+
82+
// create fused hop
83+
BinaryOp fusedhop = new BinaryOp("test", DataType.MATRIX, ValueType.FP64,
84+
OpOp2.QUANTIZE_COMPRESS, initialMatrix, sf);
85+
86+
// rewire compress consumers to fusedHop
87+
List<Hop> parents = new ArrayList<>(compresshop.getParent());
88+
for(Hop p : parents) {
89+
HopRewriteUtils.replaceChildReference(p, compresshop, fusedhop);
90+
}
91+
}
92+
}
93+
}
94+
}
95+
return roots;
96+
}
97+
98+
@Override
99+
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
100+
// do nothing, floor/compress do not occur in predicates
101+
return root;
102+
}
103+
104+
private void collectFloorCompressSequences(Hop hop, HashMap<String, Hop> floors, HashMap<String, Hop> compresses) {
105+
if(hop.isVisited())
106+
return;
107+
108+
// process childs
109+
if(!hop.getInput().isEmpty())
110+
for(Hop c : hop.getInput())
111+
collectFloorCompressSequences(c, floors, compresses);
112+
113+
// process current hop
114+
if(hop instanceof UnaryOp) {
115+
UnaryOp uop = (UnaryOp) hop;
116+
if(uop.getOp() == OpOp1.FLOOR) {
117+
floors.put(uop.getName(), uop);
118+
}
119+
else if(uop.getOp() == OpOp1.COMPRESS) {
120+
compresses.put(uop.getInput(0).getName(), uop);
121+
}
122+
}
123+
hop.setVisited();
124+
}
125+
}

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ else if(((ConstIdentifier) getThirdExpr().getOutput())
751751
else
752752
raiseValidateError("Compress/DeCompress instruction not allowed in dml script");
753753
break;
754-
754+
755755
default: //always unconditional
756756
raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false);
757757
}
@@ -2013,6 +2013,34 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV
20132013
else
20142014
raiseValidateError("Compress/DeCompress instruction not allowed in dml script");
20152015
break;
2016+
case QUANTIZE_COMPRESS:
2017+
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) {
2018+
checkNumParameters(2);
2019+
Expression firstExpr = getFirstExpr();
2020+
Expression secondExpr = getSecondExpr();
2021+
2022+
checkMatrixParam(getFirstExpr());
2023+
2024+
if(secondExpr != null) {
2025+
// check if scale factor is a scalar, vector or matrix
2026+
checkMatrixScalarParam(secondExpr);
2027+
// if scale factor is a vector or matrix, make sure it has an appropriate shape
2028+
if(secondExpr.getOutput().getDataType() != DataType.SCALAR) {
2029+
if(is1DMatrix(secondExpr)) {
2030+
long vectorLength = secondExpr.getOutput().getDim1();
2031+
if(vectorLength != firstExpr.getOutput().getDim1()) {
2032+
raiseValidateError(
2033+
"The length of the row-wise scale factor vector must match the number of rows in the matrix.");
2034+
}
2035+
}
2036+
else {
2037+
checkMatchingDimensions(firstExpr, secondExpr);
2038+
}
2039+
}
2040+
}
2041+
}
2042+
break;
2043+
20162044
case ROW_COUNT_DISTINCT:
20172045
checkNumParameters(1);
20182046
checkMatrixParam(getFirstExpr());

src/main/java/org/apache/sysds/parser/DMLTranslator.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2585,6 +2585,9 @@ else if ( sop.equalsIgnoreCase(Opcodes.NOTEQUAL.toString()) )
25852585
case DECOMPRESS:
25862586
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.DECOMPRESS, expr);
25872587
break;
2588+
case QUANTIZE_COMPRESS:
2589+
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2);
2590+
break;
25882591

25892592
// Boolean binary
25902593
case XOR:

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
5151
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
5252
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
53+
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
5354
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
5455
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
5556
import org.apache.sysds.runtime.util.CommonThreadPool;
@@ -137,6 +138,21 @@ public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb,
137138
return compress(mb, k, new CompressionSettingsBuilder(), root);
138139
}
139140

141+
public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, MatrixBlock sf, int k, WTreeRoot root) {
142+
// Handle only row vectors, as column-wise quantization is not allowed.
143+
// The restriction is handled upstream
144+
double[] scaleFactors = sf.getDenseBlockValues();
145+
CompressionSettingsBuilder builder = new CompressionSettingsBuilder().setScaleFactor(scaleFactors);
146+
return compress(mb, k, builder, root);
147+
}
148+
149+
public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, ScalarObject sf, int k, WTreeRoot root) {
150+
double[] scaleFactors = new double[1];
151+
scaleFactors[0] = sf.getDoubleValue();
152+
CompressionSettingsBuilder builder = new CompressionSettingsBuilder().setScaleFactor(scaleFactors);
153+
return compress(mb, k, builder, root);
154+
}
155+
140156
public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, CostEstimatorBuilder csb) {
141157
return compress(mb, k, new CompressionSettingsBuilder(), csb);
142158
}
@@ -285,7 +301,7 @@ else if(mb instanceof CompressedMatrixBlock && ((CompressedMatrixBlock) mb).isOv
285301
return new ImmutablePair<>(mb, null);
286302
}
287303

288-
_stats.denseSize = MatrixBlock.estimateSizeInMemory(mb.getNumRows(), mb.getNumColumns(), 1.0);
304+
_stats.denseSize = MatrixBlock.estimateSizeInMemory(mb.getNumRows(), mb.getNumColumns(), 1.0);
289305
_stats.sparseSize = MatrixBlock.estimateSizeSparseInMemory(mb.getNumRows(), mb.getNumColumns(), mb.getSparsity());
290306
_stats.originalSize = mb.getInMemorySize();
291307
_stats.originalCost = costEstimator.getCost(mb);
@@ -300,8 +316,10 @@ else if(mb instanceof CompressedMatrixBlock && ((CompressedMatrixBlock) mb).isOv
300316

301317
res = new CompressedMatrixBlock(mb); // copy metadata and allocate soft reference
302318
logInit();
319+
303320
classifyPhase();
304-
if(compressionGroups == null)
321+
322+
if(compressionGroups == null)
305323
return abortCompression();
306324

307325
// clear extra data from analysis
@@ -491,7 +509,26 @@ private Pair<MatrixBlock, CompressionStatistics> abortCompression() {
491509
MatrixBlock ucmb = ((CompressedMatrixBlock) mb).getUncompressed("Decompressing for abort: ", k);
492510
return new ImmutablePair<>(ucmb, _stats);
493511
}
494-
return new ImmutablePair<>(mb, _stats);
512+
if(compSettings.scaleFactors == null) {
513+
LOG.warn("Scale factors are null - returning original matrix.");
514+
return new ImmutablePair<>(mb, _stats);
515+
} else {
516+
LOG.warn("Scale factors are present - returning scaled matrix.");
517+
MatrixBlock scaledMb = new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.isInSparseFormat());
518+
scaledMb.copy(mb);
519+
520+
// Apply scaling and flooring
521+
// TODO: Use internal matrix prod
522+
for(int r = 0; r < mb.getNumRows(); r++) {
523+
double scaleFactor = compSettings.scaleFactors.length == 1 ? compSettings.scaleFactors[0] : compSettings.scaleFactors[r];
524+
for(int c = 0; c < mb.getNumColumns(); c++) {
525+
double newValue = Math.floor(mb.get(r, c) * scaleFactor);
526+
scaledMb.set(r, c, newValue);
527+
}
528+
}
529+
scaledMb.recomputeNonZeros();
530+
return new ImmutablePair<>(scaledMb, _stats);
531+
}
495532
}
496533

497534
private void logInit() {

0 commit comments

Comments
 (0)