6262import org .apache .sysds .runtime .meta .DataCharacteristics ;
6363import org .apache .sysds .runtime .transform .TfUtils ;
6464import org .apache .sysds .runtime .transform .encode .ColumnEncoder ;
65+ import org .apache .sysds .runtime .transform .encode .ColumnEncoderBagOfWords ;
6566import org .apache .sysds .runtime .transform .encode .ColumnEncoderBin ;
6667import org .apache .sysds .runtime .transform .encode .ColumnEncoderComposite ;
6768import org .apache .sysds .runtime .transform .encode .ColumnEncoderRecode ;
@@ -263,6 +264,7 @@ public Iterator<Tuple2<Integer, Object>> call(Iterator<Tuple2<Long, FrameBlock>>
263264 // encoder-specific outputs
264265 List <ColumnEncoderRecode > raEncoders = _encoder .getColumnEncoders (ColumnEncoderRecode .class );
265266 List <ColumnEncoderBin > baEncoders = _encoder .getColumnEncoders (ColumnEncoderBin .class );
267+ List <ColumnEncoderBagOfWords > bowEncoders = _encoder .getColumnEncoders (ColumnEncoderBagOfWords .class );
266268 ArrayList <Tuple2 <Integer , Object >> ret = new ArrayList <>();
267269
268270 // output recode maps as columnID - token pairs
@@ -273,8 +275,14 @@ public Iterator<Tuple2<Integer, Object>> call(Iterator<Tuple2<Long, FrameBlock>>
273275 for (Entry <Integer , HashSet <Object >> e1 : tmp .entrySet ())
274276 for (Object token : e1 .getValue ())
275277 ret .add (new Tuple2 <>(e1 .getKey (), token ));
276- if (!raEncoders .isEmpty ())
277- raEncoders .forEach (columnEncoderRecode -> columnEncoderRecode .getCPRecodeMapsPartial ().clear ());
278+ raEncoders .forEach (columnEncoderRecode -> columnEncoderRecode .getCPRecodeMapsPartial ().clear ());
279+ }
280+
281+ if (!bowEncoders .isEmpty ()){
282+ for (ColumnEncoderBagOfWords bowEnc : bowEncoders )
283+ for (Object token : bowEnc .getPartialTokenDictionary ())
284+ ret .add (new Tuple2 <>(bowEnc .getColID (), token ));
285+ bowEncoders .forEach (enc -> enc .getPartialTokenDictionary ().clear ());
278286 }
279287
280288 // output binning column min/max as columnID - min/max pairs
@@ -321,7 +329,8 @@ public Iterator<String> call(Tuple2<Integer, Iterable<Object>> arg0) throws Exce
321329 StringBuilder sb = new StringBuilder ();
322330
323331 // handle recode maps
324- if (_encoder .containsEncoderForID (colID , ColumnEncoderRecode .class )) {
332+ if (_encoder .containsEncoderForID (colID , ColumnEncoderRecode .class ) ||
333+ _encoder .containsEncoderForID (colID , ColumnEncoderBagOfWords .class )) {
325334 while (iter .hasNext ()) {
326335 String token = TfUtils .sanitizeSpaces (iter .next ().toString ());
327336 sb .append (rowID ).append (' ' ).append (scolID ).append (' ' );
0 commit comments