Skip to content

Commit e6b8ef0

Browse files
committed
[MINOR] Compressed tests
1 parent fd1ba7c commit e6b8ef0

File tree

6 files changed

+733
-146
lines changed

6 files changed

+733
-146
lines changed

src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ else if(ov == OverLapping.SQUASH) {
307307
}
308308
else if(ov == OverLapping.PLUS_ROW_VECTOR) {
309309

310-
MatrixBlock v = TestUtils.generateTestMatrixBlock(1, cols, -1, 1, 1.0, 4);
310+
MatrixBlock v = TestUtils.generateTestMatrixBlock(1, cols, 0, 4, 1.0, 4);
311+
v = TestUtils.ceil(v);
311312
BinaryOperator bop = new BinaryOperator(Plus.getPlusFnObject(), _k);
312313
mb = mb.binaryOperations(bop, v, null);
313314
cmb = cmb.binaryOperations(bop, v, null);
@@ -504,13 +505,15 @@ public void testMatrixMultChain(ChainType ctype) {
504505

505506
@Test
506507
public void testVectorMatrixMult() {
507-
MatrixBlock vector = TestUtils.generateTestMatrixBlock(1, rows, 0.9, 1.5, 1.0, 3);
508+
MatrixBlock vector = TestUtils.generateTestMatrixBlock(1, rows, 0, 5, 1.0, 3);
509+
vector = TestUtils.ceil(vector);
508510
testLeftMatrixMatrix(vector);
509511
}
510512

511513
@Test
512514
public void testLeftMatrixMatrixMultSmall() {
513-
MatrixBlock matrix = TestUtils.generateTestMatrixBlock(3, rows, 0.9, 1.5, 1.0, 3);
515+
MatrixBlock matrix = TestUtils.generateTestMatrixBlock(3, rows, 0, 5, 1.0, 3);
516+
matrix = TestUtils.ceil(matrix);
514517
testLeftMatrixMatrix(matrix);
515518
}
516519

@@ -522,7 +525,8 @@ public void testLeftMatrixMatrixMultConst() {
522525

523526
@Test
524527
public void testLeftMatrixMatrixMultSparse() {
525-
MatrixBlock matrix = TestUtils.generateTestMatrixBlock(2, rows, 0.9, 1.5, .1, 3);
528+
MatrixBlock matrix = TestUtils.generateTestMatrixBlock(2, rows, 0, 5, .1, 3);
529+
matrix = TestUtils.ceil(matrix);
526530
testLeftMatrixMatrix(matrix);
527531
}
528532

@@ -1053,6 +1057,98 @@ public void testSlice(int rl, int ru, int cl, int cu) {
10531057
}
10541058
}
10551059

1060+
1061+
@Test
1062+
public void testReshape2() {
1063+
testReshape(2);
1064+
}
1065+
1066+
@Test
1067+
public void testReshape3() {
1068+
testReshape(3);
1069+
}
1070+
1071+
@Test
1072+
public void testReshape10() {
1073+
testReshape(10);
1074+
}
1075+
1076+
/**
1077+
* Test the reshape mechanic of the compressed block by reshaping the matrix by making it x times wider.
1078+
*
1079+
* @param multiplier the multiplier x.
1080+
*/
1081+
public void testReshape(int multiplier) {
1082+
try {
1083+
if((double) rows / multiplier != rows / multiplier)
1084+
return;
1085+
1086+
final MatrixBlock ret2 = cmb.reshape(rows / multiplier, cols * multiplier, true);
1087+
final MatrixBlock ret1 = mb.reshape(rows / multiplier, cols * multiplier, true);
1088+
compareResultMatrices(ret1, ret2, 1);
1089+
}
1090+
catch(Exception e) {
1091+
e.printStackTrace();
1092+
throw new DMLRuntimeException("Error in Reshape", e);
1093+
}
1094+
}
1095+
1096+
@Test
1097+
public void testReshape2_divider() {
1098+
testReshapeDivider(2);
1099+
}
1100+
1101+
@Test
1102+
public void testReshape3_divider() {
1103+
testReshapeDivider(3);
1104+
}
1105+
1106+
@Test
1107+
public void testReshape10_divider() {
1108+
testReshapeDivider(10);
1109+
}
1110+
1111+
/**
1112+
* Test the reshape mechanic of the compressed block by reshaping the matrix by making it x times taller.
1113+
*
1114+
* @param divider the divider x.
1115+
*/
1116+
public void testReshapeDivider(int divider) {
1117+
try {
1118+
if((double) cols /divider != cols / divider)
1119+
return;
1120+
1121+
final MatrixBlock ret2 = cmb.reshape(rows * divider, cols / divider, true);
1122+
final MatrixBlock ret1 = mb.reshape(rows * divider, cols / divider, true);
1123+
compareResultMatrices(ret1, ret2, 1);
1124+
}
1125+
catch(Exception e) {
1126+
e.printStackTrace();
1127+
throw new DMLRuntimeException("Error in Reshape", e);
1128+
}
1129+
}
1130+
1131+
1132+
@Test
1133+
public void testReshape_opposite() {
1134+
testReshape(cols, rows);
1135+
}
1136+
1137+
public void testReshape(int newRows, int newCols) {
1138+
try {
1139+
if((double) newRows * newCols != rows * cols)
1140+
return;
1141+
1142+
final MatrixBlock ret2 = cmb.reshape(newRows, newCols, true);
1143+
final MatrixBlock ret1 = mb.reshape(newRows, newCols, true);
1144+
compareResultMatrices(ret1, ret2, 1);
1145+
}
1146+
catch(Exception e) {
1147+
e.printStackTrace();
1148+
throw new DMLRuntimeException("Error in Reshape", e);
1149+
}
1150+
}
1151+
10561152
@Test
10571153
public void testCompressAgain() {
10581154
try {
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.component.compress.colgroup;
21+
22+
import static org.junit.Assert.fail;
23+
24+
import java.util.ArrayList;
25+
import java.util.Collection;
26+
import java.util.List;
27+
28+
import org.apache.commons.lang3.NotImplementedException;
29+
import org.apache.commons.logging.Log;
30+
import org.apache.commons.logging.LogFactory;
31+
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
32+
import org.apache.sysds.runtime.compress.CompressionSettings;
33+
import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
34+
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
35+
import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType;
36+
import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
37+
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
38+
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
39+
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
40+
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
41+
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
42+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
43+
import org.apache.sysds.test.TestUtils;
44+
import org.junit.Test;
45+
import org.junit.runner.RunWith;
46+
import org.junit.runners.Parameterized;
47+
import org.junit.runners.Parameterized.Parameters;
48+
49+
@RunWith(value = Parameterized.class)
50+
public class CombineColGroups {
51+
protected static final Log LOG = LogFactory.getLog(CombineColGroups.class.getName());
52+
53+
/** Uncompressed ground truth */
54+
final MatrixBlock mb;
55+
/** ColGroup 1 */
56+
final AColGroup a;
57+
/** ColGroup 2 */
58+
final AColGroup b;
59+
60+
@Parameters
61+
public static Collection<Object[]> data() {
62+
ArrayList<Object[]> tests = new ArrayList<>();
63+
64+
try {
65+
addTwoCols(tests, 100, 3);
66+
addTwoCols(tests, 1000, 3);
67+
// addSingleVSMultiCol(tests, 100, 3, 1, 3);
68+
// addSingleVSMultiCol(tests, 100, 3, 3, 4);
69+
addSingleVSMultiCol(tests, 1000, 3, 1, 3, 1.0);
70+
addSingleVSMultiCol(tests, 1000, 3, 3, 4, 1.0);
71+
addSingleVSMultiCol(tests, 1000, 3, 3, 1, 1.0);
72+
addSingleVSMultiCol(tests, 1000, 2, 1, 10, 0.05);
73+
addSingleVSMultiCol(tests, 1000, 2, 10, 10, 0.05);
74+
addSingleVSMultiCol(tests, 1000, 2, 10, 1, 0.05);
75+
}
76+
catch(Exception e) {
77+
e.printStackTrace();
78+
fail("failed constructing tests");
79+
}
80+
81+
return tests;
82+
}
83+
84+
public CombineColGroups(MatrixBlock mb, AColGroup a, AColGroup b) {
85+
this.mb = mb;
86+
this.a = a;
87+
this.b = b;
88+
89+
CompressedMatrixBlock.debug = true;
90+
}
91+
92+
@Test
93+
public void combine() {
94+
try {
95+
AColGroup c = a.combine(b, mb.getNumRows());
96+
MatrixBlock ref = new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), false);
97+
ref.allocateDenseBlock();
98+
c.decompressToDenseBlock(ref.getDenseBlock(), 0, mb.getNumRows());
99+
ref.recomputeNonZeros();
100+
String errMessage = a.getClass().getSimpleName() + ": " + a.getColIndices() + " -- "
101+
+ b.getClass().getSimpleName() + ": " + b.getColIndices();
102+
103+
TestUtils.compareMatricesBitAvgDistance(mb, ref, 0, 0, errMessage);
104+
}
105+
catch(NotImplementedException e) {
106+
// allowed
107+
}
108+
catch(Exception e) {
109+
e.printStackTrace();
110+
fail(e.getMessage());
111+
}
112+
}
113+
114+
private static void addTwoCols(ArrayList<Object[]> tests, int nRow, int distinct) {
115+
MatrixBlock mb = TestUtils.ceil(//
116+
TestUtils.generateTestMatrixBlock(nRow, 2, 0, distinct, 1.0, 231));
117+
118+
List<AColGroup> c1s = getGroups(mb, ColIndexFactory.createI(0));
119+
List<AColGroup> c2s = getGroups(mb, ColIndexFactory.createI(1));
120+
121+
for(int i = 0; i < c1s.size(); i++) {
122+
for(int j = 0; j < c2s.size(); j++) {
123+
tests.add(new Object[] {mb, c1s.get(i), c2s.get(j)});
124+
}
125+
}
126+
}
127+
128+
private static void addSingleVSMultiCol(ArrayList<Object[]> tests, int nRow, int distinct, int nColL, int nColR,
129+
double sparsity) {
130+
MatrixBlock mb = TestUtils.ceil(//
131+
TestUtils.generateTestMatrixBlock(nRow, nColL + nColR, 0, distinct, sparsity, 231));
132+
133+
List<AColGroup> c1s = getGroups(mb, ColIndexFactory.create(nColL));
134+
List<AColGroup> c2s = getGroups(mb, ColIndexFactory.create(nColL, nColR + nColL));
135+
136+
for(int i = 0; i < c1s.size(); i++) {
137+
for(int j = 0; j < c2s.size(); j++) {
138+
tests.add(new Object[] {mb, c1s.get(0), c2s.get(0)});
139+
}
140+
}
141+
}
142+
143+
private static List<AColGroup> getGroups(MatrixBlock mb, IColIndex cols) {
144+
final CompressionSettings cs = new CompressionSettingsBuilder().create();
145+
146+
final int nRow = mb.getNumColumns();
147+
final List<CompressedSizeInfoColGroup> es = new ArrayList<>();
148+
final EstimationFactors f = new EstimationFactors(nRow, nRow, mb.getSparsity());
149+
es.add(new CompressedSizeInfoColGroup(cols, f, 312152, CompressionType.DDC));
150+
es.add(new CompressedSizeInfoColGroup(cols, f, 321521, CompressionType.RLE));
151+
es.add(new CompressedSizeInfoColGroup(cols, f, 321452, CompressionType.SDC));
152+
es.add(new CompressedSizeInfoColGroup(cols, f, 325151, CompressionType.UNCOMPRESSED));
153+
final CompressedSizeInfo csi = new CompressedSizeInfo(es);
154+
return ColGroupFactory.compressColGroups(mb, csi, cs);
155+
}
156+
}

0 commit comments

Comments
 (0)