Skip to content

Commit e12d9d2

Browse files
committed
[MINOR] Update cost estimation in CLA
Closes #2168
1 parent 28c811a commit e12d9d2

File tree

9 files changed

+210
-40
lines changed

9 files changed

+210
-40
lines changed

src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ protected double getCostSafe(CompressedSizeInfoColGroup g) {
6464
else if(g.isEmpty() || g.isConst())
6565
// const or densifying
6666
return getCost(nRows, 1, nCols, 1, 1);
67-
if(commonFraction > cvThreshold)
67+
else if(g.isIncompressable())
68+
return getCost(nRows* 3, nRows, nCols, nRows* 3, sparsity); // make incompressable very expensive.
69+
else if(commonFraction > cvThreshold)
6870
return getCost(nRows, nRows - g.getLargestOffInstances(), nCols, nVals, sparsity);
6971
else
7072
return getCost(nRows, nRows, nCols, nVals, sparsity);
@@ -142,8 +144,44 @@ private double dictionaryOpsCost(double nVals, double nCols, double sparsity) {
142144
}
143145

144146
private double leftMultCost(double nRowsScanned, double nRows, double nCols, double nVals, double sparsity) {
145-
// Plus nVals * 2 because of allocation of nVals array and scan of that
146-
final double preScalingCost = Math.max(nRowsScanned, nRows / 10) + nVals * 2;
147+
// left multiplication want more co-coding.
148+
// therefore, increase the cost if we have few columns
149+
double preScalingCost = Math.max(nRowsScanned, nRows) * 2;
150+
if ((nCols == nVals || nCols == nVals +1) && nVals > 1000){
151+
preScalingCost = 0;
152+
}
153+
// if(nCols == 1) {
154+
// nCols *= 4;
155+
// preScalingCost *= 5.0;
156+
// }
157+
// else if(nCols == 2) {
158+
// nCols *= 3;
159+
// preScalingCost *= 3.3;
160+
// }
161+
// else if(nCols == 3) {
162+
// nCols *= 2;
163+
// preScalingCost *= 1.6;
164+
// }
165+
// else if(nCols == 4) {
166+
// nCols *= 1.5;
167+
// preScalingCost *= 1.4;
168+
// }
169+
// else if(nCols > 1000)
170+
// nCols *= 1.1; // more cost if lots and lots of columns
171+
// else if(nCols > 5)
172+
// nCols *= 0.7; // scale down cost of columns.
173+
174+
// // if the number of unique values is low increase the cost.
175+
// if(nVals < 10)
176+
// nVals *= 10;
177+
// else if(nVals < 256)
178+
// nVals *= 5;
179+
// else if(nVals < 1024)
180+
// nVals *= 2;
181+
// else if(nVals > 100000)// increase the cost if the number of distinct values is high.
182+
// nVals *= 4;
183+
// else if(nVals > 60000)// increase the cost if the number of distinct values is high.
184+
// nVals *= 2;
147185
final double postScalingCost = sparsity * nVals * nCols;
148186
return leftMultCost(preScalingCost, postScalingCost);
149187
}

src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,12 @@ protected List<CompressedSizeInfoColGroup> CompressedSizeInfoColGroup(int clen,
5050

5151
@Override
5252
public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) {
53-
54-
// final IEncode map =
55-
throw new UnsupportedOperationException("Unimplemented method 'getColGroupInfo'");
53+
return null;
5654
}
5755

5856
@Override
5957
public CompressedSizeInfoColGroup getDeltaColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) {
60-
throw new UnsupportedOperationException("Unimplemented method 'getDeltaColGroupInfo'");
58+
return null;
6159
}
6260

6361
@Override
@@ -69,11 +67,11 @@ protected int worstCaseUpperBound(IColIndex columns) {
6967
}
7068
else {
7169
List<AColGroup> groups = CLALibCombineGroups.findGroupsInIndex(columns, cData.getColGroups());
72-
int nVals = 1;
70+
long nVals = 1;
7371
for(AColGroup g : groups)
7472
nVals *= g.getNumValues();
7573

76-
return Math.min(_data.getNumRows(), nVals);
74+
return Math.min(_data.getNumRows(), (int) Math.min(nVals, Integer.MAX_VALUE));
7775
}
7876
}
7977

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.compress.estim;
21+
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
25+
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
26+
import org.apache.sysds.runtime.compress.CompressionSettings;
27+
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
28+
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
29+
import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups;
30+
31+
public class ComEstCompressedSample extends ComEstSample {
32+
33+
private static boolean loggedWarning = false;
34+
35+
36+
public ComEstCompressedSample(CompressedMatrixBlock sample, CompressionSettings cs, CompressedMatrixBlock full,
37+
int k) {
38+
super(sample, cs, full, k);
39+
// cData = sample;
40+
}
41+
42+
@Override
43+
protected List<CompressedSizeInfoColGroup> CompressedSizeInfoColGroup(int clen, int k) {
44+
List<CompressedSizeInfoColGroup> ret = new ArrayList<>();
45+
final int nRow = _data.getNumRows();
46+
final List<AColGroup> fg = ((CompressedMatrixBlock) _data).getColGroups();
47+
final List<AColGroup> sg = ((CompressedMatrixBlock) _sample).getColGroups();
48+
49+
for(int i = 0; i < fg.size(); i++) {
50+
CompressedSizeInfoColGroup r = fg.get(i).getCompressionInfo(nRow);
51+
r.setMap(sg.get(i).getCompressionInfo(_sampleSize).getMap());
52+
ret.add(r);
53+
}
54+
55+
return ret;
56+
}
57+
58+
@Override
59+
public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) {
60+
if(!loggedWarning)
61+
LOG.warn("Compressed input cannot fallback to resampling " + colIndexes);
62+
loggedWarning = true;
63+
return null;
64+
}
65+
66+
@Override
67+
public CompressedSizeInfoColGroup getDeltaColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) {
68+
if(!loggedWarning)
69+
LOG.warn("Compressed input cannot fallback to resampling " + colIndexes);
70+
return null;
71+
}
72+
73+
@Override
74+
protected int worstCaseUpperBound(IColIndex columns) {
75+
CompressedMatrixBlock cData = ((CompressedMatrixBlock) _data);
76+
if(columns.size() == 1) {
77+
int id = columns.get(0);
78+
AColGroup g = cData.getColGroupForColumn(id);
79+
return g.getNumValues();
80+
}
81+
else {
82+
List<AColGroup> groups = CLALibCombineGroups.findGroupsInIndex(columns, cData.getColGroups());
83+
long nVals = 1;
84+
for(AColGroup g : groups)
85+
nVals *= g.getNumValues();
86+
87+
return Math.min(_data.getNumRows(), (int) Math.min(nVals, Integer.MAX_VALUE));
88+
}
89+
}
90+
91+
}

src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.commons.logging.LogFactory;
2424
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
2525
import org.apache.sysds.runtime.compress.CompressionSettings;
26+
import org.apache.sysds.runtime.compress.lib.CLALibSlice;
2627
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
2728

2829
public interface ComEstFactory {
@@ -37,13 +38,13 @@ public interface ComEstFactory {
3738
* @return A new CompressionSizeEstimator used to extract information of column groups
3839
*/
3940
public static AComEst createEstimator(MatrixBlock data, CompressionSettings cs, int k) {
40-
if(data instanceof CompressedMatrixBlock)
41-
return createCompressedEstimator((CompressedMatrixBlock) data, cs);
42-
4341
final int nRows = cs.transposed ? data.getNumColumns() : data.getNumRows();
4442
final int nCols = cs.transposed ? data.getNumRows() : data.getNumColumns();
4543
final double sparsity = data.getSparsity();
4644
final int sampleSize = getSampleSize(cs, nRows, nCols, sparsity);
45+
46+
if(data instanceof CompressedMatrixBlock)
47+
return createCompressedEstimator((CompressedMatrixBlock) data, cs, sampleSize, k);
4748

4849
if(data.isEmpty())
4950
return createExactEstimator(data, cs);
@@ -76,8 +77,17 @@ private static ComEstExact createExactEstimator(MatrixBlock data, CompressionSet
7677
return new ComEstExact(data, cs);
7778
}
7879

79-
private static ComEstCompressed createCompressedEstimator(CompressedMatrixBlock data, CompressionSettings cs) {
80-
LOG.debug("Using Compressed Estimator");
80+
private static AComEst createCompressedEstimator(CompressedMatrixBlock data, CompressionSettings cs, int sampleSize,
81+
int k) {
82+
if(sampleSize < data.getNumRows()) {
83+
LOG.debug("Trying to sample");
84+
final MatrixBlock slice = CLALibSlice.sliceRowsCompressed(data, 0, sampleSize);
85+
if(slice instanceof CompressedMatrixBlock) {
86+
LOG.debug("Using Sampled Compressed Estimator " + sampleSize);
87+
return new ComEstCompressedSample((CompressedMatrixBlock) slice, cs, data, k);
88+
}
89+
}
90+
LOG.debug("Using Full Compressed Estimator");
8191
return new ComEstCompressed(data, cs);
8292
}
8393

src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Random;
2424

2525
import org.apache.sysds.runtime.compress.CompressionSettings;
26+
import org.apache.sysds.runtime.compress.DMLCompressionException;
2627
import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType;
2728
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
2829
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
@@ -42,13 +43,22 @@
4243
public class ComEstSample extends AComEst {
4344

4445
/** Sample extracted from the input data */
45-
private final MatrixBlock _sample;
46+
protected final MatrixBlock _sample;
4647
/** Parallelization degree */
47-
private final int _k;
48+
protected final int _k;
4849
/** Sample size */
49-
private final int _sampleSize;
50+
protected final int _sampleSize;
5051
/** Boolean specifying if the sample is in transposed format. */
51-
private boolean _transposed;
52+
protected boolean _transposed;
53+
54+
public ComEstSample(MatrixBlock sample, CompressionSettings cs, MatrixBlock full, int k) {
55+
super(full, cs);
56+
_k = k;
57+
_transposed = cs.transposed;
58+
_sample = sample;
59+
_sampleSize = sample.getNumRows();
60+
61+
}
5262

5363
/**
5464
* CompressedSizeEstimatorSample, samples from the input data and estimates the size of the compressed matrix.
@@ -95,22 +105,44 @@ public CompressedSizeInfoColGroup getDeltaColGroupInfo(IColIndex colIndexes, int
95105
@Override
96106
protected int worstCaseUpperBound(IColIndex columns) {
97107
if(getNumColumns() == columns.size())
98-
return Math.min(getNumRows(), (int) _data.getNonZeros());
108+
return Math.min(getNumRows(), (int) Math.min(_data.getNonZeros(), Integer.MAX_VALUE));
99109
return getNumRows();
100110
}
101111

102112
@Override
103113
protected CompressedSizeInfoColGroup combine(IColIndex combinedColumns, CompressedSizeInfoColGroup g1,
104114
CompressedSizeInfoColGroup g2, int maxDistinct) {
105-
final IEncode map = g1.getMap().combine(g2.getMap());
106-
return extractInfo(map, combinedColumns, maxDistinct);
115+
try {
116+
final IEncode map = g1.getMap().combine(g2.getMap());
117+
return extractInfo(map, combinedColumns, maxDistinct);
118+
}
119+
catch(Exception e) {
120+
121+
String s1 = g1.toString();
122+
if(s1.length() > 1000)
123+
s1 = s1.substring(0, 1000);
124+
125+
String s2 = g2.toString();
126+
if(s2.length() > 1000)
127+
s2 = s2.substring(0, 1000);
128+
129+
throw new DMLCompressionException("Failed to combine :\n" + s1 + "\n\n" + s2, e);
130+
}
107131
}
108132

109133
private CompressedSizeInfoColGroup extractInfo(IEncode map, IColIndex colIndexes, int maxDistinct) {
110-
final double spar = _data.getSparsity();
111-
final EstimationFactors sampleFacts = map.extractFacts(_sampleSize, spar, spar, _cs);
112-
final EstimationFactors em = scaleFactors(sampleFacts, colIndexes, maxDistinct, map.isDense());
113-
return new CompressedSizeInfoColGroup(colIndexes, em, _cs.validCompressions, map);
134+
try {
135+
final double spar = _data.getSparsity();
136+
final EstimationFactors sampleFacts = map.extractFacts(_sampleSize, spar, spar, _cs);
137+
final EstimationFactors em = scaleFactors(sampleFacts, colIndexes, maxDistinct, map.isDense());
138+
return new CompressedSizeInfoColGroup(colIndexes, em, _cs.validCompressions, map);
139+
}
140+
catch(Exception e) {
141+
String ms = map.toString();
142+
if(ms.length() > 1000)
143+
ms = ms.substring(0, 1000);
144+
throw new DMLCompressionException("Failed to extract info: \n" + ms, e);
145+
}
114146
}
115147

116148
private EstimationFactors scaleFactors(EstimationFactors sampleFacts, IColIndex colIndexes, int maxDistinct,
@@ -125,6 +157,9 @@ private EstimationFactors scaleFactors(EstimationFactors sampleFacts, IColIndex
125157
final long nnz = calculateNNZ(colIndexes, scalingFactor);
126158
final int numOffs = calculateOffs(sampleFacts, numRows, scalingFactor, colIndexes, (int) nnz);
127159
final int estDistinct = distinctCountScale(sampleFacts, numOffs, numRows, maxDistinct, dense, nCol);
160+
// if(estDistinct < sampleFacts.numVals)
161+
// throw new DMLCompressionException("Failed estimating distinct: " + estDistinct + " should have been above "
162+
// + sampleFacts.numVals + "\n" + Arrays.toString(sampleFacts.frequencies));
128163

129164
// calculate the largest instance count.
130165
final int maxLargestInstanceCount = numRows - estDistinct + 1;
@@ -133,11 +168,9 @@ private EstimationFactors scaleFactors(EstimationFactors sampleFacts, IColIndex
133168
final int mostFrequentOffsetCount = Math.max(Math.min(maxLargestInstanceCount, scaledLargestInstanceCount),
134169
numRows - numOffs);
135170

136-
final double overallSparsity = calculateSparsity(colIndexes, nnz, scalingFactor,
137-
sampleFacts.overAllSparsity);
171+
final double overallSparsity = calculateSparsity(colIndexes, nnz, scalingFactor, sampleFacts.overAllSparsity);
138172
// For robustness safety add 10 percent more tuple sparsity
139173
final double tupleSparsity = Math.min(overallSparsity * 1.3, 1.0); // increase sparsity by 30%.
140-
141174
if(_cs.isRLEAllowed()) {
142175
final int scaledRuns = Math.max(estDistinct,
143176
calculateRuns(sampleFacts, scalingFactor, numOffs, estDistinct));
@@ -161,14 +194,14 @@ private int distinctCountScale(EstimationFactors sampleFacts, int numOffs, int n
161194
final int[] freq = sampleFacts.frequencies;
162195
if(freq == null || freq.length == 0)
163196
return numOffs; // very aggressive number of distinct
197+
maxDistinct = Math.max(maxDistinct, sampleFacts.numVals);
164198
// sampled size is smaller than actual if there was empty rows.
165199
// and the more we can reduce this value the more accurate the estimation will become.
166200
final int sampledSize = sampleFacts.numOffs;
167-
int est = SampleEstimatorFactory.distinctCount(freq, dense ? numRows : numOffs, sampledSize,
168-
_cs.estimationType);
201+
int est = SampleEstimatorFactory.distinctCount(freq, dense ? numRows : numOffs, sampledSize, _cs.estimationType);
169202
if(est > 10000)
170203
est += est * 0.5;
171-
if(nCol > 4) // Increase estimate if we get into many columns cocoding to be safe
204+
if(nCol > 4 && est > 100) // Increase estimate if we get into many columns cocoding to be safe
172205
est += ((double) est) * ((double) nCol) / 10;
173206
// Bound the estimate with the maxDistinct.
174207
return Math.max(Math.min(est, Math.min(maxDistinct, numOffs)), 1);

src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ public boolean containsZeros() {
220220

221221
private static EnumMap<CompressionType, Double> calculateCompressionSizes(IColIndex cols, EstimationFactors fact,
222222
Set<CompressionType> validCompressionTypes) {
223+
if(validCompressionTypes.size() > 10 )
224+
throw new DMLCompressionException("Invalid big number of compression types");
223225
EnumMap<CompressionType, Double> res = new EnumMap<>(CompressionType.class);
224226
for(CompressionType ct : validCompressionTypes) {
225227
double compSize = getCompressionSize(cols, ct, fact);

src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,14 @@ public EstimationFactors(int numVals, int numOffs, int largestOff, int[] frequen
8787
this.tupleSparsity = tupleSparsity;
8888

8989
if(overAllSparsity > 1 || overAllSparsity < 0)
90-
throw new DMLCompressionException("Invalid OverAllSparsity of: " + overAllSparsity);
90+
overAllSparsity = Math.max(0, Math.min(1, overAllSparsity));
9191
else if(tupleSparsity > 1 || tupleSparsity < 0)
92-
throw new DMLCompressionException("Invalid TupleSparsity of:" + tupleSparsity);
92+
tupleSparsity = Math.max(0, Math.min(1, tupleSparsity));
9393
else if(largestOff > numRows)
94-
throw new DMLCompressionException(
95-
"Invalid number of instance of most common element should be lower than number of rows. " + largestOff
96-
+ " > numRows: " + numRows);
94+
largestOff = numRows;
9795
else if(numVals > numOffs)
98-
throw new DMLCompressionException(
99-
"Num vals cannot be greater than num offs: vals: " + numVals + " offs: " + numOffs);
96+
numVals = numOffs;
97+
10098

10199
if(CompressedMatrixBlock.debug && frequencies != null) {
102100
for(int i = 0; i < frequencies.length; i++) {

src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public int size() {
5454
}
5555

5656
/**
57-
* Increment and return the id of the incremeted index.
57+
* Increment and return the id of the incremented index.
5858
*
5959
* @param key The key to increment
6060
* @return The id of the incremented entry.

0 commit comments

Comments
 (0)