Skip to content

Commit 0825f0f

Browse files
min-gukmboehm7
authored andcommitted
[SYSTEMDS-3729] Add roll reorg operations in SP
Closes #2112.
1 parent 95c74be commit 0825f0f

File tree

6 files changed

+286
-53
lines changed

6 files changed

+286
-53
lines changed

src/main/java/org/apache/sysds/hops/AggUnaryOp.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
3232
import org.apache.sysds.lops.Lop;
3333
import org.apache.sysds.common.Types.ExecType;
34-
import org.apache.sysds.lops.Nary;
3534
import org.apache.sysds.lops.PartialAggregate;
3635
import org.apache.sysds.lops.TernaryAggregate;
3736
import org.apache.sysds.lops.UAggOuterChain;

src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ public class SPInstructionParser extends InstructionParser
166166
// Reorg Instruction Opcodes (repositioning of existing values)
167167
String2SPInstructionType.put( "r'", SPType.Reorg);
168168
String2SPInstructionType.put( "rev", SPType.Reorg);
169+
String2SPInstructionType.put( "roll", SPType.Reorg);
169170
String2SPInstructionType.put( "rdiag", SPType.Reorg);
170171
String2SPInstructionType.put( "rshape", SPType.MatrixReshape);
171172
String2SPInstructionType.put( "rsort", SPType.Reorg);

src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
3737
import org.apache.sysds.runtime.functionobjects.DiagIndex;
3838
import org.apache.sysds.runtime.functionobjects.RevIndex;
39+
import org.apache.sysds.runtime.functionobjects.RollIndex;
3940
import org.apache.sysds.runtime.functionobjects.SortIndex;
4041
import org.apache.sysds.runtime.functionobjects.SwapIndex;
4142
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -68,6 +69,7 @@ public class ReorgSPInstruction extends UnarySPInstruction {
6869
private CPOperand _desc = null;
6970
private CPOperand _ixret = null;
7071
private boolean _bSortIndInMem = false;
72+
private CPOperand _shift = null;
7173

7274
private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) {
7375
super(SPType.Reorg, op, in, out, opcode, istr);
@@ -82,6 +84,11 @@ private ReorgSPInstruction(Operator op, CPOperand in, CPOperand col, CPOperand d
8284
_bSortIndInMem = bSortIndInMem;
8385
}
8486

87+
private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) {
88+
this(op, in, out, opcode, istr);
89+
_shift = shift;
90+
}
91+
8592
public static ReorgSPInstruction parseInstruction ( String str ) {
8693
CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
8794
CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
@@ -95,6 +102,15 @@ else if ( opcode.equalsIgnoreCase("rev") ) {
95102
parseUnaryInstruction(str, in, out); //max 2 operands
96103
return new ReorgSPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
97104
}
105+
else if (opcode.equalsIgnoreCase("roll")) {
106+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
107+
InstructionUtils.checkNumFields(str, 3);
108+
in.split(parts[1]);
109+
out.split(parts[3]);
110+
CPOperand shift = new CPOperand(parts[2]);
111+
return new ReorgSPInstruction(new ReorgOperator(new RollIndex(0)),
112+
in, out, shift, opcode, str);
113+
}
98114
else if ( opcode.equalsIgnoreCase("rdiag") ) {
99115
parseUnaryInstruction(str, in, out); //max 2 operands
100116
return new ReorgSPInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
@@ -141,6 +157,14 @@ else if( opcode.equalsIgnoreCase("rev") ) //REVERSE
141157
if( mcIn.getRows() % mcIn.getBlocksize() != 0 )
142158
out = RDDAggregateUtils.mergeByKey(out, false);
143159
}
160+
else if (opcode.equalsIgnoreCase("roll")) // ROLL
161+
{
162+
int shift = (int) ec.getScalarInput(_shift).getLongValue();
163+
164+
//execute roll reorg operation
165+
out = in1.flatMapToPair(new RDDRollFunction(mcIn, shift));
166+
out = RDDAggregateUtils.mergeByKey(out, false);
167+
}
144168
else if ( opcode.equalsIgnoreCase("rdiag") ) // DIAG
145169
{
146170
if(mcIn.getCols() == 1) { // diagV2M
@@ -233,7 +257,7 @@ else if ( getOpcode().equalsIgnoreCase("rsort") ) {
233257
boolean ixret = sec.getScalarInput(_ixret).getBooleanValue();
234258
mcOut.set(mc1.getRows(), ixret?1:mc1.getCols(), mc1.getBlocksize(), mc1.getBlocksize());
235259
}
236-
else { //e.g., rev
260+
else { //e.g., rev, roll
237261
mcOut.set(mc1);
238262
}
239263
}
@@ -243,7 +267,7 @@ else if ( getOpcode().equalsIgnoreCase("rsort") ) {
243267
boolean sortIx = getOpcode().equalsIgnoreCase("rsort") && sec.getScalarInput(_ixret).getBooleanValue();
244268
if( sortIx )
245269
mcOut.setNonZeros(mc1.getRows());
246-
else //default (r', rdiag, rev, rsort data)
270+
else //default (r', rdiag, rev, roll, rsort data)
247271
mcOut.setNonZeros(mc1.getNonZeros());
248272
}
249273
}
@@ -315,6 +339,31 @@ public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( Tuple2<MatrixIndexes,
315339
}
316340
}
317341

342+
private static class RDDRollFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
343+
private static final long serialVersionUID = 1183373828539843938L;
344+
345+
private DataCharacteristics _mcIn = null;
346+
private int _shift = 0;
347+
348+
public RDDRollFunction(DataCharacteristics mcIn, int shift) {
349+
_mcIn = mcIn;
350+
_shift = shift;
351+
}
352+
353+
@Override
354+
public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) {
355+
//construct input
356+
IndexedMatrixValue in = SparkUtils.toIndexedMatrixBlock(arg0);
357+
358+
//execute roll operation
359+
ArrayList<IndexedMatrixValue> out = new ArrayList<>();
360+
LibMatrixReorg.roll(in, _mcIn.getRows(), _mcIn.getBlocksize(), _shift, out);
361+
362+
//construct output
363+
return SparkUtils.fromIndexedMatrixBlock(out).iterator();
364+
}
365+
}
366+
318367
private static class ExtractColumn implements Function<MatrixBlock, MatrixBlock>
319368
{
320369
private static final long serialVersionUID = -1472164797288449559L;

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java

Lines changed: 88 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,36 @@ public static MatrixBlock roll(MatrixBlock in, MatrixBlock out, int shift) {
445445
return out;
446446
}
447447

448+
public static void roll(IndexedMatrixValue in, long rlen, int blen, int shift, ArrayList<IndexedMatrixValue> out) {
449+
MatrixIndexes inMtxIdx = in.getIndexes();
450+
MatrixBlock inMtxBlk = (MatrixBlock) in.getValue();
451+
shift %= ((rlen != 0) ? (int) rlen : 1); // Handle row length boundaries for shift
452+
453+
long inRowIdx = UtilFunctions.computeCellIndex(inMtxIdx.getRowIndex(), blen, 0) - 1;
454+
455+
int totalCopyLen = 0;
456+
while (totalCopyLen < inMtxBlk.getNumRows()) {
457+
// Calculate row and block index for the current part
458+
long outRowIdx = (inRowIdx + shift) % rlen;
459+
long outBlkIdx = UtilFunctions.computeBlockIndex(outRowIdx + 1, blen);
460+
int outBlkLen = UtilFunctions.computeBlockSize(rlen, outBlkIdx, blen);
461+
int outRowIdxInBlk = (int) (outRowIdx % blen);
462+
463+
// Calculate copy length
464+
int copyLen = Math.min((int) (outBlkLen - outRowIdxInBlk), inMtxBlk.getNumRows() - totalCopyLen);
465+
466+
// Create the output block and copy data
467+
MatrixIndexes outMtxIdx = new MatrixIndexes(outBlkIdx, inMtxIdx.getColumnIndex());
468+
MatrixBlock outMtxBlk = new MatrixBlock(outBlkLen, inMtxBlk.getNumColumns(), inMtxBlk.isInSparseFormat());
469+
copyMtx(inMtxBlk, outMtxBlk, totalCopyLen, outRowIdxInBlk, copyLen, false, false);
470+
out.add(new IndexedMatrixValue(outMtxIdx, outMtxBlk));
471+
472+
// Update counters for next iteration
473+
totalCopyLen += copyLen;
474+
inRowIdx += totalCopyLen;
475+
}
476+
}
477+
448478
public static MatrixBlock diag( MatrixBlock in, MatrixBlock out ) {
449479
//Timing time = new Timing(true);
450480

@@ -2274,77 +2304,85 @@ private static void reverseSparse(MatrixBlock in, MatrixBlock out) {
22742304

22752305
private static void rollDense(MatrixBlock in, MatrixBlock out, int shift) {
22762306
final int m = in.rlen;
2277-
final int n = in.clen;
2307+
shift %= (m != 0 ? m : 1); // roll matrix with axis=none
22782308

2279-
//set basic meta data and allocate output
2280-
out.sparse = false;
2281-
out.nonZeros = in.nonZeros;
2282-
out.allocateDenseBlock(false);
2309+
copyDenseMtx(in, out, 0, shift, m - shift, false, true);
2310+
copyDenseMtx(in, out, m - shift, 0, shift, true, true);
2311+
}
22832312

2284-
//copy all rows into target positions
2285-
if (n == 1) { //column vector
2313+
private static void rollSparse(MatrixBlock in, MatrixBlock out, int shift) {
2314+
final int m = in.rlen;
2315+
shift %= (m != 0 ? m : 1); // roll matrix with axis=0
2316+
2317+
copySparseMtx(in, out, 0, shift, m - shift, false, true);
2318+
copySparseMtx(in, out, m-shift, 0, shift, false, true);
2319+
}
2320+
2321+
public static void copyMtx(MatrixBlock in, MatrixBlock out, int inStart, int outStart, int copyLen,
2322+
boolean isAllocated, boolean copyTotalNonZeros) {
2323+
if (in.isInSparseFormat()){
2324+
copySparseMtx(in, out, inStart, outStart, copyLen, isAllocated, copyTotalNonZeros);
2325+
} else {
2326+
copyDenseMtx(in, out, inStart, outStart, copyLen, isAllocated, copyTotalNonZeros);
2327+
}
2328+
}
2329+
2330+
public static void copyDenseMtx(MatrixBlock in, MatrixBlock out, int inIdx, int outIdx, int copyLen,
2331+
boolean isAllocated, boolean copyTotalNonZeros) {
2332+
int clen = in.clen;
2333+
2334+
// set basic meta data and allocate output
2335+
if (!isAllocated){
2336+
out.sparse = false;
2337+
if (copyTotalNonZeros) out.nonZeros = in.nonZeros;
2338+
out.allocateDenseBlock(false);
2339+
}
2340+
2341+
// copy all rows into target positions
2342+
if (clen == 1) { //column vector
22862343
double[] a = in.getDenseBlockValues();
22872344
double[] c = out.getDenseBlockValues();
22882345

2289-
// roll matrix with axis=none
2290-
shift %= (m != 0 ? m : 1);
2291-
2292-
System.arraycopy(a, 0, c, shift, m - shift);
2293-
System.arraycopy(a, m - shift, c, 0, shift);
2294-
} else { //general matrix case
2346+
System.arraycopy(a, inIdx, c, outIdx, copyLen);
2347+
} else {
22952348
DenseBlock a = in.getDenseBlock();
22962349
DenseBlock c = out.getDenseBlock();
22972350

2298-
// roll matrix with axis=0
2299-
shift %= (m != 0 ? m : 1);
2351+
while (copyLen > 0) {
2352+
System.arraycopy(a.values(inIdx), a.pos(inIdx),
2353+
c.values(outIdx), c.pos(outIdx), clen);
23002354

2301-
for (int i = 0; i < m - shift; i++) {
2302-
System.arraycopy(a.values(i), a.pos(i), c.values(i + shift), c.pos(i + shift), n);
2303-
}
2304-
2305-
for (int i = m - shift; i < m; i++) {
2306-
System.arraycopy(a.values(i), a.pos(i), c.values(i + shift - m), c.pos(i + shift - m), n);
2355+
inIdx++; outIdx++; copyLen--;
23072356
}
23082357
}
23092358
}
23102359

2311-
private static void rollSparse(MatrixBlock in, MatrixBlock out, int shift) {
2312-
final int m = in.rlen;
2313-
2360+
private static void copySparseMtx(MatrixBlock in, MatrixBlock out, int inIdx, int outIdx, int copyLen,
2361+
boolean isAllocated, boolean copyTotalNonZeros) {
23142362
//set basic meta data and allocate output
2315-
out.sparse = true;
2316-
out.nonZeros = in.nonZeros;
2317-
out.allocateSparseRowsBlock(false);
2363+
if (!isAllocated){
2364+
out.sparse = true;
2365+
if (copyTotalNonZeros) out.nonZeros = in.nonZeros;
2366+
out.allocateSparseRowsBlock(false);
2367+
}
23182368

2319-
//copy all rows into target positions
23202369
SparseBlock a = in.getSparseBlock();
23212370
SparseBlock c = out.getSparseBlock();
23222371

2323-
// roll matrix with axis=0
2324-
shift %= (m != 0 ? m : 1);
2325-
2326-
for (int i = 0; i < m - shift; i++) {
2327-
if (a.isEmpty(i)) continue; // skip empty rows
2372+
while (copyLen > 0) {
2373+
if (a.isEmpty(inIdx)) continue; // skip empty rows
23282374

2329-
rollSparseRow(a, c, i, i + shift);
2330-
}
2331-
2332-
for (int i = m - shift; i < m; i++) {
2333-
if (a.isEmpty(i)) continue; // skip empty rows
2334-
2335-
rollSparseRow(a, c, i, i + shift - m);
2336-
}
2337-
}
2375+
final int apos = a.pos(inIdx);
2376+
final int alen = a.size(inIdx) + apos;
2377+
final int[] aix = a.indexes(inIdx);
2378+
final double[] avals = a.values(inIdx);
23382379

2339-
private static void rollSparseRow(SparseBlock a, SparseBlock c, int oriIdx, int shiftIdx) {
2340-
final int apos = a.pos(oriIdx);
2341-
final int alen = a.size(oriIdx) + apos;
2342-
final int[] aix = a.indexes(oriIdx);
2343-
final double[] avals = a.values(oriIdx);
2380+
// copy only non-zero elements
2381+
for (int k = apos; k < alen; k++) {
2382+
c.set(outIdx, aix[k], avals[k]);
2383+
}
23442384

2345-
// copy only non-zero elements
2346-
for (int k = apos; k < alen; k++) {
2347-
c.set(shiftIdx, aix[k], avals[k]);
2385+
inIdx++; outIdx++; copyLen--;
23482386
}
23492387
}
23502388

0 commit comments

Comments
 (0)