Skip to content

Commit 3b4f6cd

Browse files
committed
[MINOR] Frame tests improvement 2
Add tests 100% test coverage for Frame/data/lib Closes #2120
1 parent d80e3a6 commit 3b4f6cd

File tree

9 files changed

+474
-100
lines changed

9 files changed

+474
-100
lines changed

src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,12 @@
3333
import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata;
3434

3535
public class FrameLibAppend {
36-
3736
protected static final Log LOG = LogFactory.getLog(FrameLibAppend.class.getName());
37+
38+
private FrameLibAppend(){
39+
// private constructor.
40+
}
41+
3842
/**
3943
* Appends the given argument FrameBlock 'that' to this FrameBlock by creating a deep copy to prevent side effects.
4044
* For cbind, the frames are appended column-wise (same number of rows), while for rbind the frames are appended
@@ -50,7 +54,7 @@ public static FrameBlock append(FrameBlock a, FrameBlock b, boolean cbind) {
5054
return ret;
5155
}
5256

53-
public static FrameBlock appendCbind(FrameBlock a, FrameBlock b) {
57+
private static FrameBlock appendCbind(FrameBlock a, FrameBlock b) {
5458
final int nRow = a.getNumRows();
5559
final int nRowB = b.getNumRows();
5660

@@ -73,7 +77,7 @@ else if(b.getNumColumns() == 0)
7377
return new FrameBlock(_schema, _colnames, _colmeta, _coldata);
7478
}
7579

76-
public static FrameBlock appendRbind(FrameBlock a, FrameBlock b) {
80+
private static FrameBlock appendRbind(FrameBlock a, FrameBlock b) {
7781
final int nCol = a.getNumColumns();
7882
final int nColB = b.getNumColumns();
7983

src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.util.ArrayList;
2323
import java.util.List;
2424
import java.util.concurrent.Callable;
25-
import java.util.concurrent.ExecutionException;
2625
import java.util.concurrent.ExecutorService;
2726
import java.util.concurrent.Future;
2827

@@ -67,11 +66,16 @@ public static FrameBlock detectSchema(FrameBlock in, double sampleFraction, int
6766
}
6867

6968
private FrameBlock apply() {
70-
final int cols = in.getNumColumns();
71-
final FrameBlock fb = new FrameBlock(UtilFunctions.nCopies(cols, ValueType.STRING));
72-
String[] schemaInfo = (k == 1) ? singleThreadApply() : parallelApply();
73-
fb.appendRow(schemaInfo);
74-
return fb;
69+
try{
70+
final int cols = in.getNumColumns();
71+
final FrameBlock fb = new FrameBlock(UtilFunctions.nCopies(cols, ValueType.STRING));
72+
String[] schemaInfo = (k == 1) ? singleThreadApply() : parallelApply();
73+
fb.appendRow(schemaInfo);
74+
return fb;
75+
}
76+
catch(Exception e){
77+
throw new DMLRuntimeException("Failed to detect schema", e);
78+
}
7579
}
7680

7781
private String[] singleThreadApply() {
@@ -84,7 +88,7 @@ private String[] singleThreadApply() {
8488
return schemaInfo;
8589
}
8690

87-
private String[] parallelApply() {
91+
private String[] parallelApply() throws Exception {
8892
final ExecutorService pool = CommonThreadPool.get(k);
8993
try {
9094
final int cols = in.getNumColumns();
@@ -99,9 +103,6 @@ private String[] parallelApply() {
99103

100104
return schemaInfo;
101105
}
102-
catch(ExecutionException | InterruptedException e) {
103-
throw new DMLRuntimeException("Exception interrupted or exception thrown in detectSchema", e);
104-
}
105106
finally{
106107
pool.shutdown();
107108
}

src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -290,33 +290,30 @@ public static ValueType isType(double val, ValueType min) {
290290
}
291291

292292
public static FrameBlock mergeSchema(FrameBlock temp1, FrameBlock temp2) {
293-
String[] rowTemp1 = IteratorFactory.getStringRowIterator(temp1).next();
294-
String[] rowTemp2 = IteratorFactory.getStringRowIterator(temp2).next();
293+
final int nCol = temp1.getNumColumns();
295294

296-
if(rowTemp1.length != rowTemp2.length)
297-
throw new DMLRuntimeException("Schema dimension " + "mismatch: " + rowTemp1.length + " vs " + rowTemp2.length);
295+
if(nCol != temp2.getNumColumns())
296+
throw new DMLRuntimeException("Schema dimension mismatch: " + nCol + " vs " + temp2.getNumColumns());
298297

299-
for(int i = 0; i < rowTemp1.length; i++) {
298+
// hack reuse input temp1 schema, only valid if temp1 never change schema.
299+
// However, this is typically valid.
300+
FrameBlock mergedFrame = new FrameBlock(temp1.getSchema());
301+
mergedFrame.ensureAllocatedColumns(1);
302+
for(int i = 0; i < nCol; i++) {
303+
String s1 = (String) temp1.get(0, i);
304+
String s2 = (String) temp2.get(0, i);
300305
// modify schema1 if necessary (different schema2)
301-
if(!rowTemp1[i].equals(rowTemp2[i])) {
302-
if(rowTemp1[i].equals("STRING") || rowTemp2[i].equals("STRING"))
303-
rowTemp1[i] = "STRING";
304-
else if(rowTemp1[i].equals("FP64") || rowTemp2[i].equals("FP64"))
305-
rowTemp1[i] = "FP64";
306-
else if(rowTemp1[i].equals("FP32") &&
307-
new ArrayList<>(Arrays.asList("INT64", "INT32", "CHARACTER")).contains(rowTemp2[i]))
308-
rowTemp1[i] = "FP32";
309-
else if(rowTemp1[i].equals("INT64") &&
310-
new ArrayList<>(Arrays.asList("INT32", "CHARACTER")).contains(rowTemp2[i]))
311-
rowTemp1[i] = "INT64";
312-
else if(rowTemp1[i].equals("INT32") || rowTemp2[i].equals("CHARACTER"))
313-
rowTemp1[i] = "INT32";
306+
if(!s1.equals(s2)) {
307+
ValueType v1 = ValueType.valueOf(s1);
308+
ValueType v2 = ValueType.valueOf(s2);
309+
ValueType vc = ValueType.getHighestCommonTypeSafe(v1, v2);
310+
mergedFrame.set(0, i, vc.toString());
311+
}
312+
else{
313+
mergedFrame.set(0, i, s1);
314314
}
315315
}
316316

317-
// create output block one row representing the schema as strings
318-
FrameBlock mergedFrame = new FrameBlock(UtilFunctions.nCopies(temp1.getNumColumns(), ValueType.STRING));
319-
mergedFrame.appendRow(rowTemp1);
320317
return mergedFrame;
321318
}
322319

src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java

Lines changed: 77 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import org.apache.commons.logging.Log;
2727
import org.apache.commons.logging.LogFactory;
28+
import org.apache.sysds.runtime.DMLRuntimeException;
2829
import org.apache.sysds.runtime.data.DenseBlock;
2930
import org.apache.sysds.runtime.frame.data.FrameBlock;
3031
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -40,28 +41,56 @@ public interface MatrixBlockFromFrame {
4041
* Converts a frame block with arbitrary schema into a matrix block. Since matrix block only supports value type
4142
* double, we do a best effort conversion of non-double types which might result in errors for non-numerical data.
4243
*
43-
* @param frame frame block
44-
* @param k parallelization degree
45-
* @return matrix block
44+
* @param frame Frame block to convert
45+
* @param k The parallelization degree
46+
* @return MatrixBlock
4647
*/
4748
public static MatrixBlock convertToMatrixBlock(FrameBlock frame, int k) {
48-
final int m = frame.getNumRows();
49-
final int n = frame.getNumColumns();
50-
final MatrixBlock mb = new MatrixBlock(m, n, false);
51-
mb.allocateDenseBlock();
52-
if(k == -1)
53-
k = InfrastructureAnalyzer.getLocalParallelism();
54-
55-
long nnz = 0;
56-
if(k == 1)
57-
nnz = convert(frame, mb, n, 0, m);
58-
else
59-
nnz = convertParallel(frame, mb, m, n, k);
49+
return convertToMatrixBlock(frame, null, k);
50+
}
6051

61-
mb.setNonZeros(nnz);
52+
/**
53+
* Converts a frame block with arbitrary schema into a matrix block. Since matrix block only supports value type
54+
* double, we do a best effort conversion of non-double types which might result in errors for non-numerical data.
55+
*
56+
* @param frame FrameBlock to convert
57+
* @param ret The returned MatrixBlock
58+
* @param k The parallelization degree
59+
* @return MatrixBlock
60+
*/
61+
public static MatrixBlock convertToMatrixBlock(FrameBlock frame, MatrixBlock ret, int k) {
62+
try {
6263

63-
mb.examSparsity();
64-
return mb;
64+
final int m = frame.getNumRows();
65+
final int n = frame.getNumColumns();
66+
ret = allocateRet(ret, m, n);
67+
68+
if(k == -1)
69+
k = InfrastructureAnalyzer.getLocalParallelism();
70+
71+
long nnz = 0;
72+
if(k == 1)
73+
nnz = convert(frame, ret, n, 0, m);
74+
else
75+
nnz = convertParallel(frame, ret, m, n, k);
76+
77+
ret.setNonZeros(nnz);
78+
ret.examSparsity();
79+
return ret;
80+
}
81+
catch(Exception e) {
82+
throw new DMLRuntimeException("Failed to convert FrameBlock to MatrixBlock", e);
83+
}
84+
}
85+
86+
private static MatrixBlock allocateRet(MatrixBlock ret, final int m, final int n) {
87+
if(ret == null)
88+
ret = new MatrixBlock(m, n, false);
89+
else if(ret.getNumRows() != m || ret.getNumColumns() != n || ret.isInSparseFormat())
90+
ret.reset(m, n, false);
91+
if(!ret.isAllocated())
92+
ret.allocateDenseBlock();
93+
return ret;
6594
}
6695

6796
private static long convert(FrameBlock frame, MatrixBlock mb, int n, int rl, int ru) {
@@ -71,27 +100,25 @@ private static long convert(FrameBlock frame, MatrixBlock mb, int n, int rl, int
71100
return convertGeneric(frame, mb, n, rl, ru);
72101
}
73102

74-
private static long convertParallel(FrameBlock frame, MatrixBlock mb, int m, int n, int k){
103+
private static long convertParallel(FrameBlock frame, MatrixBlock mb, int m, int n, int k) throws Exception {
75104
ExecutorService pool = CommonThreadPool.get(k);
76-
try{
105+
try {
77106
List<Future<Long>> tasks = new ArrayList<>();
78107
final int blkz = Math.max(m / k, 1000);
79108

80-
for( int i = 0; i < m; i+= blkz){
81-
final int start = i;
109+
for(int i = 0; i < m; i += blkz) {
110+
final int start = i;
82111
final int end = Math.min(i + blkz, m);
83112
tasks.add(pool.submit(() -> convert(frame, mb, n, start, end)));
84113
}
85114

86115
long nnz = 0;
87-
for( Future<Long> t : tasks)
116+
for(Future<Long> t : tasks)
88117
nnz += t.get();
89118
return nnz;
90119
}
91-
catch(Exception e){
92-
throw new RuntimeException(e);
93-
}
94-
finally{
120+
121+
finally {
95122
pool.shutdown();
96123
}
97124
}
@@ -104,29 +131,42 @@ private static long convertContiguous(final FrameBlock frame, final MatrixBlock
104131
for(int bj = 0; bj < n; bj += blocksizeIJ) {
105132
int bimin = Math.min(bi + blocksizeIJ, ru);
106133
int bjmin = Math.min(bj + blocksizeIJ, n);
107-
for(int i = bi, aix = bi * n; i < bimin; i++, aix += n)
108-
for(int j = bj; j < bjmin; j++)
109-
lnnz += (c[aix + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0;
134+
lnnz = convertBlockContiguous(frame, n, lnnz, c, bi, bj, bimin, bjmin);
110135
}
111136
}
112137
return lnnz;
113138
}
114139

115-
private static long convertGeneric(final FrameBlock frame, final MatrixBlock mb, final int n, final int rl, final int ru) {
140+
private static long convertBlockContiguous(final FrameBlock frame, final int n, long lnnz, double[] c, int rl,
141+
int cl, int ru, int cu) {
142+
for(int i = rl, aix = rl * n; i < ru; i++, aix += n)
143+
for(int j = cl; j < cu; j++)
144+
lnnz += (c[aix + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0;
145+
return lnnz;
146+
}
147+
148+
private static long convertGeneric(final FrameBlock frame, final MatrixBlock mb, final int n, final int rl,
149+
final int ru) {
116150
long lnnz = 0;
117151
final DenseBlock c = mb.getDenseBlock();
118152
for(int bi = rl; bi < ru; bi += blocksizeIJ) {
119153
for(int bj = 0; bj < n; bj += blocksizeIJ) {
120154
int bimin = Math.min(bi + blocksizeIJ, ru);
121155
int bjmin = Math.min(bj + blocksizeIJ, n);
122-
for(int i = bi; i < bimin; i++) {
123-
double[] cvals = c.values(i);
124-
int cpos = c.pos(i);
125-
for(int j = bj; j < bjmin; j++)
126-
lnnz += (cvals[cpos + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0;
127-
}
156+
lnnz = convertBlockGeneric(frame, lnnz, c, bi, bj, bimin, bjmin);
128157
}
129158
}
130159
return lnnz;
131160
}
161+
162+
private static long convertBlockGeneric(final FrameBlock frame, long lnnz, final DenseBlock c, final int rl,
163+
final int cl, final int ru, final int cu) {
164+
for(int i = rl; i < ru; i++) {
165+
final double[] cvals = c.values(i);
166+
final int cpos = c.pos(i);
167+
for(int j = cl; j < cu; j++)
168+
lnnz += (cvals[cpos + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0;
169+
}
170+
return lnnz;
171+
}
132172
}

src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ private void resetSparse() {
332332
if(sparseBlock == null)
333333
return;
334334
sparseBlock.reset(estimatedNNzsPerRow, clen);
335+
denseBlock = null;
335336
}
336337

337338
private void resetDense(double val) {
@@ -343,6 +344,7 @@ else if( val != 0 ) {
343344
allocateDenseBlock(false);
344345
denseBlock.set(val);
345346
}
347+
sparseBlock = null;
346348
}
347349

348350
private void resetDense(double val, boolean dedup) {
@@ -354,6 +356,7 @@ else if( val != 0 ) {
354356
allocateDenseBlock(false, dedup);
355357
denseBlock.set(val);
356358
}
359+
sparseBlock = null;
357360
}
358361

359362
/**

src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,26 @@
1919

2020
package org.apache.sysds.test.component.frame;
2121

22+
import static org.junit.Assert.assertThrows;
2223
import static org.junit.Assert.assertTrue;
24+
import static org.mockito.ArgumentMatchers.anyInt;
25+
import static org.mockito.Mockito.spy;
26+
import static org.mockito.Mockito.when;
2327

28+
import org.apache.commons.logging.Log;
29+
import org.apache.commons.logging.LogFactory;
2430
import org.apache.sysds.common.Types.ValueType;
31+
import org.apache.sysds.runtime.DMLRuntimeException;
2532
import org.apache.sysds.runtime.frame.data.FrameBlock;
33+
import org.apache.sysds.runtime.frame.data.lib.FrameLibAppend;
34+
import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema;
2635
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
2736
import org.apache.sysds.runtime.util.DataConverter;
2837
import org.apache.sysds.test.TestUtils;
2938
import org.junit.Test;
3039

3140
public class FrameCustomTest {
41+
protected static final Log LOG = LogFactory.getLog(FrameCustomTest.class.getName());
3242

3343
@Test
3444
public void castToFrame() {
@@ -61,4 +71,30 @@ public void castToFrame2() {
6171
assertTrue(f.getSchema()[0] == ValueType.FP64);
6272
}
6373

74+
75+
@Test
76+
public void detectSchemaError(){
77+
FrameBlock f = TestUtils.generateRandomFrameBlock(10, 10, 23);
78+
FrameBlock spy = spy(f);
79+
when(spy.getColumn(anyInt())).thenThrow(new RuntimeException());
80+
81+
Exception e = assertThrows(DMLRuntimeException.class, () -> FrameLibDetectSchema.detectSchema(spy, 3));
82+
83+
assertTrue(e.getMessage().contains("Failed to detect schema"));
84+
}
85+
86+
87+
88+
@Test
89+
public void appendUniqueColNames(){
90+
FrameBlock a = new FrameBlock(new ValueType[]{ValueType.FP32}, new String[]{"Hi"});
91+
a.appendRow(new String[]{"0.2"});
92+
FrameBlock b = new FrameBlock(new ValueType[]{ValueType.FP32}, new String[]{"There"});
93+
b.appendRow(new String[]{"0.5"});
94+
95+
FrameBlock c = FrameLibAppend.append(a, b, true);
96+
97+
assertTrue(c.getColumnName(0).equals("Hi"));
98+
assertTrue(c.getColumnName(1).equals("There"));
99+
}
64100
}

0 commit comments

Comments
 (0)