Skip to content

Commit 0c5acdb

Browse files
committed
update dummy funtion for Processing columns in parallel
1 parent 34a4341 commit 0c5acdb

File tree

3 files changed

+123
-36
lines changed

3 files changed

+123
-36
lines changed

src/main/java/org/apache/sysds/runtime/transform/decode/ColumnDecoderComposite.java

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,21 @@ public FrameBlock columnDecode(MatrixBlock in, FrameBlock out, final int k) {
6868
List<Future<FrameBlock>> tasks = new ArrayList<>();
6969
for (ColumnDecoder dec : _decoders) {
7070
long t1 = System.nanoTime();
71-
List<MatrixBlock> slices = sliceColumns(in, dec.getColList());
72-
for (int c = 0; c < slices.size(); c++) {
73-
ColumnDecoder sub = dec.getColList().length == 1 ? dec :
74-
dec.subRangeDecoder(dec.getColList()[c], dec.getColList()[c] + 1, 0);
75-
if (sub == null)
76-
throw new RuntimeException("Decoder does not support column slicing: " + dec.getClass());
77-
if (sub != dec)
78-
sub._colList = new int[]{dec.getColList()[c]};
79-
int finalC = c;
80-
tasks.add(pool.submit(() -> sub.columnDecode(slices.get(finalC), out)));
71+
if(dec instanceof ColumnDecoderDummycode) {
72+
tasks.add(pool.submit(() -> dec.columnDecode(in, out)));
73+
}
74+
else {
75+
List<MatrixBlock> slices = sliceColumns(in, dec.getColList());
76+
for (int c = 0; c < slices.size(); c++) {
77+
ColumnDecoder sub = dec.getColList().length == 1 ? dec :
78+
dec.subRangeDecoder(dec.getColList()[c], dec.getColList()[c] + 1, 0);
79+
if (sub == null)
80+
throw new RuntimeException("Decoder does not support column slicing: " + dec.getClass());
81+
if (sub != dec)
82+
sub._colList = new int[]{dec.getColList()[c]};
83+
int finalC = c;
84+
tasks.add(pool.submit(() -> sub.columnDecode(slices.get(finalC), out)));
85+
}
8186
}
8287
long t2 = System.nanoTime();
8388
System.out.println(dec + "time: " + (t2 - t1) / 1e6 + " ms");
@@ -95,6 +100,8 @@ public FrameBlock columnDecode(MatrixBlock in, FrameBlock out, final int k) {
95100
@Override
96101
public void columnDecode(MatrixBlock in, FrameBlock out, int rl, int ru) {
97102
// TODO
103+
for( ColumnDecoder dec : _decoders )
104+
dec.columnDecode(in, out, rl, ru);
98105
}
99106

100107
@Override

src/main/java/org/apache/sysds/runtime/transform/decode/ColumnDecoderDummycode.java

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ public class ColumnDecoderDummycode extends ColumnDecoder {
1919

2020
private int[] _clPos = null;
2121
private int[] _cuPos = null;
22+
// category index for dedicated single-column decoders (-1 if not used)
23+
private int _category = -1;
2224

2325
protected ColumnDecoderDummycode(Types.ValueType[] schema, int[] colList) {
2426
super(schema, colList);
@@ -33,39 +35,65 @@ public FrameBlock columnDecode(MatrixBlock in, FrameBlock out) {
3335

3436
@Override
3537
public void columnDecode(MatrixBlock in, FrameBlock out, int rl, int ru) {
36-
for( int i=rl; i<ru; i++ )
37-
for( int j=0; j<_colList.length; j++ )
38-
for( int k=_clPos[j]; k<_cuPos[j]; k++ )
39-
if( in.get(i, k-1) != 0 ) {
40-
int col = _colList[j] - 1;
41-
out.set(i, col,
42-
UtilFunctions.doubleToObject(out.getSchema()[col], k-_clPos[j]+1));
43-
}
38+
if(_category >= 0) {
39+
int col = _colList[0] - 1;
40+
Object val = UtilFunctions.doubleToObject(out.getSchema()[col], _category);
41+
for(int i = rl; i < ru; i++)
42+
if(in.get(i, _clPos[0]-1) == 1)
43+
synchronized(out) { out.set(i, col, val); }
44+
}
45+
else {
46+
for( int i=rl; i<ru; i++ )
47+
for( int j=0; j<_colList.length; j++ )
48+
for( int k=_clPos[j]; k<_cuPos[j]; k++ )
49+
if( in.get(i, k-1) != 0 ) {
50+
int col = _colList[j] - 1;
51+
Object val = UtilFunctions.doubleToObject(out.getSchema()[col], k-_clPos[j]+1);
52+
synchronized(out) { out.set(i, col, val); }
53+
}
54+
}
4455
}
4556

4657
@Override
4758
public ColumnDecoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) {
48-
List<Integer> dcList = new ArrayList<>();
49-
List<Integer> clPosList = new ArrayList<>();
50-
List<Integer> cuPosList = new ArrayList<>();
59+
// special case: request for exactly one encoded column
60+
if(colEnd - colStart == 1) {
61+
int encCol = colStart;
62+
for(int j=0; j<_clPos.length; j++)
63+
if(encCol >= _clPos[j] && encCol < _cuPos[j]) {
64+
ColumnDecoderDummycode dec = new ColumnDecoderDummycode(
65+
new Types.ValueType[]{_schema[_colList[j]-1]},
66+
new int[]{_colList[j]});
67+
dec._clPos = new int[]{1};
68+
dec._cuPos = new int[]{2};
69+
dec._category = encCol - _clPos[j] + 1;
70+
return dec;
71+
}
72+
return null;
73+
}
74+
else {
75+
List<Integer> dcList = new ArrayList<>();
76+
List<Integer> clPosList = new ArrayList<>();
77+
List<Integer> cuPosList = new ArrayList<>();
5178

52-
for( int j=0; j<_colList.length; j++ ) {
53-
int colID = _colList[j];
54-
if (colID >= colStart && colID < colEnd) {
55-
dcList.add(colID - (colStart - 1));
56-
clPosList.add(_clPos[j] - dummycodedOffset);
57-
cuPosList.add(_cuPos[j] - dummycodedOffset);
79+
for( int j=0; j<_colList.length; j++ ) {
80+
int colID = _colList[j];
81+
if (colID >= colStart && colID < colEnd) {
82+
dcList.add(colID - (colStart - 1));
83+
clPosList.add(_clPos[j] - dummycodedOffset);
84+
cuPosList.add(_cuPos[j] - dummycodedOffset);
85+
}
5886
}
59-
}
60-
if (dcList.isEmpty())
61-
return null;
87+
if (dcList.isEmpty())
88+
return null;
6289

63-
ColumnDecoderDummycode dec = new ColumnDecoderDummycode(
64-
Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1),
65-
dcList.stream().mapToInt(i -> i).toArray());
66-
dec._clPos = clPosList.stream().mapToInt(i -> i).toArray();
67-
dec._cuPos = cuPosList.stream().mapToInt(i -> i).toArray();
68-
return dec;
90+
ColumnDecoderDummycode dec = new ColumnDecoderDummycode(
91+
Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1),
92+
dcList.stream().mapToInt(i -> i).toArray());
93+
dec._clPos = clPosList.stream().mapToInt(i -> i).toArray();
94+
dec._cuPos = cuPosList.stream().mapToInt(i -> i).toArray();
95+
return dec;
96+
}
6997
}
7098

7199
@Override
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package org.apache.sysds.test.functions.transform;
2+
3+
import org.apache.sysds.common.Types.ValueType;
4+
import org.apache.sysds.runtime.frame.data.FrameBlock;
5+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
6+
import org.apache.sysds.runtime.transform.decode.ColumnDecoder;
7+
import org.apache.sysds.runtime.transform.decode.ColumnDecoderFactory;
8+
import org.apache.sysds.runtime.transform.decode.Decoder;
9+
import org.apache.sysds.runtime.transform.decode.DecoderFactory;
10+
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
11+
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
12+
import org.apache.sysds.runtime.util.DataConverter;
13+
import org.apache.sysds.test.AutomatedTestBase;
14+
import org.apache.sysds.test.TestUtils;
15+
import org.junit.Test;
16+
17+
public class ColumnDecoderDummycodeTest extends AutomatedTestBase {
18+
@Override
19+
public void setUp() {
20+
TestUtils.clearAssertionInformation();
21+
}
22+
23+
@Test
24+
public void testColumnDecoderDummycode() {
25+
try {
26+
int rows = 20;
27+
double[][] arr = new double[rows][1];
28+
for (int i = 0; i < rows; i++)
29+
arr[i][0] = (i % 3) + 1;
30+
MatrixBlock mb = DataConverter.convertToMatrixBlock(arr);
31+
FrameBlock data = DataConverter.convertToFrameBlock(mb);
32+
String spec = "{ids:true, dummycode:[1]}";
33+
34+
MultiColumnEncoder enc = EncoderFactory.createEncoder(spec, data.getColumnNames(), 1, null);
35+
MatrixBlock encoded = enc.encode(data);
36+
FrameBlock meta = enc.getMetaData(new FrameBlock(1, ValueType.STRING));
37+
38+
Decoder dec = DecoderFactory.createDecoder(spec, data.getColumnNames(), data.getSchema(), meta, encoded.getNumColumns());
39+
FrameBlock expected = new FrameBlock(data.getSchema());
40+
dec.decode(encoded, expected);
41+
42+
ColumnDecoder cdec = ColumnDecoderFactory.createDecoder(spec, data.getColumnNames(), data.getSchema(), meta);
43+
FrameBlock actual = new FrameBlock(data.getSchema());
44+
cdec.columnDecode(encoded, actual);
45+
46+
TestUtils.compareFrames(expected, actual, false);
47+
}
48+
catch(Exception ex) {
49+
throw new RuntimeException(ex);
50+
}
51+
}
52+
}

0 commit comments

Comments
 (0)