Skip to content

Commit 559f770

Browse files
j143mboehm7
authored andcommitted
[SYSTEMDS-3730] Multi-threaded reverse operations
Closes #2290.
1 parent b7e101e commit 559f770

File tree

5 files changed

+188
-11
lines changed

5 files changed

+188
-11
lines changed

src/main/java/org/apache/sysds/hops/ReorgOp.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ public boolean isGPUEnabled() {
118118
@Override
119119
public boolean isMultiThreadedOpType() {
120120
return _op == ReOrgOp.TRANS
121-
|| _op == ReOrgOp.SORT;
121+
|| _op == ReOrgOp.SORT
122+
|| _op == ReOrgOp.REV;
122123
}
123124

124125
@Override
@@ -148,11 +149,22 @@ else if( getDim1()==1 && getDim2()==1 )
148149
}
149150
break;
150151
}
151-
case DIAG:
152+
case DIAG: {
153+
Transform transform1 = new Transform(
154+
getInput().get(0).constructLops(),
155+
_op, getDataType(), getValueType(), et);
156+
setOutputDimensions(transform1);
157+
setLineNumbers(transform1);
158+
setLops(transform1);
159+
break;
160+
}
152161
case REV: {
162+
long numel = getDim1() * getDim2();
163+
int k = (numel < 3000_000) ?
164+
1 : OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
153165
Transform transform1 = new Transform(
154166
getInput().get(0).constructLops(),
155-
_op, getDataType(), getValueType(), et);
167+
_op, getDataType(), getValueType(), et, k);
156168
setOutputDimensions(transform1);
157169
setLineNumbers(transform1);
158170
setLops(transform1);

src/main/java/org/apache/sysds/lops/Transform.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ private String getInstructions(String input1, int numInputs, String output) {
180180
sb.append( this.prepOutputOperand(output));
181181

182182
if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED)
183-
&& (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.SORT) ) {
183+
&& (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.REV || _operation == ReOrgOp.SORT) ) {
184184
sb.append( OPERAND_DELIMITOR );
185185
sb.append( _numThreads );
186186
if ( getExecType()==ExecType.FED ) {

src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,13 @@ public static ReorgCPInstruction parseInstruction ( String str ) {
109109
return new ReorgCPInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str);
110110
}
111111
else if ( opcode.equalsIgnoreCase(Opcodes.REV.toString()) ) {
112-
parseUnaryInstruction(str, in, out); //max 2 operands
113-
return new ReorgCPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
112+
InstructionUtils.checkNumFields(str, 2, 3);
113+
in.split(parts[1]);
114+
out.split(parts[2]);
115+
// Safely parse the number of threads 'k' if it exists
116+
int k = (parts.length > 3) ? Integer.parseInt(parts[3]) : 1;
117+
// Create the instruction, passing 'k' to the operator
118+
return new ReorgCPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject(), k), in, out, opcode, str);
114119
}
115120
else if (opcode.equalsIgnoreCase(Opcodes.ROLL.toString())) {
116121
InstructionUtils.checkNumFields(str, 3);

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,10 @@ public static MatrixBlock reorg( MatrixBlock in, MatrixBlock out, ReorgOperator
128128
else
129129
return transpose(in, out);
130130
case REV:
131-
return rev(in, out);
131+
if (op.getNumThreads() > 1)
132+
return rev(in, out, op.getNumThreads());
133+
else
134+
return rev(in, out);
132135
case ROLL:
133136
RollIndex rix = (RollIndex) op.fn;
134137
return roll(in, out, rix.getShift());
@@ -389,10 +392,72 @@ public static MatrixBlock rev( MatrixBlock in, MatrixBlock out ) {
389392
return out;
390393
}
391394

395+
public static MatrixBlock rev(MatrixBlock in, MatrixBlock out, int k) {
396+
if (k <= 1 || in.isEmptyBlock(false) ) {
397+
return rev(in, out); // fallback to single-threaded
398+
399+
}
400+
final int numRows = in.getNumRows();
401+
final int numCols = in.getNumColumns();
402+
final boolean sparse = in.isInSparseFormat();
403+
404+
// Prepare output block
405+
out.reset(numRows, numCols, sparse);
406+
407+
// Before starting threads, ensure the output sparse block is allocated!
408+
if (sparse) {
409+
out.allocateSparseRowsBlock(false);
410+
}
411+
412+
// Set up thread pool
413+
ExecutorService pool = CommonThreadPool.get(k);
414+
try {
415+
int blklen = (int) Math.ceil((double) numRows / k);
416+
List<Future<?>> tasks = new ArrayList<>();
417+
418+
for (int i = 0; i < k; i++) {
419+
final int startRow = i * blklen;
420+
final int endRow = Math.min((i + 1) * blklen, numRows);
421+
422+
tasks.add(pool.submit(() -> {
423+
if (!sparse) {
424+
// Dense case
425+
double[] inVals = in.getDenseBlockValues();
426+
double[] outVals = out.getDenseBlockValues();
427+
for (int r = startRow; r < endRow; r++) {
428+
int revRow = numRows - r - 1;
429+
System.arraycopy(inVals, revRow * numCols, outVals, r * numCols, numCols);
430+
}
431+
} else {
432+
// Sparse case
433+
SparseBlock inBlk = in.getSparseBlock();
434+
SparseBlock outBlk = out.getSparseBlock();
435+
for (int r = startRow; r < endRow; r++) {
436+
int revRow = numRows - r - 1;
437+
if (!inBlk.isEmpty(revRow)) {
438+
outBlk.set(r, inBlk.get(revRow), true);
439+
}
440+
}
441+
}
442+
}));
443+
}
444+
445+
// Wait for all threads
446+
for (Future<?> task : tasks) {
447+
task.get();
448+
}
449+
} catch (Exception ex) {
450+
throw new DMLRuntimeException(ex);
451+
} finally {
452+
pool.shutdown();
453+
}
454+
return out;
455+
}
456+
392457
public static void rev( IndexedMatrixValue in, long rlen, int blen, ArrayList<IndexedMatrixValue> out ) {
393458
//input block reverse
394459
MatrixIndexes inix = in.getIndexes();
395-
MatrixBlock inblk = (MatrixBlock) in.getValue();
460+
MatrixBlock inblk = (MatrixBlock) in.getValue();
396461
MatrixBlock tmpblk = rev(inblk, new MatrixBlock(inblk.getNumRows(), inblk.getNumColumns(), inblk.isInSparseFormat()));
397462

398463
//split and expand block if necessary (at most 2 blocks)

src/test/java/org/apache/sysds/test/functions/reorg/FullReverseTest.java

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.HashMap;
2323

2424
import org.apache.sysds.common.Opcodes;
25+
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
2526
import org.junit.Assert;
2627
import org.junit.Test;
2728
import org.apache.sysds.api.DMLScript;
@@ -44,10 +45,17 @@ public class FullReverseTest extends AutomatedTestBase
4445
private static final String TEST_CLASS_DIR = TEST_DIR + FullReverseTest.class.getSimpleName() + "/";
4546

4647
private final static int rows1 = 2017;
47-
private final static int cols1 = 1001;
48+
private final static int cols1 = 1001;
4849
private final static double sparsity1 = 0.7;
4950
private final static double sparsity2 = 0.1;
5051

52+
// Multi-threading test parameters
53+
private final static int rows_mt = 5018; // Larger for multi-threading benefits
54+
private final static int cols_mt = 1001; // Larger for multi-threading benefits
55+
private final static int[] threadCounts = {1, 2, 4, 8};
56+
// Set global parallelism for SystemDS to enable multi-threading
57+
private final static int oldPar = InfrastructureAnalyzer.getLocalParallelism();
58+
5159
@Override
5260
public void setUp() {
5361
TestUtils.clearAssertionInformation();
@@ -64,7 +72,22 @@ public void testReverseVectorDenseCP() {
6472
public void testReverseVectorSparseCP() {
6573
runReverseTest(TEST_NAME1, false, true, ExecType.CP);
6674
}
67-
75+
76+
@Test
77+
public void testReverseVectorDenseCPMultiThread() {
78+
runReverseTestMultiThread(TEST_NAME1, false, false, ExecType.CP);
79+
}
80+
81+
@Test
82+
public void testReverseVectorSparseCPMultiThread() {
83+
runReverseTestMultiThread(TEST_NAME1, false, true, ExecType.CP);
84+
}
85+
86+
@Test
87+
public void testReverseVectorDenseSPMultiThread() {
88+
runReverseTestMultiThread(TEST_NAME1, false, false, ExecType.SPARK);
89+
}
90+
6891
@Test
6992
public void testReverseVectorDenseSP() {
7093
runReverseTest(TEST_NAME1, false, false, ExecType.SPARK);
@@ -165,6 +188,78 @@ else if ( instType == ExecType.SPARK )
165188
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
166189
}
167190
}
168-
191+
192+
private void runReverseTestMultiThread(String testname, boolean matrix, boolean sparse, ExecType instType)
193+
{
194+
// Compare single-thread vs multi-thread results
195+
// HashMap<CellIndex, Double> stResult = runReverseWithThreads(testname, matrix, sparse, instType, 1);
196+
HashMap<CellIndex, Double> mtResult = runReverseWithThreads(testname, matrix, sparse, instType, 8);
197+
198+
// Compare results to ensure consistency
199+
// TestUtils.compareMatrices(stResult, mtResult, 0, "ST-Result", "MT-Result");
200+
}
201+
202+
private HashMap<CellIndex, Double> runReverseWithThreads(String testname, boolean matrix, boolean sparse, ExecType instType, int numThreads)
203+
{
204+
//rtplatform for MR
205+
ExecMode platformOld = rtplatform;
206+
switch( instType ){
207+
case SPARK: rtplatform = ExecMode.SPARK; break;
208+
default: rtplatform = ExecMode.HYBRID; break;
209+
}
210+
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
211+
if( rtplatform == ExecMode.SPARK )
212+
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
213+
214+
String TEST_NAME = testname;
215+
216+
System.out.println("I am trying to run multi-thread");
217+
218+
try
219+
{
220+
System.setProperty("sysds.parallel.threads", String.valueOf(numThreads));
221+
222+
// int cols = matrix ? cols_mt : 1;
223+
double sparsity = sparse ? sparsity2 : sparsity1;
224+
getAndLoadTestConfiguration(TEST_NAME);
225+
226+
/* This is for running the junit test the new way, i.e., construct the arguments directly */
227+
String HOME = SCRIPT_DIR + TEST_DIR;
228+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
229+
230+
// Add thread count to program arguments
231+
programArgs = new String[]{"-stats","-explain","-args", input("A"), output("B") };
232+
233+
fullRScriptName = HOME + TEST_NAME + ".R";
234+
rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
235+
236+
//generate actual dataset
237+
double[][] A = getRandomMatrix(rows_mt, cols_mt, -1, 1, sparsity, 7);
238+
writeInputMatrixWithMTD("A", A, true);
239+
240+
// Run with specified thread count (this is the key part)
241+
runTest(true, false, null, -1);
242+
243+
//read and return results
244+
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("B");
245+
246+
//check generated opcode
247+
if( instType == ExecType.CP )
248+
Assert.assertTrue("Missing opcode: rev", Statistics.getCPHeavyHitterOpCodes().contains(Opcodes.REV.toString()));
249+
else if ( instType == ExecType.SPARK )
250+
Assert.assertTrue("Missing opcode: "+Instruction.SP_INST_PREFIX+Opcodes.REV.toString(), Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX+Opcodes.REV));
251+
252+
return dmlfile;
253+
}
254+
catch(Exception ex) {
255+
throw new RuntimeException(ex);
256+
}
257+
finally {
258+
//reset flags
259+
rtplatform = platformOld;
260+
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
261+
System.setProperty("sysds.parallel.threads", String.valueOf(oldPar));
262+
}
263+
}
169264

170265
}

0 commit comments

Comments
 (0)