Skip to content

Commit 421e09b

Browse files
committed
more mapping tests
1 parent cc5c61e commit 421e09b

File tree

4 files changed

+157
-0
lines changed

4 files changed

+157
-0
lines changed

src/test/java/org/apache/sysds/test/component/compress/mapping/CustomMappingTest.java

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,24 @@
1919

2020
package org.apache.sysds.test.component.compress.mapping;
2121

22+
import static org.junit.Assert.assertEquals;
23+
import static org.junit.Assert.assertFalse;
24+
import static org.junit.Assert.assertThrows;
25+
import static org.junit.Assert.assertTrue;
2226
import static org.junit.Assert.fail;
27+
import static org.mockito.Mockito.mock;
28+
import static org.mockito.Mockito.spy;
29+
import static org.mockito.Mockito.when;
2330

31+
import org.apache.commons.lang3.NotImplementedException;
2432
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
33+
import org.apache.sysds.runtime.compress.DMLCompressionException;
34+
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
2535
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
36+
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
37+
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
38+
import org.apache.sysds.runtime.data.DenseBlock;
39+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
2640
import org.junit.Test;
2741

2842
public class CustomMappingTest {
@@ -49,4 +63,87 @@ public void createBinary() {
4963
fail(e.getMessage());
5064
}
5165
}
66+
67+
@Test
68+
public void verifySpy() {
69+
CompressedMatrixBlock.debug = true;
70+
AMapToData d = MapToFactory.create(data, 2);
71+
AMapToData spy = spy(d);
72+
when(spy.getIndex(2)).thenReturn(32);
73+
assertThrows(DMLCompressionException.class, () -> spy.verify());
74+
}
75+
76+
@Test
77+
public void equals() {
78+
CompressedMatrixBlock.debug = true;
79+
AMapToData d = MapToFactory.create(data, 2);
80+
AMapToData d2 = MapToFactory.create(data, 2);
81+
assertTrue(d.equals(d));
82+
assertTrue(d.equals(d2));
83+
assertFalse(d.equals(MapToFactory.create(new int[]{1,2,3}, 4)));
84+
assertFalse(d.equals(Integer.valueOf(23)));
85+
}
86+
87+
@Test
88+
public void countRuns() {
89+
CompressedMatrixBlock.debug = true;
90+
AMapToData d = MapToFactory.create(new int[] {1, 1, 1, 1, 1, 2, 2, 2, 2, 2}, 3);
91+
AOffset o = OffsetFactory.createOffset(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
92+
assertEquals(d.countRuns(o), 2);
93+
}
94+
95+
@Test
96+
public void countRuns2() {
97+
CompressedMatrixBlock.debug = true;
98+
AMapToData d = MapToFactory.create(new int[] {1, 1, 1, 1, 1, 2, 2, 2, 2, 2}, 3);
99+
AOffset o = OffsetFactory.createOffset(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11});
100+
assertEquals(d.countRuns(o), 3);
101+
}
102+
103+
@Test
104+
public void getMax() {
105+
CompressedMatrixBlock.debug = true;
106+
AMapToData d = MapToFactory.create(new int[] {1, 1, 1, 1, 1, 2, 2, 2, 2, 2}, 3);
107+
assertEquals(d.getMax(), 2);
108+
d = MapToFactory.create(new int[] {1, 1, 1, 1, 1, 2, 2, 2, 5, 2}, 10);
109+
assertEquals(d.getMax(), 5);
110+
d = MapToFactory.create(new int[] {1, 1, 1, 9, 1, 2, 2, 2, 2, 2}, 10);
111+
assertEquals(d.getMax(), 9);
112+
}
113+
114+
@Test
115+
public void copyInt(){
116+
CompressedMatrixBlock.debug = true;
117+
AMapToData d = MapToFactory.create(new int[] {10,9,8,7,6,5,4,3,2,1}, 11);
118+
AMapToData d2 = MapToFactory.create(new int[] {1,2,3,4,5,6,7,8,9,10}, Integer.MAX_VALUE -2);
119+
d.copy(d2);
120+
for(int i = 0; i < 10; i ++){
121+
assertEquals(d.getIndex(i), d2.getIndex(i));
122+
}
123+
}
124+
125+
@Test
126+
public void setInteger(){
127+
CompressedMatrixBlock.debug = true;
128+
AMapToData d = MapToFactory.create(new int[] {10,9,8,7,6,5,4,3,2,1}, 11);
129+
130+
for(int i = 0; i < 10; i ++){
131+
assertEquals(d.getIndex(i), 10- i);
132+
}
133+
d.set(4, Integer.valueOf(13));
134+
assertEquals(d.getIndex(4), 13);
135+
}
136+
137+
@Test(expected = NotImplementedException.class)
138+
public void preAggDenseNonContiguous(){
139+
AMapToData d = MapToFactory.create(new int[] {10,9,8,7,6,5,4,3,2,1}, 11);
140+
MatrixBlock mb = new MatrixBlock();
141+
MatrixBlock spy = spy(mb);
142+
DenseBlock db = mock(DenseBlock.class);
143+
when(db.isContiguous()).thenReturn(false);
144+
when(spy.getDenseBlock()).thenReturn(db);
145+
146+
d.preAggregateDense(spy, null, 10, 13,0, 10);
147+
}
148+
52149
}

src/test/java/org/apache/sysds/test/component/compress/mapping/MappingTests.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.Arrays;
3232
import java.util.Collection;
3333
import java.util.Random;
34+
import java.util.concurrent.ExecutorService;
3435

3536
import org.apache.commons.lang3.NotImplementedException;
3637
import org.apache.commons.logging.Log;
@@ -41,6 +42,7 @@
4142
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToCharPByte;
4243
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
4344
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE;
45+
import org.apache.sysds.runtime.util.CommonThreadPool;
4446
import org.junit.Test;
4547
import org.junit.runner.RunWith;
4648
import org.junit.runners.Parameterized;
@@ -350,6 +352,58 @@ public void testAppendNotSame() {
350352
LOG.error("Did not throw exception with: " + m);
351353
}
352354

355+
@Test
356+
public void splitReshapeParallel() throws Exception {
357+
if(m.size() % 2 == 0){
358+
359+
ExecutorService pool = CommonThreadPool.get();
360+
AMapToData[] ret = m.splitReshapeDDCPushDown(2, pool);
361+
362+
for(int i = 0; i < m.size(); i++){
363+
assertEquals(m.getIndex(i), ret[i % 2].getIndex(i/2));
364+
}
365+
}
366+
}
367+
368+
369+
@Test
370+
public void splitReshape2() throws Exception {
371+
if(m.size() % 2 == 0){
372+
373+
AMapToData[] ret = m.splitReshapeDDC(2);
374+
375+
for(int i = 0; i < m.size(); i++){
376+
assertEquals(m.getIndex(i), ret[i % 2].getIndex(i/2));
377+
}
378+
}
379+
}
380+
381+
@Test
382+
public void splitReshape4() throws Exception {
383+
if(m.size() % 4 == 0){
384+
385+
AMapToData[] ret = m.splitReshapeDDC(4);
386+
387+
for(int i = 0; i < m.size(); i++){
388+
assertEquals(m.getIndex(i), ret[i % 4].getIndex(i/4));
389+
}
390+
}
391+
}
392+
393+
@Test
394+
395+
public void getCounts(){
396+
int[] counts = m.getCounts();
397+
int countZeros = 0;
398+
for(int i= 0; i < m.size(); i++){
399+
if(m.getIndex(i) == 0)
400+
countZeros++;
401+
}
402+
assertEquals(counts[0], countZeros);
403+
}
404+
405+
406+
353407
private static class Holder implements IMapToDataGroup {
354408

355409
AMapToData d;

src/test/java/org/apache/sysds/test/component/compress/mapping/PreAggregateDDC_DDCTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ public static Collection<Object[]> data() {
7676
create(tests, 10000, 32, 2, 1, r.nextInt(sm));
7777
create(tests, 10000, 2, 2, 1, r.nextInt(sm));
7878
create(tests, 10000, 2, 2, 10, r.nextInt(sm));
79+
create(tests, 10005, 2, 2, 1, r.nextInt(sm));
80+
create(tests, 10005, 2, 2, 10, r.nextInt(sm));
7981

8082
createSkewed(tests, 10000, 2, 2, 10, r.nextInt(sm), 0.1);
8183
createSkewed(tests, 10000, 2, 2, 10, r.nextInt(sm), 0.01);

src/test/java/org/apache/sysds/test/component/compress/mapping/PreAggregateSDCZ_SDCZTest.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ public static Collection<Object[]> data() {
7575
create(tests, 10000, 150, 13, 1, 1000, 100, r.nextInt(sm));
7676
create(tests, 10000, 150, 149, 1, 1000, 100, r.nextInt(sm));
7777

78+
create(tests, 10000, 32, 200, 1, 100, 1000, r.nextInt(sm));
79+
create(tests, 10000, 150, 13, 1, 100, 1000, r.nextInt(sm));
80+
create(tests, 10000, 150, 149, 1, 100, 1000, r.nextInt(sm));
81+
7882
return tests;
7983
}
8084

0 commit comments

Comments
 (0)