2626import org .apache .sysds .hops .estim .EstimatorBasicWorst ;
2727import org .apache .sysds .hops .estim .EstimatorBitsetMM ;
2828import org .apache .sysds .hops .estim .EstimatorDensityMap ;
29+ import org .apache .sysds .hops .estim .EstimatorLayeredGraph ;
2930import org .apache .sysds .hops .estim .EstimatorMatrixHistogram ;
3031import org .apache .sysds .hops .estim .EstimatorSample ;
32+ import org .apache .sysds .hops .estim .EstimatorSampleRa ;
3133import org .apache .sysds .hops .estim .SparsityEstimator ;
3234import org .apache .sysds .runtime .instructions .InstructionUtils ;
3335import org .apache .sysds .runtime .matrix .data .MatrixBlock ;
3739public class SelfProductTest extends AutomatedTestBase
3840{
3941 private final static int m = 2500 ;
40- private final static double sparsity1 = 0.0001 ;
41- private final static double sparsity2 = 0.000001 ;
42+ private final static double sparsity0 = 0.5 ;
43+ private final static double sparsity1 = 0.1 ;
44+ private final static double sparsity2 = 0.0001 ;
45+ private final static double sparsity3 = 0.000001 ;
4246 private final static double eps1 = 0.05 ;
4347 private final static double eps2 = 1e-4 ;
4448 private final static double eps3 = 0 ;
@@ -50,97 +54,114 @@ public void setUp() {
5054 }
5155
5256 @ Test
53- public void testBasicAvgCase1 () {
54- runSparsityEstimateTest (new EstimatorBasicAvg (), m , sparsity1 );
55- }
56-
57- @ Test
58- public void testBasicAvgCase2 () {
57+ public void testBasicAvgCase () {
58+ runSparsityEstimateTest (new EstimatorBasicAvg (), m /4 , sparsity0 );
59+ runSparsityEstimateTest (new EstimatorBasicAvg (), m /2 , sparsity1 );
5960 runSparsityEstimateTest (new EstimatorBasicAvg (), m , sparsity2 );
61+ runSparsityEstimateTest (new EstimatorBasicAvg (), m , sparsity3 );
6062 }
6163
6264 @ Test
63- public void testDensityMapCase1 () {
64- runSparsityEstimateTest (new EstimatorDensityMap (), m , sparsity1 );
65- }
66-
67- @ Test
68- public void testDensityMapCase2 () {
65+ public void testDensityMapCase () {
66+ runSparsityEstimateTest (new EstimatorDensityMap (), m /4 , sparsity0 );
67+ runSparsityEstimateTest (new EstimatorDensityMap (), m /2 , sparsity1 );
6968 runSparsityEstimateTest (new EstimatorDensityMap (), m , sparsity2 );
69+ runSparsityEstimateTest (new EstimatorDensityMap (), m , sparsity3 );
7070 }
7171
7272 @ Test
73- public void testDensityMap7Case1 () {
74- runSparsityEstimateTest (new EstimatorDensityMap (7 ), m , sparsity1 );
75- }
76-
77- @ Test
78- public void testDensityMap7Case2 () {
73+ public void testDensityMap7Case () {
74+ runSparsityEstimateTest (new EstimatorDensityMap (7 ), m /4 , sparsity0 );
75+ runSparsityEstimateTest (new EstimatorDensityMap (7 ), m /2 , sparsity1 );
7976 runSparsityEstimateTest (new EstimatorDensityMap (7 ), m , sparsity2 );
77+ runSparsityEstimateTest (new EstimatorDensityMap (7 ), m , sparsity3 );
8078 }
8179
8280 @ Test
83- public void testBitsetMatrixCase1 () {
84- runSparsityEstimateTest (new EstimatorBitsetMM (), m , sparsity1 );
81+ public void testBitsetMatrixCase () {
82+ runSparsityEstimateTest (new EstimatorBitsetMM (), m /4 , sparsity0 );
83+ runSparsityEstimateTest (new EstimatorBitsetMM (), m /2 , sparsity1 );
84+ runSparsityEstimateTest (new EstimatorBitsetMM (), m , sparsity2 );
85+ runSparsityEstimateTest (new EstimatorBitsetMM (), m , sparsity3 );
8586 }
8687
8788 @ Test
88- public void testBitsetMatrixCase2 () {
89- runSparsityEstimateTest (new EstimatorBitsetMM (), m , sparsity2 );
89+ public void testBitset2MatrixCase () {
90+ runSparsityEstimateTest (new EstimatorBitsetMM (2 ), m /4 , sparsity0 );
91+ runSparsityEstimateTest (new EstimatorBitsetMM (2 ), m /2 , sparsity1 );
92+ runSparsityEstimateTest (new EstimatorBitsetMM (2 ), m , sparsity2 );
93+ runSparsityEstimateTest (new EstimatorBitsetMM (2 ), m , sparsity3 );
9094 }
9195
9296 @ Test
93- public void testMatrixHistogramCase1 () {
94- runSparsityEstimateTest (new EstimatorMatrixHistogram (false ), m , sparsity1 );
97+ public void testMatrixHistogramCase () {
98+ runSparsityEstimateTest (new EstimatorMatrixHistogram (false ), m /4 , sparsity0 );
99+ runSparsityEstimateTest (new EstimatorMatrixHistogram (false ), m /2 , sparsity1 );
100+ runSparsityEstimateTest (new EstimatorMatrixHistogram (false ), m , sparsity2 );
101+ runSparsityEstimateTest (new EstimatorMatrixHistogram (false ), m , sparsity3 );
95102 }
96103
97104 @ Test
98- public void testMatrixHistogramCase2 () {
99- runSparsityEstimateTest (new EstimatorMatrixHistogram (false ), m , sparsity2 );
105+ public void testMatrixHistogramExceptCase () {
106+ runSparsityEstimateTest (new EstimatorMatrixHistogram (true ), m /4 , sparsity0 );
107+ runSparsityEstimateTest (new EstimatorMatrixHistogram (true ), m /2 , sparsity1 );
108+ runSparsityEstimateTest (new EstimatorMatrixHistogram (true ), m , sparsity2 );
109+ runSparsityEstimateTest (new EstimatorMatrixHistogram (true ), m , sparsity3 );
100110 }
101111
102112 @ Test
103- public void testMatrixHistogramExceptCase1 () {
104- runSparsityEstimateTest (new EstimatorMatrixHistogram (true ), m , sparsity1 );
113+ public void testSamplingDefCase () {
114+ runSparsityEstimateTest (new EstimatorSample (), m , sparsity2 );
115+ runSparsityEstimateTest (new EstimatorSample (), m , sparsity3 );
105116 }
106117
107118 @ Test
108- public void testMatrixHistogramExceptCase2 () {
109- runSparsityEstimateTest (new EstimatorMatrixHistogram (true ), m , sparsity2 );
119+ public void testSampling20Case () {
120+ runSparsityEstimateTest (new EstimatorSample (0.2 ), m , sparsity2 );
121+ runSparsityEstimateTest (new EstimatorSample (0.2 ), m , sparsity3 );
110122 }
111123
112124 @ Test
113- public void testSamplingDefCase1 () {
114- runSparsityEstimateTest (new EstimatorSample (), m , sparsity1 );
125+ public void testSamplingRaDefCase () {
126+ runSparsityEstimateTest (new EstimatorSampleRa (), m /4 , sparsity0 );
127+ runSparsityEstimateTest (new EstimatorSampleRa (), m , sparsity2 );
128+ runSparsityEstimateTest (new EstimatorSampleRa (), m , sparsity3 );
115129 }
116130
117131 @ Test
118- public void testSamplingDefCase2 () {
119- runSparsityEstimateTest (new EstimatorSample (), m , sparsity2 );
132+ public void testSamplingRa20Case () {
133+ runSparsityEstimateTest (new EstimatorSampleRa (0.2 ), m /4 , sparsity0 );
134+ runSparsityEstimateTest (new EstimatorSampleRa (0.2 ), m , sparsity2 );
135+ runSparsityEstimateTest (new EstimatorSampleRa (0.2 ), m , sparsity3 );
120136 }
121137
122138 @ Test
123- public void testSampling20Case1 () {
124- runSparsityEstimateTest (new EstimatorSample (0.2 ), m , sparsity1 );
139+ public void testLayeredGraphDefCase () {
140+ runSparsityEstimateTest (new EstimatorLayeredGraph (), m , sparsity2 );
141+ runSparsityEstimateTest (new EstimatorLayeredGraph (), m , sparsity3 );
125142 }
126143
127144 @ Test
128- public void testSampling20Case2 () {
129- runSparsityEstimateTest (new EstimatorSample (0.2 ), m , sparsity2 );
145+ public void testLayeredGraph64Case () {
146+ runSparsityEstimateTest (new EstimatorLayeredGraph (64 ), m , sparsity2 );
147+ runSparsityEstimateTest (new EstimatorLayeredGraph (64 ), m , sparsity3 );
130148 }
131149
132150 private static void runSparsityEstimateTest (SparsityEstimator estim , int n , double sp ) {
133- MatrixBlock m1 = MatrixBlock .randOperations (m , n , sp , 1 , 1 , "uniform" , 3 );
151+ MatrixBlock m1 = MatrixBlock .randOperations (n , n , sp , 1 , 1 , "uniform" , 3 );
134152 MatrixBlock m3 = m1 .aggregateBinaryOperations (m1 , m1 ,
135153 new MatrixBlock (), InstructionUtils .getMatMultOperator (1 ));
136- double spExact = OptimizerUtils .getSparsity (m , m ,
154+ double spExact1 = OptimizerUtils .getSparsity (n , n ,
137155 EstimationUtils .getSelfProductOutputNnz (m1 ));
156+ double spExact2 = sp <0.4 ? OptimizerUtils .getSparsity (n , n ,
157+ EstimationUtils .getSparseProductOutputNnz (m1 , m1 )) : spExact1 ;
138158
139159 //compare estimated and real sparsity
140160 double est = estim .estim (m1 , m1 );
141161 TestUtils .compareScalars (est , m3 .getSparsity (),
142162 (estim instanceof EstimatorBitsetMM ) ? eps3 : //exact
143163 (estim instanceof EstimatorBasicWorst ) ? eps1 : eps2 );
144- TestUtils .compareScalars (m3 .getSparsity (), spExact , eps3 );
164+ TestUtils .compareScalars (m3 .getSparsity (), spExact1 , eps3 );
165+ TestUtils .compareScalars (m3 .getSparsity (), spExact2 , eps3 );
145166 }
146167}
0 commit comments