3434import org .apache .sysds .runtime .compress .CompressedMatrixBlock ;
3535import org .apache .sysds .runtime .compress .colgroup .AColGroup ;
3636import org .apache .sysds .runtime .compress .colgroup .ColGroupConst ;
37- import org .apache .sysds .runtime .compress .colgroup .ColGroupDDC ;
37+ import org .apache .sysds .runtime .compress .colgroup .ColGroupUncompressed ;
3838import org .apache .sysds .runtime .compress .colgroup .indexes .ColIndexFactory ;
3939import org .apache .sysds .runtime .compress .colgroup .indexes .IColIndex ;
40- import org .apache .sysds .runtime .functionobjects .Plus ;
4140import org .apache .sysds .runtime .matrix .data .LibMatrixMult ;
4241import org .apache .sysds .runtime .matrix .data .MatrixBlock ;
43- import org .apache .sysds .runtime .matrix .operators .BinaryOperator ;
4442import org .apache .sysds .runtime .util .CommonThreadPool ;
4543
4644public final class CLALibRightMultBy {
4745 private static final Log LOG = LogFactory .getLog (CLALibRightMultBy .class .getName ());
4846
49- private CLALibRightMultBy (){
47+ private CLALibRightMultBy () {
5048 // private constructor
5149 }
5250
@@ -74,6 +72,11 @@ public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBloc
7472 if (m2 instanceof CompressedMatrixBlock )
7573 m2 = ((CompressedMatrixBlock ) m2 ).getUncompressed ("Uncompressed right side of right MM" , k );
7674
75+ if (betterIfDecompressed (m1 )) {
76+ // perform uncompressed multiplication.
77+ return decompressingMatrixMult (m1 , m2 , k );
78+ }
79+
7780 if (!allowOverlap ) {
7881 LOG .trace ("Overlapping output not allowed in call to Right MM" );
7982 return RMM (m1 , m2 , k );
@@ -87,14 +90,67 @@ public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBloc
8790 if (retC .isOverlapping ())
8891 retC .setNonZeros ((long ) rr * rc ); // set non zeros to fully dense in case of overlapping.
8992 else
90- retC .recomputeNonZeros (); // recompute if non overlapping compressed out.
93+ retC .recomputeNonZeros (k ); // recompute if non overlapping compressed out.
9194 return retC ;
9295 }
9396 }
97+ }
98+
99+ private static MatrixBlock decompressingMatrixMult (CompressedMatrixBlock m1 , MatrixBlock m2 , int k ) {
100+ ExecutorService pool = CommonThreadPool .get (k );
101+ try {
102+ final int rl = m1 .getNumRows ();
103+ final int cr = m2 .getNumColumns ();
104+ // final int rr = m2.getNumRows(); // shared dim
105+ final MatrixBlock ret = new MatrixBlock (rl , cr , false );
106+ ret .allocateBlock ();
107+
108+ // MatrixBlock m1uc = m1.decompress(k);
109+ final List <Future <Long >> tasks = new ArrayList <>();
110+ final List <AColGroup > groups = m1 .getColGroups ();
111+ final int blkI = Math .max ((int ) Math .ceil ((double ) rl / k ), 16 );
112+ final int blkJ = blkI > 16 ? cr : Math .max ((cr / k ), 512 ); // make it a multiplicative of 8.
113+ for (int i = 0 ; i < rl ; i += blkI ) {
114+ final int startI = i ;
115+ final int endI = Math .min (i + blkI , rl );
116+ for (int j = 0 ; j < cr ; j += blkJ ){
117+ final int startJ = j ;
118+ final int endJ = Math .min (j + blkJ , cr );
119+ tasks .add (pool .submit (() -> {
120+ for (AColGroup g : groups )
121+ g .rightDecompressingMult (m2 , ret , startI , endI , rl , startJ , endJ );
122+ return ret .recomputeNonZeros (startI , endI - 1 , startJ , endJ -1 );
123+ }));
124+ }
125+ }
126+ long nnz = 0 ;
127+ for (Future <Long > t : tasks )
128+ nnz += t .get ();
129+
130+ ret .setNonZeros (nnz );
131+ ret .examSparsity ();
132+ return ret ;
133+ }
134+ catch (InterruptedException | ExecutionException e ) {
135+ throw new DMLRuntimeException (e );
136+ }
137+ finally {
138+ pool .shutdown ();
139+ }
94140
95141 }
96142
143+ private static boolean betterIfDecompressed (CompressedMatrixBlock m ) {
144+ for (AColGroup g : m .getColGroups ()) {
145+ if (!(g instanceof ColGroupUncompressed ) && g .getNumValues () * 2 >= m .getNumRows ()) {
146+ return true ;
147+ }
148+ }
149+ return false ;
150+ }
151+
97152 private static CompressedMatrixBlock RMMOverlapping (CompressedMatrixBlock m1 , MatrixBlock that , int k ) {
153+
98154 final int rl = m1 .getNumRows ();
99155 final int cr = that .getNumColumns ();
100156 final int rr = that .getNumRows (); // shared dim
@@ -103,21 +159,27 @@ private static CompressedMatrixBlock RMMOverlapping(CompressedMatrixBlock m1, Ma
103159 final CompressedMatrixBlock ret = new CompressedMatrixBlock (rl , cr );
104160
105161 final boolean shouldFilter = CLALibUtils .shouldPreFilter (colGroups );
162+ final double [] constV ;
163+ final List <AColGroup > filteredGroups ;
106164
107- double [] constV = shouldFilter ? new double [rr ] : null ;
108- final List <AColGroup > filteredGroups = CLALibUtils .filterGroups (colGroups , constV );
109- if (colGroups == filteredGroups )
165+ if (shouldFilter ) {
166+ constV = new double [rr ];
167+ filteredGroups = CLALibUtils .filterGroups (colGroups , constV );
168+ }
169+ else {
170+ filteredGroups = colGroups ;
110171 constV = null ;
172+ }
111173
112- if (k == 1 )
174+ if (k == 1 || filteredGroups . size () == 1 )
113175 RMMSingle (filteredGroups , that , retCg );
114176 else
115177 RMMParallel (filteredGroups , that , retCg , k );
116178
117179 if (constV != null ) {
118180 final MatrixBlock cb = new MatrixBlock (1 , constV .length , constV );
119181 final MatrixBlock cbRet = new MatrixBlock (1 , that .getNumColumns (), false );
120- LibMatrixMult .matrixMult (cb , that , cbRet );
182+ LibMatrixMult .matrixMult (cb , that , cbRet ); // mm on row vector left.
121183 if (!cbRet .isEmpty ())
122184 addConstant (cbRet , retCg );
123185 }
@@ -133,52 +195,72 @@ private static CompressedMatrixBlock RMMOverlapping(CompressedMatrixBlock m1, Ma
133195 }
134196
135197 private static void addConstant (MatrixBlock constantRow , List <AColGroup > out ) {
136- final int nCol = constantRow .getNumColumns ();
137- int bestCandidate = -1 ;
138- int bestCandidateValuesSize = Integer .MAX_VALUE ;
139- for (int i = 0 ; i < out .size (); i ++) {
140- AColGroup g = out .get (i );
141- if (g instanceof ColGroupDDC && g .getNumCols () == nCol && g .getNumValues () < bestCandidateValuesSize )
142- bestCandidate = i ;
143- }
198+ // it is fairly safe to add the constant row to a column group.
199+ // but it is not necessary the fastest.
200+
201+ // final int nCol = constantRow.getNumColumns();
202+ // int bestCandidate = -1;
203+ // int bestCandidateValuesSize = Integer.MAX_VALUE;
204+ // for(int i = 0; i < out.size(); i++) {
205+ // AColGroup g = out.get(i);
206+ // if(g instanceof ColGroupDDC && g.getNumCols() == nCol && g.getNumValues() < bestCandidateValuesSize)
207+ // bestCandidate = i;
208+ // }
144209
145210 constantRow .sparseToDense ();
146211
147- if (bestCandidate != -1 ) {
148- AColGroup bc = out .get (bestCandidate );
149- out .remove (bestCandidate );
150- AColGroup ng = bc .binaryRowOpRight (new BinaryOperator (Plus .getPlusFnObject (), 1 ),
151- constantRow .getDenseBlockValues (), true );
152- out .add (ng );
153- }
154- else
155- out .add (ColGroupConst .create (constantRow .getDenseBlockValues ()));
212+ // if(bestCandidate != -1) {
213+ // AColGroup bc = out.get(bestCandidate);
214+ // out.remove(bestCandidate);
215+ // AColGroup ng = bc.binaryRowOpRight(new BinaryOperator(Plus.getPlusFnObject(), 1),
216+ // constantRow.getDenseBlockValues(), true);
217+ // out.add(ng);
218+ // }
219+ // else
220+ out .add (ColGroupConst .create (constantRow .getDenseBlockValues ()));
156221 }
157222
158223 private static MatrixBlock RMM (CompressedMatrixBlock m1 , MatrixBlock that , int k ) {
224+
225+ // Timing t = new Timing();
159226 // this version returns a decompressed result.
160227 final int rl = m1 .getNumRows ();
161228 final int cr = that .getNumColumns ();
162229 final int rr = that .getNumRows (); // shared dim
163230 final List <AColGroup > colGroups = m1 .getColGroups ();
164- final List <AColGroup > retCg = new ArrayList <>();
165231
166232 final boolean shouldFilter = CLALibUtils .shouldPreFilter (colGroups );
167233
168234 // start allocation of output.
169235 MatrixBlock ret = new MatrixBlock (rl , cr , false );
170236 final Future <MatrixBlock > f = ret .allocateBlockAsync ();
171237
172- double [] constV = shouldFilter ? new double [rr ] : null ;
173- final List <AColGroup > filteredGroups = CLALibUtils .filterGroups (colGroups , constV );
174- if (colGroups == filteredGroups )
238+ double [] constV ;
239+ final List <AColGroup > filteredGroups ;
240+
241+ if (shouldFilter ) {
242+ if (CLALibUtils .alreadyPreFiltered (colGroups , cr )) {
243+ filteredGroups = new ArrayList <>(colGroups .size () - 1 );
244+ constV = CLALibUtils .filterGroupsAndSplitPreAggOneConst (colGroups , filteredGroups );
245+ }
246+ else {
247+ constV = new double [rr ];
248+ filteredGroups = CLALibUtils .filterGroups (colGroups , constV );
249+ }
250+ }
251+ else {
252+ filteredGroups = colGroups ;
175253 constV = null ;
254+ }
176255
256+
257+ final List <AColGroup > retCg = new ArrayList <>(filteredGroups .size ());
177258 if (k == 1 )
178259 RMMSingle (filteredGroups , that , retCg );
179260 else
180261 RMMParallel (filteredGroups , that , retCg , k );
181262
263+
182264 if (constV != null ) {
183265 MatrixBlock constVMB = new MatrixBlock (1 , constV .length , constV );
184266 MatrixBlock mmTemp = new MatrixBlock (1 , cr , false );
@@ -233,7 +315,7 @@ private static boolean RMMParallel(List<AColGroup> filteredGroups, MatrixBlock t
233315 catch (InterruptedException | ExecutionException e ) {
234316 throw new DMLRuntimeException (e );
235317 }
236- finally {
318+ finally {
237319 pool .shutdown ();
238320 }
239321 return containsNull ;
0 commit comments