Skip to content

Commit c21273c

Browse files
e-straussmboehm7
authored andcommitted
[SYSTEMDS-3782] Bag-of-words encoder for Spark backend
Closes #2145.
1 parent cdc2e2c commit c21273c

File tree

12 files changed

+827
-197
lines changed

12 files changed

+827
-197
lines changed

src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import org.apache.sysds.runtime.meta.DataCharacteristics;
6363
import org.apache.sysds.runtime.transform.TfUtils;
6464
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
65+
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
6566
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
6667
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
6768
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
@@ -263,6 +264,7 @@ public Iterator<Tuple2<Integer, Object>> call(Iterator<Tuple2<Long, FrameBlock>>
263264
// encoder-specific outputs
264265
List<ColumnEncoderRecode> raEncoders = _encoder.getColumnEncoders(ColumnEncoderRecode.class);
265266
List<ColumnEncoderBin> baEncoders = _encoder.getColumnEncoders(ColumnEncoderBin.class);
267+
List<ColumnEncoderBagOfWords> bowEncoders = _encoder.getColumnEncoders(ColumnEncoderBagOfWords.class);
266268
ArrayList<Tuple2<Integer, Object>> ret = new ArrayList<>();
267269

268270
// output recode maps as columnID - token pairs
@@ -273,8 +275,14 @@ public Iterator<Tuple2<Integer, Object>> call(Iterator<Tuple2<Long, FrameBlock>>
273275
for(Entry<Integer, HashSet<Object>> e1 : tmp.entrySet())
274276
for(Object token : e1.getValue())
275277
ret.add(new Tuple2<>(e1.getKey(), token));
276-
if(!raEncoders.isEmpty())
277-
raEncoders.forEach(columnEncoderRecode -> columnEncoderRecode.getCPRecodeMapsPartial().clear());
278+
raEncoders.forEach(columnEncoderRecode -> columnEncoderRecode.getCPRecodeMapsPartial().clear());
279+
}
280+
281+
if(!bowEncoders.isEmpty()){
282+
for (ColumnEncoderBagOfWords bowEnc : bowEncoders)
283+
for (Object token : bowEnc.getPartialTokenDictionary())
284+
ret.add(new Tuple2<>(bowEnc.getColID(), token));
285+
bowEncoders.forEach(enc -> enc.getPartialTokenDictionary().clear());
278286
}
279287

280288
// output binning column min/max as columnID - min/max pairs
@@ -321,7 +329,8 @@ public Iterator<String> call(Tuple2<Integer, Iterable<Object>> arg0) throws Exce
321329
StringBuilder sb = new StringBuilder();
322330

323331
// handle recode maps
324-
if(_encoder.containsEncoderForID(colID, ColumnEncoderRecode.class)) {
332+
if(_encoder.containsEncoderForID(colID, ColumnEncoderRecode.class) ||
333+
_encoder.containsEncoderForID(colID, ColumnEncoderBagOfWords.class)) {
325334
while(iter.hasNext()) {
326335
String token = TfUtils.sanitizeSpaces(iter.next().toString());
327336
sb.append(rowID).append(' ').append(scolID).append(' ');

src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
8989
import org.apache.sysds.runtime.transform.decode.Decoder;
9090
import org.apache.sysds.runtime.transform.decode.DecoderFactory;
91+
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
9192
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
9293
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
9394
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
@@ -1056,6 +1057,13 @@ public Tuple2<Long, FrameBlock> call(Tuple2<Long, FrameBlock> in) throws Excepti
10561057

10571058
// execute block transform apply
10581059
MultiColumnEncoder encoder = _bencoder.getValue();
1060+
// we need to create a copy of the encoder since the bag of word encoder stores frameblock specific state
1061+
// which would be overwritten when multiple blocks are located on a executor
1062+
// to avoid this, we need to create a shallow copy of the MCEncoder, where we only instantiate new bow
1063+
// encoders objects with the frameblock specific fields and shallow copy the other fields (like meta)
1064+
// other encoders are reused and not newly instantiated
1065+
if(!encoder.getColumnEncoders(ColumnEncoderBagOfWords.class).isEmpty())
1066+
encoder = new MultiColumnEncoder(encoder); // create copy
10591067
MatrixBlock tmp = encoder.apply(blk);
10601068
// remap keys
10611069
if(_omap != null) {

src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ protected void setApplyRowBlocksPerColumn(int nPart) {
450450
}
451451

452452
public enum EncoderType {
453-
Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, WordEmbedding,
453+
Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, WordEmbedding, BagOfWords
454454
}
455455

456456
/*

0 commit comments

Comments
 (0)