2121
2222import java .util .ArrayList ;
2323import java .util .List ;
24+ import java .util .concurrent .ExecutionException ;
25+ import java .util .concurrent .ExecutorService ;
26+ import java .util .concurrent .Future ;
2427
2528import org .apache .commons .logging .Log ;
2629import org .apache .commons .logging .LogFactory ;
2730import org .apache .sysds .runtime .compress .CompressedMatrixBlock ;
2831import org .apache .sysds .runtime .compress .CompressedMatrixBlockFactory ;
32+ import org .apache .sysds .runtime .compress .DMLCompressionException ;
2933import org .apache .sysds .runtime .compress .colgroup .AColGroup ;
3034import org .apache .sysds .runtime .compress .colgroup .ColGroupEmpty ;
3135import org .apache .sysds .runtime .compress .colgroup .ColGroupUncompressed ;
3236import org .apache .sysds .runtime .compress .colgroup .indexes .ColIndexFactory ;
3337import org .apache .sysds .runtime .compress .colgroup .indexes .IColIndex ;
3438import org .apache .sysds .runtime .matrix .data .MatrixBlock ;
39+ import org .apache .sysds .runtime .util .CommonThreadPool ;
3540
36- public final class CLALibAppend {
41+ public final class CLALibCBind {
3742
38- private CLALibAppend () {
43+ private CLALibCBind () {
3944 // private constructor.
4045 }
4146
42- private static final Log LOG = LogFactory .getLog (CLALibAppend .class .getName ());
47+ private static final Log LOG = LogFactory .getLog (CLALibCBind .class .getName ());
4348
44- public static MatrixBlock append (MatrixBlock left , MatrixBlock right , int k ) {
49+ public static MatrixBlock cbind (MatrixBlock left , MatrixBlock [] right , int k ) {
50+ try {
51+
52+ if (right .length == 1 ) {
53+ return cbind (left , right [0 ], k );
54+ }
55+ else {
56+ boolean allCompressed = true ;
57+ for (int i = 0 ; i < right .length && allCompressed ; i ++)
58+ allCompressed = right [i ] instanceof CompressedMatrixBlock ;
59+ if (allCompressed )
60+ return cbindAllCompressed ((CompressedMatrixBlock ) left , right , k );
61+ else
62+ return cbindAllNormalCompressed (left , right , k );
63+ }
64+ }
65+ catch (Exception e ) {
66+ throw new DMLCompressionException ("Failed to Cbind with compressed input" , e );
67+ }
68+ }
69+
70+ private static MatrixBlock cbindAllNormalCompressed (MatrixBlock left , MatrixBlock [] right , int k ) {
71+ for (int i = 0 ; i < right .length ; i ++) {
72+ left = cbind (left , right [i ], k );
73+ }
74+ return left ;
75+ }
76+
77+ public static MatrixBlock cbind (MatrixBlock left , MatrixBlock right , int k ) {
4578
4679 final int m = left .getNumRows ();
4780 final int n = left .getNumColumns () + right .getNumColumns ();
@@ -66,15 +99,96 @@ else if(right.isEmpty() && left instanceof CompressedMatrixBlock)
6699 final double spar = (left .getNonZeros () + right .getNonZeros ()) / ((double ) m * n );
67100 final double estSizeUncompressed = MatrixBlock .estimateSizeInMemory (m , n , spar );
68101 final double estSizeCompressed = left .getInMemorySize () + right .getInMemorySize ();
102+ // if(isAligned((CompressedMatrixBlock) left, (CompressedMatrixBlock) right))
103+ // return combineCompressed((CompressedMatrixBlock) left, (CompressedMatrixBlock) right);
104+ // else
69105 if (estSizeUncompressed < estSizeCompressed )
70106 return uc (left ).append (uc (right ), null );
71107 else if (left instanceof CompressedMatrixBlock )
72108 return appendRightUncompressed ((CompressedMatrixBlock ) left , right , m , n );
73109 else
74110 return appendLeftUncompressed (left , (CompressedMatrixBlock ) right , m , n );
75111 }
112+ if (isAligned ((CompressedMatrixBlock ) left , (CompressedMatrixBlock ) right ))
113+ return combineCompressed ((CompressedMatrixBlock ) left , (CompressedMatrixBlock ) right );
114+ else
115+ return append ((CompressedMatrixBlock ) left , (CompressedMatrixBlock ) right , m , n );
116+ }
117+
118+ private static MatrixBlock cbindAllCompressed (CompressedMatrixBlock left , MatrixBlock [] right , int k )
119+ throws InterruptedException , ExecutionException {
120+
121+ final int nCol = left .getNumColumns ();
122+ for (int i = 0 ; i < right .length ; i ++) {
123+ CompressedMatrixBlock rightCM = ((CompressedMatrixBlock ) right [i ]);
124+ if (nCol != right [i ].getNumColumns () || !isAligned (left , rightCM ))
125+ return cbindAllNormalCompressed (left , right , k );
126+ }
127+ return cbindAllCompressedAligned (left , right , k );
128+
129+ }
130+
131+ private static boolean isAligned (CompressedMatrixBlock left , CompressedMatrixBlock right ) {
132+ final List <AColGroup > gl = left .getColGroups ();
133+ for (int j = 0 ; j < gl .size (); j ++) {
134+ final AColGroup glj = gl .get (j );
135+ final int aColumnInGroup = glj .getColIndices ().get (0 );
136+ final AColGroup grj = right .getColGroupForColumn (aColumnInGroup );
137+
138+ if (!glj .sameIndexStructure (grj ) || glj .getNumCols () != grj .getNumCols ())
139+ return false ;
140+
141+ }
142+ return true ;
143+ }
144+
145+ private static CompressedMatrixBlock combineCompressed (CompressedMatrixBlock left , CompressedMatrixBlock right ) {
146+ final List <AColGroup > gl = left .getColGroups ();
147+ final List <AColGroup > retCG = new ArrayList <>(gl .size ());
148+ for (int j = 0 ; j < gl .size (); j ++) {
149+ AColGroup glj = gl .get (j );
150+ int aColumnInGroup = glj .getColIndices ().get (0 );
151+ AColGroup grj = right .getColGroupForColumn (aColumnInGroup );
152+ // parallel combine...
153+ retCG .add (glj .combineWithSameIndex (left .getNumRows (), left .getNumColumns (), grj ));
154+ }
155+ return new CompressedMatrixBlock (left .getNumRows (), left .getNumColumns () + right .getNumColumns (),
156+ left .getNonZeros () + right .getNonZeros (), false , retCG );
157+ }
158+
159+ private static CompressedMatrixBlock cbindAllCompressedAligned (CompressedMatrixBlock left , MatrixBlock [] right ,
160+ final int k ) throws InterruptedException , ExecutionException {
161+
162+ final ExecutorService pool = CommonThreadPool .get (k );
163+ try {
164+ final List <AColGroup > gl = left .getColGroups ();
165+ final List <Future <AColGroup >> tasks = new ArrayList <>();
166+ final int nCol = left .getNumColumns ();
167+ final int nRow = left .getNumRows ();
168+ for (int i = 0 ; i < gl .size (); i ++) {
169+ final AColGroup gli = gl .get (i );
170+ tasks .add (pool .submit (() -> {
171+ List <AColGroup > combines = new ArrayList <>();
172+ final int cId = gli .getColIndices ().get (0 );
173+ for (int j = 0 ; j < right .length ; j ++) {
174+ combines .add (((CompressedMatrixBlock ) right [j ]).getColGroupForColumn (cId ));
175+ }
176+ return gli .combineWithSameIndex (nRow , nCol , combines );
177+ }));
178+ }
179+
180+ final List <AColGroup > retCG = new ArrayList <>(gl .size ());
181+ for (Future <AColGroup > t : tasks )
182+ retCG .add (t .get ());
183+
184+ int totalCol = nCol + right .length * nCol ;
185+
186+ return new CompressedMatrixBlock (left .getNumRows (), totalCol , -1 , false , retCG );
187+ }
188+ finally {
189+ pool .shutdown ();
190+ }
76191
77- return append ((CompressedMatrixBlock ) left , (CompressedMatrixBlock ) right , m , n );
78192 }
79193
80194 private static MatrixBlock appendLeftUncompressed (MatrixBlock left , CompressedMatrixBlock right , final int m ,
@@ -123,17 +237,17 @@ private static MatrixBlock append(CompressedMatrixBlock left, CompressedMatrixBl
123237 ret .setNonZeros (left .getNonZeros () + right .getNonZeros ());
124238 ret .setOverlapping (left .isOverlapping () || right .isOverlapping ());
125239
126- final double compressedSize = ret .getInMemorySize ();
127- final double uncompressedSize = MatrixBlock .estimateSizeInMemory (m , n , ret .getSparsity ());
240+ // final double compressedSize = ret.getInMemorySize();
241+ // final double uncompressedSize = MatrixBlock.estimateSizeInMemory(m, n, ret.getSparsity());
128242
129- if (compressedSize < uncompressedSize )
130- return ret ;
131- else {
132- final double ratio = uncompressedSize / compressedSize ;
133- String message = String .format ("Decompressing c bind matrix because it had to small compression ratio: %2.3f" ,
134- ratio );
135- return ret .getUncompressed (message );
136- }
243+ // if(compressedSize < uncompressedSize)
244+ return ret ;
245+ // else {
246+ // final double ratio = uncompressedSize / compressedSize;
247+ // String message = String.format("Decompressing c bind matrix because it had to small compression ratio: %2.3f",
248+ // ratio);
249+ // return ret.getUncompressed(message);
250+ // }
137251 }
138252
139253 private static MatrixBlock appendRightEmpty (CompressedMatrixBlock left , MatrixBlock right , int m , int n ) {
0 commit comments