Skip to content

Commit b8c4f88

Browse files
committed
update record funtion for Processing columns and test it
1 parent 6dfa925 commit b8c4f88

File tree

3 files changed

+89
-38
lines changed

3 files changed

+89
-38
lines changed

src/main/java/org/apache/sysds/runtime/transform/decode/ColumnDecoderFactory.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ public static ColumnDecoder createDecoder(String spec, String[] colnames, ValueT
9696
ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])),currOffset));
9797
}
9898
if( !rcIDs.isEmpty() ) {
99-
ldecoders.add(new ColumnDecoderRecode(schema, !dcIDs.isEmpty(),
100-
ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0])),currOffset));
99+
for( int col : rcIDs ) {
100+
ldecoders.add(new ColumnDecoderRecode(schema[col-1], !dcIDs.isEmpty(), col-1, currOffset));
101+
currOffset++;
102+
}
101103
}
102104
if( !ptIDs.isEmpty() ) {
103105
for (int col : ptIDs) {
@@ -137,7 +139,7 @@ public static ColumnDecoder createInstance(int type) {
137139
switch(dtype) {
138140
case Dummycode: return new ColumnDecoderDummycode(null, null, -1);
139141
case PassThrough: return new ColumnDecoderPassThrough(null, -1, null, -1);
140-
case Recode: return new ColumnDecoderRecode(null, false, null, -1);
142+
case Recode: return new ColumnDecoderRecode(null, false, -1, -1);
141143
default:
142144
throw new DMLRuntimeException("Unsupported Encoder Type used: " + dtype);
143145
}

src/main/java/org/apache/sysds/runtime/transform/decode/ColumnDecoderRecode.java

Lines changed: 79 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import org.apache.sysds.common.Types.ValueType;
2323
import org.apache.sysds.runtime.frame.data.FrameBlock;
24+
import org.apache.sysds.runtime.frame.data.columns.ABooleanArray;
25+
import org.apache.sysds.runtime.frame.data.columns.Array;
2426
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
2527
import org.apache.sysds.runtime.matrix.data.Pair;
2628
import org.apache.sysds.runtime.transform.TfUtils;
@@ -36,53 +38,41 @@ public class ColumnDecoderRecode extends ColumnDecoder {
3638

3739
private static final long serialVersionUID = -3784249774608228805L;
3840

39-
private HashMap<Long, Object> _rcMaps = null;
40-
private Object[] _rcMapsDirect = null;
41+
private HashMap<Long, Object> _rcMap = null;
42+
private Object[] _rcMapDirect = null;
4143
private boolean _onOut = false;
4244

4345
public ColumnDecoderRecode() {
4446
super(null, -1, -1);
4547
}
4648

47-
protected ColumnDecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols, int offset) {
48-
super(schema, rcCols,offset);
49+
protected ColumnDecoderRecode(ValueType schema, boolean onOut, int rcCol, int offset) {
50+
super(schema, rcCol, offset);
4951
_onOut = onOut;
5052
}
5153

5254
@Override
5355
public FrameBlock columnDecode(MatrixBlock in, FrameBlock out) {
5456

55-
long r1 = System.nanoTime();
56-
// TODO
5757
out.ensureAllocatedColumns(in.getNumRows());
5858
columnDecode(in, out, 0, in.getNumRows());
59-
long r2 = System.nanoTime();
60-
System.out.println(this.getClass() + "time: " + (r2 - r1) / 1e6 + " ms");
6159
return out;
6260
}
6361

6462
@Override
6563
public void columnDecode(MatrixBlock in, FrameBlock out, int rl, int ru) {
6664
// TODO
67-
if( _onOut ) { //recode on output (after dummy)
68-
for( int i=rl; i<ru; i++ ) {
69-
for( int j=0; j<_colList.length; j++ ) {
70-
int colID = _colList[j];
71-
double val = UtilFunctions.objectToDouble(
72-
out.getSchema()[colID-1], out.get(i, colID-1));
73-
long key = UtilFunctions.toLong(val);
74-
out.set(i, colID-1, getRcMapValue(j, key));
75-
}
76-
}
65+
Array<?> a = out.getColumn(_colID);
66+
if(_onOut) {
67+
for(int i = rl; i < ru; i++) {
68+
double val = UtilFunctions.objectToDouble(_schema, a.get(i));
69+
long key = UtilFunctions.toLong(val);
70+
setArrayValue(a, i, getRcMapValue(key)); }
7771
}
78-
else { //recode on input (no dummy)
79-
out.ensureAllocatedColumns(in.getNumRows());
80-
for( int i=rl; i<ru; i++ ) {
81-
for( int j=0; j<_colList.length; j++ ) {
82-
//double val = in.get(i, _colList[j]-1);
83-
long key = UtilFunctions.toLong(in.get(i, j));
84-
out.set(i, _colList[j]-1, getRcMapValue(j, key));
85-
}
72+
else {
73+
for(int i = rl; i < ru; i++) {
74+
long key = UtilFunctions.toLong(in.get(i, _colID));
75+
setArrayValue(a, i, getRcMapValue(key));
8676
}
8777
}
8878
}
@@ -117,7 +107,24 @@ public ColumnDecoder subRangeDecoder(int colStart, int colEnd, int dummycodedOff
117107
@Override
118108
@SuppressWarnings("unchecked")
119109
public void initMetaData(FrameBlock meta) {
120-
110+
int col = _colID; // already 0-based
111+
_rcMap = new HashMap<>();
112+
long max = 0;
113+
for(int i=0; i<meta.getNumRows(); i++) {
114+
Object val = meta.get(i, col);
115+
if(val == null)
116+
break;
117+
String[] tmp = ColumnEncoderRecode.splitRecodeMapEntry(val.toString());
118+
Object obj = UtilFunctions.stringToObject(_schema, tmp[0]);
119+
long lval = Long.parseLong(tmp[1]);
120+
_rcMap.put(lval, obj);
121+
max = Math.max(max, lval);
122+
}
123+
if(max < Integer.MAX_VALUE) {
124+
_rcMapDirect = new Object[(int)max];
125+
for(Map.Entry<Long,Object> e : _rcMap.entrySet())
126+
_rcMapDirect[e.getKey().intValue()-1] = e.getValue();
127+
}
121128
//initialize recode maps according to schema
122129
//_rcMaps = new HashMap[_colList.length];
123130
//long[] max = new long[_colList.length];
@@ -146,12 +153,29 @@ public void initMetaData(FrameBlock meta) {
146153
// }
147154
//}
148155
}
149-
public Object getRcMapValue(int i, long key) {
150-
return null;
151-
//return (_rcMapsDirect != null) ?
152-
// _rcMapsDirect[i][(int)key-1] : _rcMaps[i].get(key);
156+
public Object getRcMapValue(long key) {
157+
return (_rcMapDirect != null && key > 0 && key <= _rcMapDirect.length) ?
158+
_rcMapDirect[(int)key-1] : _rcMap.get(key);
153159
}
154-
160+
private void setArrayValue(Array<?> a, int index, Object val) {
161+
if(val == null) {
162+
if(_schema == ValueType.STRING || _schema == ValueType.CHARACTER)
163+
a.set(index, (String)null);
164+
else if(_schema == ValueType.BOOLEAN)
165+
((ABooleanArray)a).set(index, (Boolean)null);
166+
else
167+
a.set(index, Double.NaN);
168+
}
169+
else if(_schema.isNumeric()) {
170+
a.set(index, UtilFunctions.objectToDouble(_schema, val));
171+
}
172+
else if(_schema == ValueType.BOOLEAN) {
173+
((ABooleanArray)a).set(index, UtilFunctions.objectToBoolean(_schema, val));
174+
}
175+
else { // STRING or CHARACTER
176+
a.set(index, val.toString());
177+
}
178+
}
155179
/**
156180
* Parses a line of &lt;token, ID, count&gt; into &lt;token, ID&gt; pairs, where
157181
* quoted tokens (potentially including separators) are supported.
@@ -182,6 +206,13 @@ public void writeExternal(ObjectOutput out) throws IOException {
182206
// out.writeUTF(e1.getValue().toString());
183207
// }
184208
//}
209+
super.writeExternal(out);
210+
out.writeBoolean(_onOut);
211+
out.writeInt(_rcMap.size());
212+
for(Map.Entry<Long,Object> e : _rcMap.entrySet()) {
213+
out.writeLong(e.getKey());
214+
out.writeUTF(e.getValue().toString());
215+
}
185216
}
186217

187218
@Override
@@ -197,5 +228,21 @@ public void readExternal(ObjectInput in) throws IOException {
197228
// maps.put(in.readLong(), in.readUTF());
198229
// _rcMaps[i] = maps;
199230
//}
231+
super.readExternal(in);
232+
_onOut = in.readBoolean();
233+
int size = in.readInt();
234+
_rcMap = new HashMap<>();
235+
long max = 0;
236+
for(int i = 0; i < size; i++) {
237+
long key = in.readLong();
238+
String val = in.readUTF();
239+
_rcMap.put(key, val);
240+
max = Math.max(max, key);
241+
}
242+
if(max < Integer.MAX_VALUE) {
243+
_rcMapDirect = new Object[(int)max];
244+
for(Map.Entry<Long,Object> e : _rcMap.entrySet())
245+
_rcMapDirect[e.getKey().intValue()-1] = e.getValue();
246+
}
200247
}
201248
}

src/test/java/org/apache/sysds/test/functions/transform/ColumnDecoderMixedMethodsTest.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,20 @@ public void setUp() {
4343
public void testColumnDecoderMixedMethods() {
4444
try {
4545
int rows = 50;
46-
double[][] arr = new double[rows][3];
46+
double[][] arr = new double[rows][2];
4747
for (int i = 0; i < rows; i++) {
4848
arr[i][0] = 2*i + 1; // bin column
4949
arr[i][1] = 101 + i; // recode column
50-
arr[i][2] = (i % 4) + 2; // dummy column
50+
//arr[i][2] = (i % 4) + 2; // dummy column
5151
//arr[i][3] = 2*i + 1; // pass through column
5252
//arr[i][4] = 100 + i; // bin column
5353
//arr[i][5] = (i % 2) + 1; // recode
5454
}
5555
MatrixBlock mb = DataConverter.convertToMatrixBlock(arr);
5656
FrameBlock data = DataConverter.convertToFrameBlock(mb);
57-
String spec = "{ids:true,bin:[{id:1, method:equi-width, numbins:4},{id:3, method:equi-width, numbins:4}]}";//, dummycode:[6]
57+
//String spec = "{ids:true,bin:[{id:1, method:equi-width, numbins:4},{id:3, method:equi-width, numbins:4}]}";//, dummycode:[6]
58+
String spec = "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], recode:[2]}";
59+
5860
// recode:[1,3],
5961
MultiColumnEncoder enc = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), null);
6062
MatrixBlock encoded = enc.encode(data);

0 commit comments

Comments
 (0)