Skip to content

Commit 407090d

Browse files
committed
fixed ctable with seq fuse
rewrite fused ctable with given output dim (disaled: performance decrease, need to fix it first)
1 parent 749ec56 commit 407090d

File tree

6 files changed

+74
-56
lines changed

6 files changed

+74
-56
lines changed

scripts/builtin/kmeans.dml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,10 @@ m_kmeans = function(Matrix[Double] X, Integer k = 10, Integer runs = 10, Integer
145145
}
146146

147147
# Find the closest centroid for each record
148-
# P = D <= minD;
148+
P = D <= minD;
149149
# If some records belong to multiple centroids, share them equally
150-
# P = P / rowSums (P);
151-
P = table(seq(1,nrow(D)), rowIndexMin(D))
152-
# P = table(seq(1,nrow(D)),compress(rowIndexMin(D)))
150+
P = P / rowSums (P);
151+
# P = table(seq(1,num_records), rowIndexMin(D), num_records, num_centroids)
153152
# Compute the column normalization factor for P
154153
P_denom = colSums (P);
155154
# Compute new centroids as weighted averages over the records

src/main/java/org/apache/sysds/hops/TernaryOp.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,8 @@ public boolean isSequenceRewriteApplicable( boolean left )
651651

652652
try
653653
{
654+
// TODO: to rewrite is not currently not triggered if outdim are given --> getInput().size()>=3
655+
// currently disabled due performance decrease
654656
if( getInput().size()==2 || (getInput().size()==3 && getInput().get(2).getDataType()==DataType.SCALAR) )
655657
{
656658
Hop input1 = getInput().get(0);

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
4040
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
4141
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
42+
import org.apache.sysds.runtime.matrix.data.Pair;
4243
import org.apache.sysds.runtime.util.CommonThreadPool;
4344
import org.apache.sysds.runtime.util.UtilFunctions;
4445

@@ -71,19 +72,23 @@ public static MatrixBlock rexpand(int seqHeight, MatrixBlock A, int nColOut, int
7172

7273
try {
7374
final int[] map = new int[seqHeight];
74-
int maxCol = constructInitialMapping(map, A, k);
75+
Pair<Integer, Integer> meta = constructInitialMapping(map, A, k, nColOut);
76+
int maxCol = meta.getKey();
77+
int nZeros = meta.getValue();
7578
boolean containsNull = maxCol < 0;
7679
maxCol = Math.abs(maxCol);
7780

81+
boolean cutOff = false;
7882
if(nColOut == -1)
7983
nColOut = maxCol;
8084
else if(nColOut < maxCol)
81-
throw new DMLRuntimeException("invalid nColOut, requested: " + nColOut + " but have to be : " + maxCol);
85+
cutOff = true;
8286

83-
final int nNulls = containsNull ? correctNulls(map, nColOut) : 0;
87+
if(containsNull)
88+
correctNulls(map, nColOut);
8489
if(nColOut == 0) // edge case of empty zero dimension block.
8590
return new MatrixBlock(seqHeight, 0, 0.0);
86-
return createCompressedReturn(map, nColOut, seqHeight, nNulls, containsNull, k);
91+
return createCompressedReturn(map, nColOut, seqHeight, nZeros, containsNull || cutOff, k);
8792
}
8893
catch(Exception e) {
8994
throw new RuntimeException("Failed table seq operator", e);
@@ -139,7 +144,7 @@ private static int correctNulls(int[] map, int nColOut) {
139144
return nNulls;
140145
}
141146

142-
private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {
147+
private static Pair<Integer,Integer> constructInitialMapping(int[] map, MatrixBlock A, int k, int maxOutCol) {
143148
if(A.isEmpty() || A.isInSparseFormat())
144149
throw new DMLRuntimeException("not supported empty or sparse construction of seq table");
145150
final MatrixBlock Ac;
@@ -155,20 +160,23 @@ private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {
155160
try {
156161

157162
int blkz = Math.max((map.length / k), 1000);
158-
List<Future<Integer>> tasks = new ArrayList<>();
163+
List<Future<Pair<Integer,Integer>>> tasks = new ArrayList<>();
159164
for(int i = 0; i < map.length; i += blkz) {
160165
final int start = i;
161166
final int end = Math.min(i + blkz, map.length);
162-
tasks.add(pool.submit(() -> partialMapping(map, Ac, start, end)));
167+
tasks.add(pool.submit(() -> partialMapping(map, Ac, start, end, maxOutCol)));
163168
}
164169

165170
int maxCol = 0;
166-
for(Future<Integer> f : tasks) {
167-
int tmp = f.get();
168-
if(Math.abs(tmp) > Math.abs(maxCol))
169-
maxCol = tmp;
171+
int zeros = 0;
172+
for(Future<Pair<Integer,Integer>> f : tasks) {
173+
int tmpMaxCol = f.get().getKey();
174+
int tmpZeros = f.get().getValue();
175+
if(Math.abs(tmpMaxCol) > Math.abs(maxCol))
176+
maxCol = tmpMaxCol;
177+
zeros += tmpZeros;
170178
}
171-
return maxCol;
179+
return new Pair<Integer,Integer>(maxCol, zeros);
172180
}
173181
catch(Exception e) {
174182
throw new DMLRuntimeException(e);
@@ -179,33 +187,32 @@ private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {
179187

180188
}
181189

182-
private static int partialMapping(int[] map, MatrixBlock A, int start, int end) {
190+
private static Pair<Integer, Integer> partialMapping(int[] map, MatrixBlock A, int start, int end, int maxOutCol) {
183191

184192
int maxCol = 0;
185-
boolean containsNull = false;
186-
193+
int zeros = 0;
194+
int notHandledNulls = 0;
187195
final double[] aVals = A.getDenseBlockValues();
188196

189197
for(int i = start; i < end; i++) {
190198
final double v2 = aVals[i];
191-
if(Double.isNaN(v2)) {
192-
map[i] = -1; // assign temporarily to -1
193-
containsNull = true;
194-
}
195-
else {
196-
// safe casts to long for consistent behavior with indexing
197-
int col = UtilFunctions.toInt(v2);
198-
if(col <= 0)
199-
throw new DMLRuntimeException(
199+
int colUnsafe = UtilFunctions.toInt(v2);
200+
if(colUnsafe <= 0)
201+
throw new DMLRuntimeException(
200202
"Erroneous input while computing the contingency table (value <= zero): " + v2);
203+
boolean invalid = Double.isNaN(v2) || (maxOutCol != -1 && colUnsafe > maxOutCol);
204+
final int colSafe = invalid ? maxOutCol : colUnsafe - 1;
205+
zeros += invalid ? 1 : 0;
206+
notHandledNulls += Double.isNaN(v2) ? maxOutCol : 0;
207+
maxCol = Math.max(colUnsafe, maxCol);
208+
map[i] = colSafe;
209+
}
201210

202-
map[i] = col - 1;
203-
// maintain max seen col
204-
maxCol = Math.max(col, maxCol);
205-
}
211+
if (notHandledNulls < 0){
212+
maxCol *= -1;
206213
}
207214

208-
return containsNull ? maxCol * -1 : maxCol;
215+
return new Pair<Integer, Integer>(maxCol, zeros);
209216
}
210217

211218
public static boolean compressedTableSeq() {

src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,17 @@ public void processInstruction(ExecutionContext ec) {
110110

111111
boolean outputDimsKnown = (outputDim1 != -1 && outputDim2 != -1);
112112
if ( outputDimsKnown ) {
113-
int inputRows = matBlock1.getNumRows();
114-
int inputCols = matBlock1.getNumColumns();
115-
boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, inputRows*inputCols);
116-
//only create result block if dense; it is important not to aggregate on sparse result
117-
//blocks because it would implicitly turn the O(N) algorithm into O(N log N).
118-
if( !sparse )
119-
resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, false);
113+
if(_isExpand){
114+
resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, true);
115+
} else {
116+
int inputRows = matBlock1.getNumRows();
117+
int inputCols = matBlock1.getNumColumns();
118+
boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, inputRows*inputCols);
119+
//only create result block if dense; it is important not to aggregate on sparse result
120+
//blocks because it would implicitly turn the O(N) algorithm into O(N log N).
121+
if( !sparse )
122+
resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, false);
123+
}
120124
}
121125

122126
switch(ctableOp) {
@@ -140,7 +144,8 @@ public void processInstruction(ExecutionContext ec) {
140144
}
141145
matBlock2 = ec.getMatrixInput(input2.getName());
142146
cst1 = ec.getScalarInput(input3).getDoubleValue();
143-
resultBlock = LibMatrixReorg.fusedSeqRexpand(matBlock2.getNumRows(), matBlock2, cst1, resultBlock, true, _k);
147+
resultBlock = LibMatrixReorg.fusedSeqRexpand(matBlock2.getNumRows(), matBlock2, cst1, resultBlock,
148+
!outputDimsKnown, _k);
144149
break;
145150
case CTABLE_TRANSFORM_HISTOGRAM: //(VECTOR)
146151
// F=ctable(A,1) or F = ctable(A,1,1)

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,11 +1044,13 @@ public static MatrixBlock fusedSeqRexpand(int seqHeight, MatrixBlock A, double w
10441044

10451045
}
10461046

1047-
private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, double w, MatrixBlock ret, boolean updateClen) {
1047+
private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, double w, MatrixBlock ret,
1048+
boolean updateClen) {
10481049
if(ret == null) {
10491050
ret = new MatrixBlock();
10501051
updateClen = true;
10511052
}
1053+
int outCols = updateClen ? -1 : ret.getNumColumns();
10521054
final int rlen = seqHeight;
10531055
// prepare allocation of CSR sparse block
10541056
final int[] rowPointers = new int[rlen + 1];
@@ -1060,14 +1062,14 @@ private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, d
10601062
ret.sparse = true;
10611063
ret.denseBlock = null;
10621064
// construct sparse CSR block from filled arrays
1063-
SparseBlockCSR csr = new SparseBlockCSR(rowPointers, indexes, values, rlen);
1065+
SparseBlockCSR csr = new SparseBlockCSR(rowPointers, indexes, values, seqHeight);
10641066
ret.sparseBlock = csr;
1065-
int blkz = Math.min(1024, rlen);
1067+
int blkz = Math.min(1024, seqHeight);
10661068
int maxcol = 0;
10671069
boolean containsNull = false;
1068-
for(int i = 0; i < rlen; i += blkz) {
1070+
for(int i = 0; i < seqHeight; i += blkz) {
10691071
// blocked execution for earlier JIT compilation
1070-
int t = fusedSeqRexpandSparseBlock(csr, A, w, i, Math.min(i + blkz, rlen));
1072+
int t = fusedSeqRexpandSparseBlock(csr, A, w, i, Math.min(i + blkz, seqHeight), (int) outCols);
10711073
if(t < 0) {
10721074
t = Math.abs(t);
10731075
containsNull = true;
@@ -1078,14 +1080,15 @@ private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, d
10781080
if(containsNull)
10791081
csr.compact();
10801082

1081-
rowPointers[rlen] = rlen;
1083+
rowPointers[seqHeight] = seqHeight;
10821084
ret.setNonZeros(ret.sparseBlock.size());
10831085
if(updateClen)
1084-
ret.setNumColumns(maxcol);
1086+
ret.setNumColumns(outCols == -1 ? maxcol : (int) outCols);
10851087
return ret;
10861088
}
10871089

1088-
private static int fusedSeqRexpandSparseBlock(final SparseBlockCSR csr, final MatrixBlock A, final double w, int rl, int ru) {
1090+
private static int fusedSeqRexpandSparseBlock(final SparseBlockCSR csr, final MatrixBlock A, final double w, int rl,
1091+
int ru, int maxOutCol) {
10891092

10901093
// prepare allocation of CSR sparse block
10911094
final int[] rowPointers = csr.rowPointers();
@@ -1096,7 +1099,7 @@ private static int fusedSeqRexpandSparseBlock(final SparseBlockCSR csr, final Ma
10961099
int maxCol = 0;
10971100

10981101
for(int i = rl; i < ru; i++) {
1099-
int c = rexpandSingleRow(i, A.get(i, 0), w, indexes, values);
1102+
int c = rexpandSingleRow(i, A.get(i, 0), w, indexes, values, maxOutCol);
11001103
if(c < 0)
11011104
containsNull = true;
11021105
else
@@ -1114,7 +1117,7 @@ private static void updateClenRexpand(MatrixBlock ret, int maxCol, boolean updat
11141117
ret.clen = maxCol;
11151118
}
11161119

1117-
public static int rexpandSingleRow(int row, double v2, double w, int[] retIx, double[] retVals) {
1120+
public static int rexpandSingleRow(int row, double v2, double w, int[] retIx, double[] retVals, int maxOutCol) {
11181121
// If any of the values are NaN (i.e., missing) then
11191122
// we skip this tuple, proceed to the next tuple
11201123
if(Double.isNaN(v2))
@@ -1124,10 +1127,12 @@ public static int rexpandSingleRow(int row, double v2, double w, int[] retIx, d
11241127
int col = UtilFunctions.toInt(v2);
11251128
if(col <= 0)
11261129
throw new DMLRuntimeException("Erroneous input while computing the contingency table (value <= zero): " + v2);
1127-
1128-
// set weight as value (expand is guaranteed to address different cells)
1129-
retIx[row] = col - 1;
1130-
retVals[row] = w;
1130+
// maxOutCol = - 1 if not specified --> TRUE
1131+
if(col <= maxOutCol){
1132+
// set weight as value (expand is guaranteed to address different cells)
1133+
retIx[row] = col - 1;
1134+
retVals[row] = w;
1135+
}
11311136

11321137
// maintain max seen col
11331138
return col;

src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ public void federatedKmeans(Types.ExecMode execMode, boolean singleWorker) {
123123

124124
// Run actual dml script with federated matrix
125125
fullDMLScriptName = HOME + TEST_NAME + ".dml";
126-
programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
126+
programArgs = new String[] {"-stats","20", "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
127127
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
128128
"single=" + String.valueOf(singleWorker).toUpperCase(), "runs=" + String.valueOf(runs), "out=" + output("Z")};
129129

0 commit comments

Comments
 (0)