Skip to content

Commit d726705

Browse files
committed
[MINOR] Improved code coverage and fixes sparsity estimators
1 parent 01586a4 commit d726705

File tree

4 files changed

+91
-49
lines changed

4 files changed

+91
-49
lines changed

src/main/java/org/apache/sysds/hops/estim/EstimationUtils.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ public static long getSelfProductOutputNnz(MatrixBlock m1) {
9494
double[] avals = a.values(i);
9595
int aix = a.pos(i);
9696
Arrays.fill(tmp, 0); //reset
97-
for( int k=aix; k<aix+n; k++ ) {
98-
double aval = avals[k];
97+
for( int k=0; k<n; k++ ) {
98+
double aval = avals[aix+k];
9999
if( aval == 0 ) continue;
100100
double[] bvals = a.values(k);
101101
int bix = a.pos(k);

src/main/java/org/apache/sysds/hops/estim/EstimatorBitsetMM.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@
4545
*/
4646
public class EstimatorBitsetMM extends SparsityEstimator
4747
{
48+
private final int _type;
49+
50+
public EstimatorBitsetMM() {
51+
this(-1);
52+
}
53+
54+
public EstimatorBitsetMM(int type) {
55+
_type = type;
56+
}
57+
4858
@Override
4959
public DataCharacteristics estim(MMNode root) {
5060
BitsetMatrix m1Map = getCachedSynopsis(root.getLeft());
@@ -205,14 +215,14 @@ public BitsetMatrix transpose() {
205215
//protected abstract BitsetMatrix reshape(int rows, int cols, boolean byrow);
206216
}
207217

208-
public static BitsetMatrix createBitset(int m, int n) {
209-
return (long)m*n < Integer.MAX_VALUE ?
218+
public BitsetMatrix createBitset(int m, int n) {
219+
return ((long)m*n < Integer.MAX_VALUE && _type != 2) ?
210220
new BitsetMatrix1(m, n) : //linearized long array
211221
new BitsetMatrix2(m, n); //bitset per row
212222
}
213223

214-
public static BitsetMatrix createBitset(MatrixBlock in) {
215-
return in.getLength() < Integer.MAX_VALUE ?
224+
public BitsetMatrix createBitset(MatrixBlock in) {
225+
return (in.getLength() < Integer.MAX_VALUE && _type != 2) ?
216226
new BitsetMatrix1(in) : //linearized long array
217227
new BitsetMatrix2(in); //bitset per row
218228
}

src/test/java/org/apache/sysds/test/component/estim/SelfProductTest.java

Lines changed: 64 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
import org.apache.sysds.hops.estim.EstimatorBasicWorst;
2727
import org.apache.sysds.hops.estim.EstimatorBitsetMM;
2828
import org.apache.sysds.hops.estim.EstimatorDensityMap;
29+
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
2930
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
3031
import org.apache.sysds.hops.estim.EstimatorSample;
32+
import org.apache.sysds.hops.estim.EstimatorSampleRa;
3133
import org.apache.sysds.hops.estim.SparsityEstimator;
3234
import org.apache.sysds.runtime.instructions.InstructionUtils;
3335
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -37,8 +39,10 @@
3739
public class SelfProductTest extends AutomatedTestBase
3840
{
3941
private final static int m = 2500;
40-
private final static double sparsity1 = 0.0001;
41-
private final static double sparsity2 = 0.000001;
42+
private final static double sparsity0 = 0.5;
43+
private final static double sparsity1 = 0.1;
44+
private final static double sparsity2 = 0.0001;
45+
private final static double sparsity3 = 0.000001;
4246
private final static double eps1 = 0.05;
4347
private final static double eps2 = 1e-4;
4448
private final static double eps3 = 0;
@@ -50,97 +54,114 @@ public void setUp() {
5054
}
5155

5256
@Test
53-
public void testBasicAvgCase1() {
54-
runSparsityEstimateTest(new EstimatorBasicAvg(), m, sparsity1);
55-
}
56-
57-
@Test
58-
public void testBasicAvgCase2() {
57+
public void testBasicAvgCase() {
58+
runSparsityEstimateTest(new EstimatorBasicAvg(), m/4, sparsity0);
59+
runSparsityEstimateTest(new EstimatorBasicAvg(), m/2, sparsity1);
5960
runSparsityEstimateTest(new EstimatorBasicAvg(), m, sparsity2);
61+
runSparsityEstimateTest(new EstimatorBasicAvg(), m, sparsity3);
6062
}
6163

6264
@Test
63-
public void testDensityMapCase1() {
64-
runSparsityEstimateTest(new EstimatorDensityMap(), m, sparsity1);
65-
}
66-
67-
@Test
68-
public void testDensityMapCase2() {
65+
public void testDensityMapCase() {
66+
runSparsityEstimateTest(new EstimatorDensityMap(), m/4, sparsity0);
67+
runSparsityEstimateTest(new EstimatorDensityMap(), m/2, sparsity1);
6968
runSparsityEstimateTest(new EstimatorDensityMap(), m, sparsity2);
69+
runSparsityEstimateTest(new EstimatorDensityMap(), m, sparsity3);
7070
}
7171

7272
@Test
73-
public void testDensityMap7Case1() {
74-
runSparsityEstimateTest(new EstimatorDensityMap(7), m, sparsity1);
75-
}
76-
77-
@Test
78-
public void testDensityMap7Case2() {
73+
public void testDensityMap7Case() {
74+
runSparsityEstimateTest(new EstimatorDensityMap(7), m/4, sparsity0);
75+
runSparsityEstimateTest(new EstimatorDensityMap(7), m/2, sparsity1);
7976
runSparsityEstimateTest(new EstimatorDensityMap(7), m, sparsity2);
77+
runSparsityEstimateTest(new EstimatorDensityMap(7), m, sparsity3);
8078
}
8179

8280
@Test
83-
public void testBitsetMatrixCase1() {
84-
runSparsityEstimateTest(new EstimatorBitsetMM(), m, sparsity1);
81+
public void testBitsetMatrixCase() {
82+
runSparsityEstimateTest(new EstimatorBitsetMM(), m/4, sparsity0);
83+
runSparsityEstimateTest(new EstimatorBitsetMM(), m/2, sparsity1);
84+
runSparsityEstimateTest(new EstimatorBitsetMM(), m, sparsity2);
85+
runSparsityEstimateTest(new EstimatorBitsetMM(), m, sparsity3);
8586
}
8687

8788
@Test
88-
public void testBitsetMatrixCase2() {
89-
runSparsityEstimateTest(new EstimatorBitsetMM(), m, sparsity2);
89+
public void testBitset2MatrixCase() {
90+
runSparsityEstimateTest(new EstimatorBitsetMM(2), m/4, sparsity0);
91+
runSparsityEstimateTest(new EstimatorBitsetMM(2), m/2, sparsity1);
92+
runSparsityEstimateTest(new EstimatorBitsetMM(2), m, sparsity2);
93+
runSparsityEstimateTest(new EstimatorBitsetMM(2), m, sparsity3);
9094
}
9195

9296
@Test
93-
public void testMatrixHistogramCase1() {
94-
runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m, sparsity1);
97+
public void testMatrixHistogramCase() {
98+
runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m/4, sparsity0);
99+
runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m/2, sparsity1);
100+
runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m, sparsity2);
101+
runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m, sparsity3);
95102
}
96103

97104
@Test
98-
public void testMatrixHistogramCase2() {
99-
runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m, sparsity2);
105+
public void testMatrixHistogramExceptCase() {
106+
runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m/4, sparsity0);
107+
runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m/2, sparsity1);
108+
runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, sparsity2);
109+
runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, sparsity3);
100110
}
101111

102112
@Test
103-
public void testMatrixHistogramExceptCase1() {
104-
runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, sparsity1);
113+
public void testSamplingDefCase() {
114+
runSparsityEstimateTest(new EstimatorSample(), m, sparsity2);
115+
runSparsityEstimateTest(new EstimatorSample(), m, sparsity3);
105116
}
106117

107118
@Test
108-
public void testMatrixHistogramExceptCase2() {
109-
runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, sparsity2);
119+
public void testSampling20Case() {
120+
runSparsityEstimateTest(new EstimatorSample(0.2), m, sparsity2);
121+
runSparsityEstimateTest(new EstimatorSample(0.2), m, sparsity3);
110122
}
111123

112124
@Test
113-
public void testSamplingDefCase1() {
114-
runSparsityEstimateTest(new EstimatorSample(), m, sparsity1);
125+
public void testSamplingRaDefCase() {
126+
runSparsityEstimateTest(new EstimatorSampleRa(), m/4, sparsity0);
127+
runSparsityEstimateTest(new EstimatorSampleRa(), m, sparsity2);
128+
runSparsityEstimateTest(new EstimatorSampleRa(), m, sparsity3);
115129
}
116130

117131
@Test
118-
public void testSamplingDefCase2() {
119-
runSparsityEstimateTest(new EstimatorSample(), m, sparsity2);
132+
public void testSamplingRa20Case() {
133+
runSparsityEstimateTest(new EstimatorSampleRa(0.2), m/4, sparsity0);
134+
runSparsityEstimateTest(new EstimatorSampleRa(0.2), m, sparsity2);
135+
runSparsityEstimateTest(new EstimatorSampleRa(0.2), m, sparsity3);
120136
}
121137

122138
@Test
123-
public void testSampling20Case1() {
124-
runSparsityEstimateTest(new EstimatorSample(0.2), m, sparsity1);
139+
public void testLayeredGraphDefCase() {
140+
runSparsityEstimateTest(new EstimatorLayeredGraph(), m, sparsity2);
141+
runSparsityEstimateTest(new EstimatorLayeredGraph(), m, sparsity3);
125142
}
126143

127144
@Test
128-
public void testSampling20Case2() {
129-
runSparsityEstimateTest(new EstimatorSample(0.2), m, sparsity2);
145+
public void testLayeredGraph64Case() {
146+
runSparsityEstimateTest(new EstimatorLayeredGraph(64), m, sparsity2);
147+
runSparsityEstimateTest(new EstimatorLayeredGraph(64), m, sparsity3);
130148
}
131149

132150
private static void runSparsityEstimateTest(SparsityEstimator estim, int n, double sp) {
133-
MatrixBlock m1 = MatrixBlock.randOperations(m, n, sp, 1, 1, "uniform", 3);
151+
MatrixBlock m1 = MatrixBlock.randOperations(n, n, sp, 1, 1, "uniform", 3);
134152
MatrixBlock m3 = m1.aggregateBinaryOperations(m1, m1,
135153
new MatrixBlock(), InstructionUtils.getMatMultOperator(1));
136-
double spExact = OptimizerUtils.getSparsity(m, m,
154+
double spExact1 = OptimizerUtils.getSparsity(n, n,
137155
EstimationUtils.getSelfProductOutputNnz(m1));
156+
double spExact2 = sp<0.4 ? OptimizerUtils.getSparsity(n, n,
157+
EstimationUtils.getSparseProductOutputNnz(m1, m1)) : spExact1;
138158

139159
//compare estimated and real sparsity
140160
double est = estim.estim(m1, m1);
141161
TestUtils.compareScalars(est, m3.getSparsity(),
142162
(estim instanceof EstimatorBitsetMM) ? eps3 : //exact
143163
(estim instanceof EstimatorBasicWorst) ? eps1 : eps2);
144-
TestUtils.compareScalars(m3.getSparsity(), spExact, eps3);
164+
TestUtils.compareScalars(m3.getSparsity(), spExact1, eps3);
165+
TestUtils.compareScalars(m3.getSparsity(), spExact2, eps3);
145166
}
146167
}

src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.sysds.hops.estim.EstimatorBasicWorst;
2525
import org.apache.sysds.hops.estim.EstimatorBitsetMM;
2626
import org.apache.sysds.hops.estim.EstimatorDensityMap;
27+
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
2728
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
2829
import org.apache.sysds.hops.estim.MMNode;
2930
import org.apache.sysds.hops.estim.SparsityEstimator;
@@ -126,6 +127,16 @@ public void testMatrixHistogramExceptCase2() {
126127
runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, k, n, n2, case2);
127128
}
128129

130+
@Test
131+
public void testLayeredGraphCase1() {
132+
runSparsityEstimateTest(new EstimatorLayeredGraph(32), m, k, n, n2, case1);
133+
}
134+
135+
@Test
136+
public void testLayeredGraphCase2() {
137+
runSparsityEstimateTest(new EstimatorLayeredGraph(32), m, k, n, n2, case2);
138+
}
139+
129140
private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, int n, int n2, double[] sp) {
130141
MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, "uniform", 1);
131142
MatrixBlock m2 = MatrixBlock.randOperations(k, n, sp[1], 1, 1, "uniform", 2);

0 commit comments

Comments
 (0)