@@ -189,7 +189,7 @@ public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject
189189 //setup thread-local memory if necessary
190190 if ( allocTmp &&_reqVectMem > 0 )
191191 if (inputs .get (0 ).isInSparseFormat () && DMLScript .SPARSE_INTERMEDIATE ) {
192- LibSpoofPrimitives .setupSparseThreadLocalMemory (_reqVectMem , n / 2 , n2 );
192+ LibSpoofPrimitives .setupSparseThreadLocalMemory (_reqVectMem , n , n2 );
193193 LibSpoofPrimitives .setupThreadLocalMemory (_reqVectMem , n , n2 );
194194 } else {
195195 LibSpoofPrimitives .setupThreadLocalMemory (_reqVectMem , n , n2 );
@@ -442,7 +442,12 @@ public DenseBlock call() {
442442
443443 //allocate vector intermediates and partial output
444444 if ( _reqVectMem > 0 )
445- LibSpoofPrimitives .setupThreadLocalMemory (_reqVectMem , _clen , _clen2 );
445+ if (_a .isInSparseFormat () && DMLScript .SPARSE_INTERMEDIATE ) {
446+ LibSpoofPrimitives .setupSparseThreadLocalMemory (_reqVectMem , _clen , _clen2 );
447+ LibSpoofPrimitives .setupThreadLocalMemory (_reqVectMem , _clen , _clen2 );
448+ } else {
449+ LibSpoofPrimitives .setupThreadLocalMemory (_reqVectMem , _clen , _clen2 );
450+ }
446451 DenseBlock c = DenseBlockFactory .createDenseBlock (1 , _outLen );
447452
448453 if ( !_a .isInSparseFormat () )
@@ -451,7 +456,12 @@ public DenseBlock call() {
451456 executeSparse (_a .getSparseBlock (), _b , _scalars , c , _clen , _rl , _ru , 0 );
452457
453458 if ( _reqVectMem > 0 )
454- LibSpoofPrimitives .cleanupThreadLocalMemory ();
459+ if (_a .isInSparseFormat () && DMLScript .SPARSE_INTERMEDIATE ) {
460+ LibSpoofPrimitives .cleanupSparseThreadLocalMemory ();
461+ LibSpoofPrimitives .cleanupThreadLocalMemory ();
462+ } else {
463+ LibSpoofPrimitives .cleanupThreadLocalMemory ();
464+ }
455465 return c ;
456466 }
457467 }
@@ -485,15 +495,25 @@ protected ParExecTask( MatrixBlock a, SideInput[] b, MatrixBlock c, double[] sca
485495 public Long call () {
486496 //allocate vector intermediates
487497 if ( _reqVectMem > 0 )
488- LibSpoofPrimitives .setupThreadLocalMemory (_reqVectMem , _clen , _clen2 );
498+ if (_a .isInSparseFormat () && DMLScript .SPARSE_INTERMEDIATE ) {
499+ LibSpoofPrimitives .setupSparseThreadLocalMemory (_reqVectMem , _clen , _clen2 );
500+ LibSpoofPrimitives .setupThreadLocalMemory (_reqVectMem , _clen , _clen2 );
501+ } else {
502+ LibSpoofPrimitives .setupThreadLocalMemory (_reqVectMem , _clen , _clen2 );
503+ }
489504
490505 if ( !_a .isInSparseFormat () )
491506 executeDense (_a .getDenseBlock (), _b , _scalars , _c .getDenseBlock (), _clen , _rl , _ru , 0 );
492507 else
493508 executeSparse (_a .getSparseBlock (), _b , _scalars , _c .getDenseBlock (), _clen , _rl , _ru , 0 );
494-
509+
495510 if ( _reqVectMem > 0 )
496- LibSpoofPrimitives .cleanupThreadLocalMemory ();
511+ if (_a .isInSparseFormat () && DMLScript .SPARSE_INTERMEDIATE ) {
512+ LibSpoofPrimitives .cleanupSparseThreadLocalMemory ();
513+ LibSpoofPrimitives .cleanupThreadLocalMemory ();
514+ } else {
515+ LibSpoofPrimitives .cleanupThreadLocalMemory ();
516+ }
497517
498518 //maintain nnz for row partition
499519 return _c .recomputeNonZeros (_rl , _ru -1 , 0 , _c .getNumColumns ()-1 );
0 commit comments