Skip to content

Commit 12367cb

Browse files
committed
[SYSTEMDS-3828] Parallel Compressed Replace
This commit adds the parallel kernel for compressed replace of values. Closes #2209
1 parent b751389 commit 12367cb

File tree

7 files changed

+185
-74
lines changed

7 files changed

+185
-74
lines changed

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.apache.sysds.runtime.compress.lib.CLALibMMChain;
5959
import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
6060
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
61+
import org.apache.sysds.runtime.compress.lib.CLALibReplace;
6162
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
6263
import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
6364
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
@@ -307,7 +308,7 @@ public void putInto(MatrixBlock target, int rowOffset, int colOffset, boolean sp
307308
* @return The cached decompressed matrix, if it does not exist return null
308309
*/
309310
public MatrixBlock getCachedDecompressed() {
310-
if( allowCachingUncompressed && decompressedVersion != null) {
311+
if(allowCachingUncompressed && decompressedVersion != null) {
311312
final MatrixBlock mb = decompressedVersion.get();
312313
if(mb != null) {
313314
DMLCompressionStatistics.addDecompressCacheCount();
@@ -401,8 +402,8 @@ public long estimateCompressedSizeInMemory() {
401402
long total = baseSizeInMemory();
402403
// take into consideration duplicate dictionaries
403404
Set<IDictionary> dicts = new HashSet<>();
404-
for(AColGroup grp : _colGroups){
405-
if(grp instanceof ADictBasedColGroup){
405+
for(AColGroup grp : _colGroups) {
406+
if(grp instanceof ADictBasedColGroup) {
406407
IDictionary dg = ((ADictBasedColGroup) grp).getDictionary();
407408
if(dicts.contains(dg))
408409
total -= dg.getInMemorySize();
@@ -576,8 +577,7 @@ public void append(MatrixValue v2, ArrayList<IndexedMatrixValue> outlist, int bl
576577
}
577578

578579
@Override
579-
public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock w, MatrixBlock out, ChainType ctype,
580-
int k) {
580+
public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock w, MatrixBlock out, ChainType ctype, int k) {
581581

582582
checkMMChain(ctype, v, w);
583583
// multi-threaded MMChain of single uncompressed ColGroup
@@ -629,27 +629,8 @@ public MatrixBlock transposeSelfMatrixMultOperations(MatrixBlock out, MMTSJType
629629
}
630630

631631
@Override
632-
public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement) {
633-
if(Double.isInfinite(pattern)) {
634-
LOG.info("Ignoring replace infinite in compression since it does not contain this value");
635-
return this;
636-
}
637-
else if(isOverlapping()) {
638-
final String message = "replaceOperations " + pattern + " -> " + replacement;
639-
return getUncompressed(message).replaceOperations(result, pattern, replacement);
640-
}
641-
else {
642-
643-
CompressedMatrixBlock ret = new CompressedMatrixBlock(getNumRows(), getNumColumns());
644-
final List<AColGroup> prev = getColGroups();
645-
final int colGroupsLength = prev.size();
646-
final List<AColGroup> retList = new ArrayList<>(colGroupsLength);
647-
for(int i = 0; i < colGroupsLength; i++)
648-
retList.add(prev.get(i).replace(pattern, replacement));
649-
ret.allocateColGroupList(retList);
650-
ret.recomputeNonZeros();
651-
return ret;
652-
}
632+
public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement, int k) {
633+
return CLALibReplace.replace(this, (MatrixBlock) result, pattern, replacement, k);
653634
}
654635

655636
@Override
@@ -710,10 +691,10 @@ public boolean containsValue(double pattern) {
710691
return false;
711692
}
712693
}
713-
694+
714695
@Override
715696
public boolean containsValue(double pattern, int k) {
716-
//TODO parallel contains value
697+
// TODO parallel contains value
717698
return containsValue(pattern);
718699
}
719700

@@ -775,8 +756,8 @@ public boolean isEmptyBlock(boolean safe) {
775756
return false;
776757
else if(_colGroups == null || nonZeros == 0)
777758
return true;
778-
else{
779-
if(nonZeros == -1){
759+
else {
760+
if(nonZeros == -1) {
780761
// try to use column groups
781762
for(AColGroup g : _colGroups)
782763
if(!g.isEmpty())
@@ -1177,8 +1158,7 @@ public void appendRow(int r, SparseRow row, boolean deep) {
11771158
}
11781159

11791160
@Override
1180-
public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i, int rowoffset, int coloffset,
1181-
boolean deep) {
1161+
public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i, int rowoffset, int coloffset, boolean deep) {
11821162
throw new DMLCompressionException("Can't append row to compressed Matrix");
11831163
}
11841164

@@ -1238,7 +1218,7 @@ public void sparseToDense(int k) {
12381218
}
12391219

12401220
@Override
1241-
public void denseToSparse(boolean allowCSR, int k){
1221+
public void denseToSparse(boolean allowCSR, int k) {
12421222
// do nothing
12431223
}
12441224

@@ -1327,13 +1307,13 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype
13271307
throw new DMLCompressionException("Invalid to allocate block on a compressed MatrixBlock");
13281308
}
13291309

1330-
@Override
1310+
@Override
13311311
public MatrixBlock transpose(int k) {
13321312
return getUncompressed().transpose(k);
13331313
}
13341314

1335-
@Override
1336-
public MatrixBlock reshape(int rows,int cols, boolean byRow){
1315+
@Override
1316+
public MatrixBlock reshape(int rows, int cols, boolean byRow) {
13371317
return CLALibReshape.reshape(this, rows, cols, byRow);
13381318
}
13391319

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.compress.lib;
21+
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
import java.util.concurrent.ExecutionException;
25+
import java.util.concurrent.ExecutorService;
26+
import java.util.concurrent.Future;
27+
28+
import org.apache.commons.logging.Log;
29+
import org.apache.commons.logging.LogFactory;
30+
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
31+
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
32+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
33+
import org.apache.sysds.runtime.util.CommonThreadPool;
34+
35+
public class CLALibReplace {
36+
private static final Log LOG = LogFactory.getLog(CLALibReplace.class.getName());
37+
38+
private CLALibReplace(){
39+
// private constructor
40+
}
41+
42+
public static MatrixBlock replace(CompressedMatrixBlock in, MatrixBlock out, double pattern, double replacement,
43+
int k) {
44+
try {
45+
46+
if(Double.isInfinite(pattern)) {
47+
LOG.info("Ignoring replace infinite in compression since it does not contain this value");
48+
return in;
49+
}
50+
else if(in.isOverlapping()) {
51+
final String message = "replaceOperations " + pattern + " -> " + replacement;
52+
return in.getUncompressed(message).replaceOperations(out, pattern, replacement);
53+
}
54+
else
55+
return replaceNormal(in, out, pattern, replacement, k);
56+
}
57+
catch(Exception e) {
58+
throw new RuntimeException("Failed replace pattern: " + pattern + " replacement: " + replacement, e);
59+
}
60+
}
61+
62+
private static MatrixBlock replaceNormal(CompressedMatrixBlock in, MatrixBlock out, double pattern,
63+
double replacement, int k) throws Exception {
64+
CompressedMatrixBlock ret = new CompressedMatrixBlock(in.getNumRows(), in.getNumColumns());
65+
final List<AColGroup> prev = in.getColGroups();
66+
final int colGroupsLength = prev.size();
67+
final List<AColGroup> retList = new ArrayList<>(colGroupsLength);
68+
69+
if(k <= 1)
70+
replaceSingleThread(pattern, replacement, prev, colGroupsLength, retList);
71+
else
72+
replaceMultiThread(pattern, replacement, k, prev, colGroupsLength, retList);
73+
74+
ret.allocateColGroupList(retList);
75+
if(replacement == 0) // have to recompute!
76+
ret.recomputeNonZeros();
77+
else if(pattern == 0) // always fully dense.
78+
ret.setNonZeros(((long) in.getNumRows()) * in.getNumColumns());
79+
else // same nonzeros as input
80+
ret.setNonZeros(in.getNonZeros());
81+
return ret;
82+
}
83+
84+
private static void replaceMultiThread(double pattern, double replacement, int k, final List<AColGroup> prev,
85+
final int colGroupsLength, final List<AColGroup> retList) throws InterruptedException, ExecutionException {
86+
ExecutorService pool = CommonThreadPool.get(k);
87+
88+
try {
89+
List<Future<AColGroup>> tasks = new ArrayList<>(colGroupsLength);
90+
for(int i = 0; i < colGroupsLength; i++) {
91+
final int j = i;
92+
tasks.add(pool.submit(() -> prev.get(j).replace(pattern, replacement)));
93+
}
94+
for(int i = 0; i < colGroupsLength; i++) {
95+
retList.add(tasks.get(i).get());
96+
}
97+
}
98+
finally {
99+
pool.shutdown();
100+
}
101+
}
102+
103+
private static void replaceSingleThread(double pattern, double replacement, final List<AColGroup> prev,
104+
final int colGroupsLength, final List<AColGroup> retList) {
105+
for(int i = 0; i < colGroupsLength; i++)
106+
retList.add(prev.get(i).replace(pattern, replacement));
107+
}
108+
}

src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
import org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
6767
import org.apache.sysds.runtime.util.AutoDiff;
6868
import org.apache.sysds.runtime.util.DataConverter;
69+
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
6970

7071
public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction {
7172
private static final Log LOG = LogFactory.getLog(ParameterizedBuiltinCPInstruction.class.getName());
@@ -276,7 +277,8 @@ else if(opcode.equalsIgnoreCase(Opcodes.REPLACE.toString())) {
276277
MatrixBlock target = targetObj.acquireRead();
277278
double pattern = Double.parseDouble(params.get("pattern"));
278279
double replacement = Double.parseDouble(params.get("replacement"));
279-
MatrixBlock ret = target.replaceOperations(new MatrixBlock(), pattern, replacement);
280+
MatrixBlock ret = target.replaceOperations(new MatrixBlock(), pattern, replacement,
281+
InfrastructureAnalyzer.getLocalParallelism());
280282
if( ret == target ) //shallow copy (avoid bufferpool pollution)
281283
ec.setVariable(output.getName(), targetObj);
282284
else

src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5157,9 +5157,13 @@ public MatrixBlock rexpandOperations( MatrixBlock ret, double max, boolean rows,
51575157

51585158

51595159
@Override
5160-
public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement) {
5160+
public final MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement) {
5161+
return replaceOperations(result, pattern, replacement, 1);
5162+
}
5163+
5164+
public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement, int k) {
51615165
MatrixBlock ret = checkType(result);
5162-
return LibMatrixReplace.replaceOperations(this, ret, pattern, replacement);
5166+
return LibMatrixReplace.replaceOperations(this, ret, pattern, replacement, k);
51635167
}
51645168

51655169
public MatrixBlock extractTriangular(MatrixBlock ret, boolean lower, boolean diag, boolean values) {

src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.apache.sysds.test.component.compress;
2121

2222
import static org.junit.Assert.assertEquals;
23+
import static org.junit.Assert.assertNull;
2324
import static org.junit.Assert.assertTrue;
2425
import static org.junit.Assert.fail;
2526

@@ -38,6 +39,7 @@
3839
import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory;
3940
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
4041
import org.apache.sysds.runtime.compress.lib.CLALibCBind;
42+
import org.apache.sysds.runtime.compress.lib.CLALibReplace;
4143
import org.apache.sysds.runtime.compress.workload.WTreeRoot;
4244
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
4345
import org.apache.sysds.test.TestUtils;
@@ -397,9 +399,18 @@ public void manyRowsButNotQuite() {
397399
TestUtils.compareMatricesBitAvgDistance(m1, m2, 0, 0, "no");
398400
}
399401

402+
@Test(expected = Exception.class)
403+
public void cbindWithError() {
404+
CLALibCBind.cbind(null, new MatrixBlock[] {null}, 0);
405+
}
400406

401407
@Test(expected = Exception.class)
402-
public void cbindWithError(){
403-
CLALibCBind.cbind(null, new MatrixBlock[]{null}, 0);
408+
public void replaceWithError() {
409+
CLALibReplace.replace(null, null, 0, 0, 10);
410+
}
411+
412+
@Test
413+
public void replaceInf() {
414+
assertNull(CLALibReplace.replace(null, null, Double.POSITIVE_INFINITY, 0, 10));
404415
}
405416
}

src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -329,38 +329,6 @@ public void testContainsValue_not() {
329329
}
330330
}
331331

332-
@Test
333-
public void testReplaceNotContainedValue() {
334-
double v = min - 1;
335-
if(v != 0)
336-
testReplace(v);
337-
}
338-
339-
@Test
340-
public void testReplace() {
341-
if(min != 0)
342-
testReplace(min);
343-
}
344-
345-
@Test
346-
public void testReplaceZero() {
347-
testReplace(0);
348-
}
349-
350-
private void testReplace(double value) {
351-
try {
352-
if(!(cmb instanceof CompressedMatrixBlock) || rows * cols > 10000)
353-
return;
354-
ucRet = mb.replaceOperations(ucRet, value, 1425);
355-
MatrixBlock ret2 = cmb.replaceOperations(new MatrixBlock(), value, 1425);
356-
compareResultMatrices(ucRet, ret2, 1);
357-
}
358-
catch(Exception e) {
359-
e.printStackTrace();
360-
throw new DMLRuntimeException(e);
361-
}
362-
}
363-
364332
@Test
365333
public void testCompressedMatrixConstruction() {
366334
try {

src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,7 @@ public void appendCBindAlignedSelfMultiple() {
11731173
}
11741174
catch(AssertionError e) {
11751175
e.printStackTrace();
1176-
fail("failed Cbind: " + cmb.toString() );
1176+
fail("failed Cbind: " + cmb.toString());
11771177
}
11781178
}
11791179

@@ -1299,4 +1299,42 @@ protected static CompressionSettingsBuilder csb() {
12991299
return new CompressionSettingsBuilder().setSeed(compressionSeed).setMinimumSampleSize(100);
13001300
}
13011301

1302+
@Test
1303+
public void testReplaceNotContainedValue() {
1304+
double v = min - 1;
1305+
if(v != 0)
1306+
testReplace(v, 132);
1307+
}
1308+
1309+
@Test
1310+
public void testReplace() {
1311+
if(min != 0)
1312+
testReplace(min, 323);
1313+
}
1314+
1315+
@Test
1316+
public void testReplaceWithZero() {
1317+
if(min != 0)
1318+
testReplace(min, 0);
1319+
}
1320+
1321+
@Test
1322+
public void testReplaceZero() {
1323+
testReplace(0, 3232);
1324+
}
1325+
1326+
private void testReplace(double value, double replacements) {
1327+
try {
1328+
if(!(cmb instanceof CompressedMatrixBlock) || rows * cols > 10000)
1329+
return;
1330+
ucRet = mb.replaceOperations(ucRet, value, replacements, _k);
1331+
MatrixBlock ret2 = cmb.replaceOperations(new MatrixBlock(), value, replacements, _k);
1332+
compareResultMatrices(ucRet, ret2, 1);
1333+
}
1334+
catch(Exception e) {
1335+
e.printStackTrace();
1336+
throw new DMLRuntimeException(e);
1337+
}
1338+
}
1339+
13021340
}

0 commit comments

Comments
 (0)