@@ -42,6 +42,10 @@ private CLALibTSMM() {
4242 // private constructor
4343 }
4444
45+ public static MatrixBlock leftMultByTransposeSelf (CompressedMatrixBlock cmb , int k ) {
46+ return leftMultByTransposeSelf (cmb , new MatrixBlock (), k );
47+ }
48+
4549 /**
4650 * Self left Matrix multiplication (tsmm)
4751 *
@@ -51,17 +55,25 @@ private CLALibTSMM() {
5155 * @param ret The output matrix to put the result into
5256 * @param k The parallelization degree allowed
5357 */
54- public static void leftMultByTransposeSelf (CompressedMatrixBlock cmb , MatrixBlock ret , int k ) {
58+ public static MatrixBlock leftMultByTransposeSelf (CompressedMatrixBlock cmb , MatrixBlock ret , int k ) {
5559
60+ final int numColumns = cmb .getNumColumns ();
61+ final int numRows = cmb .getNumRows ();
62+ if (cmb .isEmpty ())
63+ return new MatrixBlock (numColumns , numColumns , true );
64+ // create output matrix block
65+ if (ret == null )
66+ ret = new MatrixBlock (numColumns , numColumns , false );
67+ else
68+ ret .reset (numColumns , numColumns , false );
69+ ret .allocateDenseBlock ();
5670 final List <AColGroup > groups = cmb .getColGroups ();
5771
58- final int numColumns = cmb .getNumColumns ();
5972 if (groups .size () >= numColumns ) {
6073 MatrixBlock m = cmb .getUncompressed ("TSMM to many columngroups" , k );
6174 LibMatrixMult .matrixMultTransposeSelf (m , ret , true , k );
62- return ;
75+ return ret ;
6376 }
64- final int numRows = cmb .getNumRows ();
6577 final boolean shouldFilter = CLALibUtils .shouldPreFilter (groups );
6678 final boolean overlapping = cmb .isOverlapping ();
6779 if (shouldFilter ) {
@@ -77,6 +89,7 @@ public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBloc
7789
7890 ret .setNonZeros (LibMatrixMult .copyUpperToLowerTriangle (ret ));
7991 ret .examSparsity ();
92+ return ret ;
8093 }
8194
8295 private static void addCorrectionLayer (List <AColGroup > filteredGroups , MatrixBlock result , int nRows , int nCols ,
@@ -86,8 +99,6 @@ private static void addCorrectionLayer(List<AColGroup> filteredGroups, MatrixBlo
8699 addCorrectionLayer (constV , filteredColSum , nRows , retV );
87100 }
88101
89-
90-
91102 private static void tsmmColGroups (List <AColGroup > groups , MatrixBlock ret , int nRows , boolean overlapping , int k ) {
92103 if (k <= 1 )
93104 tsmmColGroupsSingleThread (groups , ret , nRows );
@@ -136,12 +147,12 @@ private static void tsmmColGroupsMultiThread(List<AColGroup> groups, MatrixBlock
136147
137148 public static void addCorrectionLayer (double [] constV , double [] filteredColSum , int nRow , double [] ret ) {
138149 final int nColRow = constV .length ;
139- for (int row = 0 ; row < nColRow ; row ++){
150+ for (int row = 0 ; row < nColRow ; row ++) {
140151 int offOut = nColRow * row ;
141152 final double v1l = constV [row ];
142153 final double v2l = filteredColSum [row ] + constV [row ] * nRow ;
143- for (int col = row ; col < nColRow ; col ++){
144- ret [offOut + col ] += v1l * filteredColSum [col ] + v2l * constV [col ];
154+ for (int col = row ; col < nColRow ; col ++) {
155+ ret [offOut + col ] += v1l * filteredColSum [col ] + v2l * constV [col ];
145156 }
146157 }
147158 }
0 commit comments