Skip to content

Commit 5c5fb65

Browse files
committed
CLALib Reshape Tests
1 parent 4c0aee7 commit 5c5fb65

File tree

3 files changed

+164
-34
lines changed

3 files changed

+164
-34
lines changed

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReshape.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ private CLALibReshape(CompressedMatrixBlock in, int rows, int cols, boolean roww
5959
this.rows = rows;
6060
this.cols = cols;
6161
this.rowwise = rowwise;
62-
this.pool = k >= 1 ? CommonThreadPool.get(k) : null;
62+
this.pool = k > 1 ? CommonThreadPool.get(k) : null;
6363
}
6464

6565
public static MatrixBlock reshape(CompressedMatrixBlock in, int rows, int cols, boolean rowwise) {
@@ -161,7 +161,7 @@ private void checkValidity() {
161161

162162
private boolean shouldItBeCompressedOutputs() {
163163
// The number of rows in the reshaped allocations is fairly large.
164-
return rlen > COMPRESSED_RESHAPE_THRESHOLD &&
164+
return rlen > COMPRESSED_RESHAPE_THRESHOLD && rowwise &&
165165
// the reshape is a clean multiplier of number of rows, meaning each column group cleanly reshape into x others
166166
(double) rlen / rows % 1.0 == 0.0;
167167
}

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,11 +1016,12 @@ public static MatrixBlock fusedSeqRexpand(int seqHeight, MatrixBlock A, double w
10161016
"Invalid input A in table(seq(1, nrow(A)), A, w): A should only have one column but has: "
10171017
+ A.getNumColumns());
10181018

1019-
if(!Double.isNaN(w)) {
1019+
if(!Double.isNaN(w) && w != 0) {
10201020
if((CLALibRexpand.compressedTableSeq() || A instanceof CompressedMatrixBlock) && w == 1)
10211021
return CLALibRexpand.rexpand(seqHeight, A, updateClen ? -1 : ret.getNumColumns(), k);
1022-
else
1023-
return fusedSeqRexpandSparseBlock(seqHeight, A, w, ret, updateClen, k);
1022+
else{
1023+
return fusedSeqRexpandSparse(seqHeight, A, w, ret, updateClen);
1024+
}
10241025
}
10251026
else {
10261027
if(ret == null) {
@@ -1040,22 +1041,58 @@ public static MatrixBlock fusedSeqRexpand(int seqHeight, MatrixBlock A, double w
10401041

10411042
}
10421043

1043-
private static MatrixBlock fusedSeqRexpandSparseBlock(final int rlen, final MatrixBlock A, final double w, MatrixBlock ret,
1044-
boolean updateClen, int k ) {
1045-
1044+
private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, double w, MatrixBlock ret, boolean updateClen) {
1045+
if(ret == null) {
1046+
ret = new MatrixBlock();
1047+
updateClen = true;
1048+
}
1049+
final int rlen = seqHeight;
10461050
// prepare allocation of CSR sparse block
10471051
final int[] rowPointers = new int[rlen + 1];
10481052
final int[] indexes = new int[rlen];
10491053
final double[] values = new double[rlen];
10501054

1051-
// sparse-unsafe table execution
1052-
// (because input values of 0 are invalid and have to result in errors)
1053-
// resultBlock guaranteed to be allocated for table expand
1054-
// each row in resultBlock will be allocated and will contain exactly one value
1055+
ret.rlen = rlen;
1056+
// assign the output
1057+
ret.sparse = true;
1058+
ret.denseBlock = null;
1059+
// construct sparse CSR block from filled arrays
1060+
SparseBlockCSR csr = new SparseBlockCSR(rowPointers, indexes, values, rlen);
1061+
ret.sparseBlock = csr;
1062+
int blkz = Math.min(1024, rlen);
1063+
int maxcol = 0;
1064+
boolean containsNull = false;
1065+
for(int i = 0; i < rlen; i += blkz) {
1066+
// blocked execution for earlier JIT compilation
1067+
int t = fusedSeqRexpandSparseBlock(csr, A, w, i, Math.min(i + blkz, rlen));
1068+
if(t < 0) {
1069+
t = Math.abs(t);
1070+
containsNull = true;
1071+
}
1072+
maxcol = Math.max(t, maxcol);
1073+
}
1074+
1075+
if(containsNull)
1076+
csr.compact();
1077+
1078+
rowPointers[rlen] = rlen;
1079+
ret.setNonZeros(ret.sparseBlock.size());
1080+
if(updateClen)
1081+
ret.setNumColumns(maxcol);
1082+
return ret;
1083+
}
1084+
1085+
private static int fusedSeqRexpandSparseBlock(final SparseBlockCSR csr, final MatrixBlock A, final double w, int rl, int ru) {
1086+
1087+
// prepare allocation of CSR sparse block
1088+
final int[] rowPointers = csr.rowPointers();
1089+
final int[] indexes = csr.indexes();
1090+
final double[] values = csr.values();
1091+
10551092
boolean containsNull = false;
10561093
int maxCol = 0;
10571094

1058-
for(int i = 0; i < rlen; i++) {
1095+
for(int i = rl; i < ru; i++) {
10591096
int c = rexpandSingleRow(i, A.get(i, 0), w, indexes, values);
10601097
if(c < 0)
10611098
containsNull = true;
@@ -1064,27 +1101,7 @@ private static MatrixBlock fusedSeqRexpandSparseBlock(final int rlen, final Matr
10641101
rowPointers[i] = i;
10651102
}
10661103

1067-
rowPointers[rlen] = rlen;
1068-
1069-
if(ret == null) {
1070-
ret = new MatrixBlock();
1071-
updateClen = true;
1072-
}
1073-
1074-
ret.rlen = rlen;
1075-
// assign the output
1076-
ret.sparse = true;
1077-
ret.denseBlock = null;
1078-
// construct sparse CSR block from filled arrays
1079-
ret.sparseBlock = new SparseBlockCSR(rowPointers, indexes, values, rlen);
1080-
// compact all the null entries.
1081-
if(containsNull){
1082-
((SparseBlockCSR) ret.sparseBlock).compact();
1083-
}
1084-
ret.setNonZeros(ret.sparseBlock.size());
1085-
1086-
updateClenRexpand(ret, maxCol, updateClen);
1087-
return ret;
1104+
return containsNull ? -maxCol: maxCol;
10881105
}
10891106

10901107
private static void updateClenRexpand(MatrixBlock ret, int maxCol, boolean updateClen) {
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.component.compress.lib;
21+
22+
import org.apache.commons.logging.Log;
23+
import org.apache.commons.logging.LogFactory;
24+
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
25+
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
26+
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
27+
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
28+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
29+
import org.apache.sysds.test.TestUtils;
30+
import org.junit.Test;
31+
32+
public class CLALibReshapeTests {
33+
protected static final Log LOG = LogFactory.getLog(CLALibReshapeTests.class.getName());
34+
35+
static{
36+
Thread.currentThread().setName("test_reshape");
37+
}
38+
39+
@Test
40+
public void reshapeSimple() {
41+
MatrixBlock mb = TestUtils.generateTestMatrixBlock(1000, 5, 1, 1, 0.5, 235);
42+
MatrixBlock m2 = CompressedMatrixBlockFactory.compress(mb).getLeft();
43+
44+
MatrixBlock m3 = CLALibReshape.reshape((CompressedMatrixBlock) m2, 500, 10, false);
45+
MatrixBlock ref = LibMatrixReorg.reshape(mb, 500, 10, false);
46+
47+
TestUtils.compareMatrices(ref, m3, 0);
48+
}
49+
50+
@Test
51+
public void reshapeSimple2Rowwise() {
52+
MatrixBlock mb = TestUtils.generateTestMatrixBlock(3000, 1, 1, 1, 0.5, 235);
53+
MatrixBlock m2 = CompressedMatrixBlockFactory.compress(mb).getLeft();
54+
55+
MatrixBlock m3 = CLALibReshape.reshape((CompressedMatrixBlock) m2, 1500, 2, true);
56+
MatrixBlock ref = LibMatrixReorg.reshape(mb, 1500, 2, true);
57+
58+
TestUtils.compareMatrices(ref, m3, 0);
59+
}
60+
61+
@Test
62+
public void reshapeMulti2Rowwise() {
63+
MatrixBlock mb = TestUtils.generateTestMatrixBlock(3000, 4, 1, 1, 0.5, 235);
64+
MatrixBlock m2 = CompressedMatrixBlockFactory.compress(mb).getLeft();
65+
66+
MatrixBlock m3 = CLALibReshape.reshape((CompressedMatrixBlock) m2, 1500, 8, true);
67+
MatrixBlock ref = LibMatrixReorg.reshape(mb, 1500, 8, true);
68+
69+
TestUtils.compareMatrices(ref, m3, 0);
70+
}
71+
72+
73+
@Test
74+
public void reshapeMulti2RowwiseSingleThread() {
75+
MatrixBlock mb = TestUtils.generateTestMatrixBlock(3000, 4, 1, 1, 0.5, 235);
76+
MatrixBlock m2 = CompressedMatrixBlockFactory.compress(mb).getLeft();
77+
78+
MatrixBlock m3 = CLALibReshape.reshape((CompressedMatrixBlock) m2, 1500, 8, true, 1);
79+
MatrixBlock ref = LibMatrixReorg.reshape(mb, 1500, 8, true);
80+
81+
TestUtils.compareMatrices(ref, m3, 0);
82+
}
83+
84+
@Test
85+
public void reshapeSimple2RowwiseNotMultiply() {
86+
MatrixBlock mb = TestUtils.generateTestMatrixBlock(3000, 2, 1, 1, 0.5, 235);
87+
MatrixBlock m2 = CompressedMatrixBlockFactory.compress(mb).getLeft();
88+
89+
MatrixBlock m3 = CLALibReshape.reshape((CompressedMatrixBlock) m2, 2000, 3, true);
90+
MatrixBlock ref = LibMatrixReorg.reshape(mb, 2000, 3, true);
91+
92+
TestUtils.compareMatrices(ref, m3, 0);
93+
}
94+
95+
@Test
96+
public void reshapeSimple2ColWise() {
97+
MatrixBlock mb = TestUtils.generateTestMatrixBlock(3000, 1, 1, 1, 0.5, 235);
98+
MatrixBlock m2 = CompressedMatrixBlockFactory.compress(mb).getLeft();
99+
100+
MatrixBlock m3 = CLALibReshape.reshape((CompressedMatrixBlock) m2, 1500, 2, false);
101+
MatrixBlock ref = LibMatrixReorg.reshape(mb, 1500, 2, false);
102+
103+
TestUtils.compareMatrices(ref, m3, 0);
104+
}
105+
106+
@Test(expected = Exception.class)
107+
public void reshapeInvalid() {
108+
MatrixBlock mb = TestUtils.generateTestMatrixBlock(1000, 5, 1, 1, 0.5, 235);
109+
MatrixBlock m2 = CompressedMatrixBlockFactory.compress(mb).getLeft();
110+
111+
CLALibReshape.reshape((CompressedMatrixBlock) m2, 501, 10, false);
112+
}
113+
}

0 commit comments

Comments
 (0)