@@ -61,18 +61,20 @@ public abstract class ColumnEncoder implements Encoder, Comparable<ColumnEncoder
6161 private static final long serialVersionUID = 2299156350718979064L ;
6262 protected int _colID ;
6363 protected ArrayList <Integer > _sparseRowsWZeros = null ;
64+ protected int [] sparseRowPointerOffset = null ; // offsets created by bag of words encoders (multiple nnz)
6465 protected long _estMetaSize = 0 ;
6566 protected int _estNumDistincts = 0 ;
6667 protected int _nBuildPartitions = 0 ;
6768 protected int _nApplyPartitions = 0 ;
69+ protected long _avgEntrySize = 0 ;
6870
6971 //Override in ColumnEncoderWordEmbedding
7072 public void initEmbeddings (MatrixBlock embeddings ){
7173 return ;
7274 }
7375
7476 protected enum TransformType {
75- BIN , RECODE , DUMMYCODE , FEATURE_HASH , PASS_THROUGH , UDF , WORD_EMBEDDING , N_A
77+ BIN , RECODE , DUMMYCODE , FEATURE_HASH , PASS_THROUGH , UDF , WORD_EMBEDDING , BAG_OF_WORDS , N_A
7678 }
7779
7880 protected ColumnEncoder (int colID ) {
@@ -115,6 +117,9 @@ public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int outputCol, int r
115117 case WORD_EMBEDDING :
116118 TransformStatistics .incWordEmbeddingApplyTime (t );
117119 break ;
120+ case BAG_OF_WORDS :
121+ TransformStatistics .incBagOfWordsApplyTime (t );
122+ break ;
118123 case FEATURE_HASH :
119124 TransformStatistics .incFeatureHashingApplyTime (t );
120125 break ;
@@ -152,6 +157,7 @@ protected void applySparse(CacheBlock<?> in, MatrixBlock out, int outputCol, int
152157 for (int i = rowStart ; i < rowEnd ; i +=B ) {
153158 int lim = Math .min (i +B , rowEnd );
154159 for (int ii =i ; ii <lim ; ii ++) {
160+ int indexWithOffset = sparseRowPointerOffset != null ? sparseRowPointerOffset [ii ] - 1 + index : index ;
155161 if (mcsr ) {
156162 SparseRowVector row = (SparseRowVector ) out .getSparseBlock ().get (ii );
157163 row .values ()[index ] = codes [ii -rowStart ];
@@ -161,8 +167,8 @@ protected void applySparse(CacheBlock<?> in, MatrixBlock out, int outputCol, int
161167 // Manually fill the column-indexes and values array
162168 SparseBlockCSR csrblock = (SparseBlockCSR )out .getSparseBlock ();
163169 int rptr [] = csrblock .rowPointers ();
164- csrblock .indexes ()[rptr [ii ]+index ] = outputCol ;
165- csrblock .values ()[rptr [ii ]+index ] = codes [ii -rowStart ];
170+ csrblock .indexes ()[rptr [ii ]+indexWithOffset ] = outputCol ;
171+ csrblock .values ()[rptr [ii ]+indexWithOffset ] = codes [ii -rowStart ];
166172 }
167173 }
168174 }
@@ -336,6 +342,11 @@ public int getEstNumDistincts() {
336342 return _estNumDistincts ;
337343 }
338344
345+ public void computeMapSizeEstimate (CacheBlock <?> in , int [] sampleIndices ) {
346+ throw new DMLRuntimeException (this + " does not need map size estimation" );
347+ }
348+
349+
339350 @ Override
340351 public int compareTo (ColumnEncoder o ) {
341352 return Integer .compare (getEncoderType (this ), getEncoderType (o ));
@@ -355,9 +366,11 @@ public List<DependencyTask<?>> getBuildTasks(CacheBlock<?> in) {
355366 tasks .add (getBuildTask (in ));
356367 }
357368 else {
369+ if (this instanceof ColumnEncoderBagOfWords )
370+ ((ColumnEncoderBagOfWords ) this ).initNnzPartials (in .getNumRows (), blockSizes .length );
358371 HashMap <Integer , Object > ret = new HashMap <>();
359372 for (int startRow = 0 , i = 0 ; i < blockSizes .length ; startRow +=blockSizes [i ], i ++)
360- tasks .add (getPartialBuildTask (in , startRow , blockSizes [i ], ret ));
373+ tasks .add (getPartialBuildTask (in , startRow , blockSizes [i ], ret , i ));
361374 tasks .add (getPartialMergeBuildTask (ret ));
362375 dep = new ArrayList <>(Collections .nCopies (tasks .size () - 1 , null ));
363376 dep .add (tasks .subList (0 , tasks .size () - 1 ));
@@ -370,7 +383,7 @@ public Callable<Object> getBuildTask(CacheBlock<?> in) {
370383 }
371384
372385 public Callable <Object > getPartialBuildTask (CacheBlock <?> in , int startRow ,
373- int blockSize , HashMap <Integer , Object > ret ) {
386+ int blockSize , HashMap <Integer , Object > ret , int p ) {
374387 throw new DMLRuntimeException (
375388 "Trying to get the PartialBuild task of an Encoder which does not support partial building" );
376389 }
@@ -381,11 +394,12 @@ public Callable<Object> getPartialMergeBuildTask(HashMap<Integer, ?> ret) {
381394 }
382395
383396
384- public List <DependencyTask <?>> getApplyTasks (CacheBlock <?> in , MatrixBlock out , int outputCol ) {
397+ public List <DependencyTask <?>> getApplyTasks (CacheBlock <?> in , MatrixBlock out , int outputCol , int [] sparseRowPointerOffsets ) {
385398 List <Callable <Object >> tasks = new ArrayList <>();
386399 List <List <? extends Callable <?>>> dep = null ;
400+ //for now single threaded apply for bag of words
387401 int [] blockSizes = getBlockSizes (in .getNumRows (), _nApplyPartitions );
388-
402+ this . sparseRowPointerOffset = out . isInSparseFormat () ? sparseRowPointerOffsets : null ;
389403 for (int startRow = 0 , i = 0 ; i < blockSizes .length ; startRow +=blockSizes [i ], i ++){
390404 if (out .isInSparseFormat ())
391405 tasks .add (getSparseTask (in , out , outputCol , startRow , blockSizes [i ]));
@@ -419,7 +433,7 @@ public Set<Integer> getSparseRowsWZeros(){
419433 return null ;
420434 }
421435
422- protected void addSparseRowsWZeros (ArrayList <Integer > sparseRowsWZeros ){
436+ protected void addSparseRowsWZeros (List <Integer > sparseRowsWZeros ){
423437 synchronized (this ){
424438 if (_sparseRowsWZeros == null )
425439 _sparseRowsWZeros = new ArrayList <>();
0 commit comments