Skip to content

Commit 11c5a74

Browse files
committed
[SYSTEMDS-3730] Fixed and improved multi-threaded reverse operations
This patch consolidates the new multi-threaded reverse operation, by using common single- and multi-threaded kernels, fixing parallelization decisions and preallocation, and consolidating the tests.
1 parent 559f770 commit 11c5a74

File tree

3 files changed

+56
-176
lines changed

3 files changed

+56
-176
lines changed

src/main/java/org/apache/sysds/hops/ReorgOp.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,10 @@ else if( getDim1()==1 && getDim2()==1 )
159159
break;
160160
}
161161
case REV: {
162-
long numel = getDim1() * getDim2();
163-
int k = (numel < 3000_000) ?
164-
1 : OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
165162
Transform transform1 = new Transform(
166163
getInput().get(0).constructLops(),
167-
_op, getDataType(), getValueType(), et, k);
164+
_op, getDataType(), getValueType(), et,
165+
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
168166
setOutputDimensions(transform1);
169167
setLineNumbers(transform1);
170168
setLops(transform1);

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java

Lines changed: 28 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -382,18 +382,29 @@ public static MatrixBlock rev( MatrixBlock in, MatrixBlock out ) {
382382
return out;
383383
}
384384

385-
if( in.sparse )
386-
reverseSparse( in, out );
387-
else
388-
reverseDense( in, out );
385+
//set basic meta data and allocate output
386+
out.sparse = in.sparse;
387+
out.nonZeros = in.nonZeros;
388+
389+
390+
if( in.sparse ) {
391+
out.allocateSparseRowsBlock(false);
392+
reverseSparse(in, out, 0, in.rlen);
393+
}
394+
else {
395+
out.allocateDenseBlock(false);
396+
reverseDense(in, out, 0, in.rlen);
397+
}
389398

390399
//System.out.println("rev ("+in.rlen+", "+in.clen+", "+in.sparse+") in "+time.stop()+" ms.");
391400

392401
return out;
393402
}
394403

395404
public static MatrixBlock rev(MatrixBlock in, MatrixBlock out, int k) {
396-
if (k <= 1 || in.isEmptyBlock(false) ) {
405+
if (k <= 1 || in.isEmptyBlock(false)
406+
|| in.getLength() < PAR_NUMCELL_THRESHOLD )
407+
{
397408
return rev(in, out); // fallback to single-threaded
398409

399410
}
@@ -405,10 +416,11 @@ public static MatrixBlock rev(MatrixBlock in, MatrixBlock out, int k) {
405416
out.reset(numRows, numCols, sparse);
406417

407418
// Before starting threads, ensure the output sparse block is allocated!
408-
if (sparse) {
419+
if (sparse)
409420
out.allocateSparseRowsBlock(false);
410-
}
411-
421+
else
422+
out.allocateDenseBlock(false);
423+
412424
// Set up thread pool
413425
ExecutorService pool = CommonThreadPool.get(k);
414426
try {
@@ -420,25 +432,10 @@ public static MatrixBlock rev(MatrixBlock in, MatrixBlock out, int k) {
420432
final int endRow = Math.min((i + 1) * blklen, numRows);
421433

422434
tasks.add(pool.submit(() -> {
423-
if (!sparse) {
424-
// Dense case
425-
double[] inVals = in.getDenseBlockValues();
426-
double[] outVals = out.getDenseBlockValues();
427-
for (int r = startRow; r < endRow; r++) {
428-
int revRow = numRows - r - 1;
429-
System.arraycopy(inVals, revRow * numCols, outVals, r * numCols, numCols);
430-
}
431-
} else {
432-
// Sparse case
433-
SparseBlock inBlk = in.getSparseBlock();
434-
SparseBlock outBlk = out.getSparseBlock();
435-
for (int r = startRow; r < endRow; r++) {
436-
int revRow = numRows - r - 1;
437-
if (!inBlk.isEmpty(revRow)) {
438-
outBlk.set(r, inBlk.get(revRow), true);
439-
}
440-
}
441-
}
435+
if( in.sparse )
436+
reverseSparse(in, out, startRow, endRow);
437+
else
438+
reverseDense(in, out, startRow, endRow);
442439
}));
443440
}
444441

@@ -2523,41 +2520,30 @@ public static int[] mergeNnzCounts(int[] cnt, int[] cnt2) {
25232520
return cnt;
25242521
}
25252522

2526-
private static void reverseDense(MatrixBlock in, MatrixBlock out) {
2523+
private static void reverseDense(MatrixBlock in, MatrixBlock out, int rl, int ru) {
25272524
final int m = in.rlen;
25282525
final int n = in.clen;
25292526

2530-
//set basic meta data and allocate output
2531-
out.sparse = false;
2532-
out.nonZeros = in.nonZeros;
2533-
out.allocateDenseBlock(false);
2534-
25352527
//copy all rows into target positions
25362528
if( n == 1 ) { //column vector
25372529
double[] a = in.getDenseBlockValues();
25382530
double[] c = out.getDenseBlockValues();
2539-
for( int i=0; i<m; i++ )
2531+
for( int i=rl; i<ru; i++ )
25402532
c[m-1-i] = a[i];
25412533
}
25422534
else { //general matrix case
25432535
DenseBlock a = in.getDenseBlock();
25442536
DenseBlock c = out.getDenseBlock();
2545-
for( int i=0; i<m; i++ ) {
2537+
for( int i=rl; i<ru; i++ ) {
25462538
final int ri = m - 1 - i;
25472539
System.arraycopy(a.values(i), a.pos(i), c.values(ri), c.pos(ri), n);
25482540
}
25492541
}
25502542
}
25512543

2552-
private static void reverseSparse(MatrixBlock in, MatrixBlock out) {
2544+
private static void reverseSparse(MatrixBlock in, MatrixBlock out, int rl, int ru) {
25532545
final int m = in.rlen;
25542546

2555-
//set basic meta data and allocate output
2556-
out.sparse = true;
2557-
out.nonZeros = in.nonZeros;
2558-
2559-
out.allocateSparseRowsBlock(false);
2560-
25612547
//copy all rows into target positions
25622548
SparseBlock a = in.getSparseBlock();
25632549
SparseBlock c = out.getSparseBlock();

src/test/java/org/apache/sysds/test/functions/reorg/FullReverseTest.java

Lines changed: 26 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@
2222
import java.util.HashMap;
2323

2424
import org.apache.sysds.common.Opcodes;
25-
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
2625
import org.junit.Assert;
2726
import org.junit.Test;
28-
import org.apache.sysds.api.DMLScript;
2927
import org.apache.sysds.common.Types.ExecMode;
3028
import org.apache.sysds.common.Types.ExecType;
3129
import org.apache.sysds.runtime.instructions.Instruction;
@@ -44,18 +42,16 @@ public class FullReverseTest extends AutomatedTestBase
4442
private final static String TEST_DIR = "functions/reorg/";
4543
private static final String TEST_CLASS_DIR = TEST_DIR + FullReverseTest.class.getSimpleName() + "/";
4644

47-
private final static int rows1 = 2017;
48-
private final static int cols1 = 1001;
45+
//single-threaded execution
46+
private final static int rows1 = 201;
47+
private final static int cols1 = 100;
48+
//multi-threaded / distributed execution
49+
private final static int rows2 = 2017;
50+
private final static int cols2 = 1001;
51+
4952
private final static double sparsity1 = 0.7;
5053
private final static double sparsity2 = 0.1;
5154

52-
// Multi-threading test parameters
53-
private final static int rows_mt = 5018; // Larger for multi-threading benefits
54-
private final static int cols_mt = 1001; // Larger for multi-threading benefits
55-
private final static int[] threadCounts = {1, 2, 4, 8};
56-
// Set global parallelism for SystemDS to enable multi-threading
57-
private final static int oldPar = InfrastructureAnalyzer.getLocalParallelism();
58-
5955
@Override
6056
public void setUp() {
6157
TestUtils.clearAssertionInformation();
@@ -65,97 +61,74 @@ public void setUp() {
6561

6662
@Test
6763
public void testReverseVectorDenseCP() {
68-
runReverseTest(TEST_NAME1, false, false, ExecType.CP);
64+
runReverseTest(TEST_NAME1, false, rows1, 1, ExecType.CP);
6965
}
7066

7167
@Test
7268
public void testReverseVectorSparseCP() {
73-
runReverseTest(TEST_NAME1, false, true, ExecType.CP);
69+
runReverseTest(TEST_NAME1, true, rows1, 1, ExecType.CP);
7470
}
7571

7672
@Test
7773
public void testReverseVectorDenseCPMultiThread() {
78-
runReverseTestMultiThread(TEST_NAME1, false, false, ExecType.CP);
74+
runReverseTest(TEST_NAME1, false, rows2, 1, ExecType.CP);
7975
}
8076

8177
@Test
8278
public void testReverseVectorSparseCPMultiThread() {
83-
runReverseTestMultiThread(TEST_NAME1, false, true, ExecType.CP);
84-
}
85-
86-
@Test
87-
public void testReverseVectorDenseSPMultiThread() {
88-
runReverseTestMultiThread(TEST_NAME1, false, false, ExecType.SPARK);
79+
runReverseTest(TEST_NAME1, true, rows2, 1, ExecType.CP);
8980
}
9081

9182
@Test
9283
public void testReverseVectorDenseSP() {
93-
runReverseTest(TEST_NAME1, false, false, ExecType.SPARK);
84+
runReverseTest(TEST_NAME1, false, rows2, 1, ExecType.SPARK);
9485
}
9586

9687
@Test
9788
public void testReverseVectorSparseSP() {
98-
runReverseTest(TEST_NAME1, false, true, ExecType.SPARK);
89+
runReverseTest(TEST_NAME1, true, rows2, 1, ExecType.SPARK);
9990
}
10091

10192
@Test
10293
public void testReverseMatrixDenseCP() {
103-
runReverseTest(TEST_NAME1, true, false, ExecType.CP);
94+
runReverseTest(TEST_NAME1, false, rows1, cols1, ExecType.CP);
10495
}
10596

10697
@Test
10798
public void testReverseMatrixSparseCP() {
108-
runReverseTest(TEST_NAME1, true, true, ExecType.CP);
99+
runReverseTest(TEST_NAME1, true, rows1, cols1, ExecType.CP);
109100
}
110101

111102
@Test
112103
public void testReverseMatrixDenseSP() {
113-
runReverseTest(TEST_NAME1, true, false, ExecType.SPARK);
104+
runReverseTest(TEST_NAME1, false, rows2, cols2, ExecType.SPARK);
114105
}
115106

116107
@Test
117108
public void testReverseMatrixSparseSP() {
118-
runReverseTest(TEST_NAME1, true, true, ExecType.SPARK);
109+
runReverseTest(TEST_NAME1, true, rows2, cols2, ExecType.SPARK);
119110
}
120111

121112
@Test
122113
public void testReverseVectorDenseRewriteCP() {
123-
runReverseTest(TEST_NAME2, false, false, ExecType.CP);
114+
runReverseTest(TEST_NAME2, false, rows1, 1, ExecType.CP);
124115
}
125116

126117
@Test
127118
public void testReverseMatrixDenseRewriteCP() {
128-
runReverseTest(TEST_NAME2, true, false, ExecType.CP);
129-
}
130-
119+
runReverseTest(TEST_NAME2, false, rows1, 1, ExecType.CP);
120+
}
131121

132-
/**
133-
*
134-
* @param sparseM1
135-
* @param sparseM2
136-
* @param instType
137-
*/
138-
private void runReverseTest(String testname, boolean matrix, boolean sparse, ExecType instType)
122+
private void runReverseTest(String testname, boolean sparse, int rows, int cols, ExecType instType)
139123
{
140-
//rtplatform for MR
141-
ExecMode platformOld = rtplatform;
142-
switch( instType ){
143-
case SPARK: rtplatform = ExecMode.SPARK; break;
144-
default: rtplatform = ExecMode.HYBRID; break;
145-
}
146-
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
147-
if( rtplatform == ExecMode.SPARK )
148-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
149-
124+
ExecMode platformOld = setExecMode(instType);
150125
String TEST_NAME = testname;
151126

152127
try
153128
{
154-
int cols = matrix ? cols1 : 1;
155129
double sparsity = sparse ? sparsity2 : sparsity1;
156130
getAndLoadTestConfiguration(TEST_NAME);
157131

158-
/* This is for running the junit test the new way, i.e., construct the arguments directly */
159132
String HOME = SCRIPT_DIR + TEST_DIR;
160133
fullDMLScriptName = HOME + TEST_NAME + ".dml";
161134
programArgs = new String[]{"-stats","-explain","-args", input("A"), output("B") };
@@ -164,10 +137,10 @@ private void runReverseTest(String testname, boolean matrix, boolean sparse, Exe
164137
rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
165138

166139
//generate actual dataset
167-
double[][] A = getRandomMatrix(rows1, cols, -1, 1, sparsity, 7);
140+
double[][] A = getRandomMatrix(rows, cols, -1, 1, sparsity, 7);
168141
writeInputMatrixWithMTD("A", A, true);
169142

170-
runTest(true, false, null, -1);
143+
runTest(true, false, null, -1);
171144
runRScript(true);
172145

173146
//compare matrices
@@ -181,85 +154,8 @@ private void runReverseTest(String testname, boolean matrix, boolean sparse, Exe
181154
else if ( instType == ExecType.SPARK )
182155
Assert.assertTrue("Missing opcode: "+Instruction.SP_INST_PREFIX+Opcodes.REV.toString(), Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX+Opcodes.REV));
183156
}
184-
finally
185-
{
186-
//reset flags
187-
rtplatform = platformOld;
188-
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
189-
}
190-
}
191-
192-
private void runReverseTestMultiThread(String testname, boolean matrix, boolean sparse, ExecType instType)
193-
{
194-
// Compare single-thread vs multi-thread results
195-
// HashMap<CellIndex, Double> stResult = runReverseWithThreads(testname, matrix, sparse, instType, 1);
196-
HashMap<CellIndex, Double> mtResult = runReverseWithThreads(testname, matrix, sparse, instType, 8);
197-
198-
// Compare results to ensure consistency
199-
// TestUtils.compareMatrices(stResult, mtResult, 0, "ST-Result", "MT-Result");
200-
}
201-
202-
private HashMap<CellIndex, Double> runReverseWithThreads(String testname, boolean matrix, boolean sparse, ExecType instType, int numThreads)
203-
{
204-
//rtplatform for MR
205-
ExecMode platformOld = rtplatform;
206-
switch( instType ){
207-
case SPARK: rtplatform = ExecMode.SPARK; break;
208-
default: rtplatform = ExecMode.HYBRID; break;
209-
}
210-
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
211-
if( rtplatform == ExecMode.SPARK )
212-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
213-
214-
String TEST_NAME = testname;
215-
216-
System.out.println("I am trying to run multi-thread");
217-
218-
try
219-
{
220-
System.setProperty("sysds.parallel.threads", String.valueOf(numThreads));
221-
222-
// int cols = matrix ? cols_mt : 1;
223-
double sparsity = sparse ? sparsity2 : sparsity1;
224-
getAndLoadTestConfiguration(TEST_NAME);
225-
226-
/* This is for running the junit test the new way, i.e., construct the arguments directly */
227-
String HOME = SCRIPT_DIR + TEST_DIR;
228-
fullDMLScriptName = HOME + TEST_NAME + ".dml";
229-
230-
// Add thread count to program arguments
231-
programArgs = new String[]{"-stats","-explain","-args", input("A"), output("B") };
232-
233-
fullRScriptName = HOME + TEST_NAME + ".R";
234-
rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
235-
236-
//generate actual dataset
237-
double[][] A = getRandomMatrix(rows_mt, cols_mt, -1, 1, sparsity, 7);
238-
writeInputMatrixWithMTD("A", A, true);
239-
240-
// Run with specified thread count (this is the key part)
241-
runTest(true, false, null, -1);
242-
243-
//read and return results
244-
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("B");
245-
246-
//check generated opcode
247-
if( instType == ExecType.CP )
248-
Assert.assertTrue("Missing opcode: rev", Statistics.getCPHeavyHitterOpCodes().contains(Opcodes.REV.toString()));
249-
else if ( instType == ExecType.SPARK )
250-
Assert.assertTrue("Missing opcode: "+Instruction.SP_INST_PREFIX+Opcodes.REV.toString(), Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX+Opcodes.REV));
251-
252-
return dmlfile;
253-
}
254-
catch(Exception ex) {
255-
throw new RuntimeException(ex);
256-
}
257157
finally {
258-
//reset flags
259-
rtplatform = platformOld;
260-
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
261-
System.setProperty("sysds.parallel.threads", String.valueOf(oldPar));
158+
resetExecMode(platformOld);
262159
}
263160
}
264-
265-
}
161+
}

0 commit comments

Comments
 (0)