Skip to content

Commit 29b3c61

Browse files
min-gukmboehm7
authored andcommitted
[SYSTEMDS-3729] Add missing federated roll reorg operations
Closes #2126.
1 parent 80332e0 commit 29b3c61

File tree

9 files changed

+422
-10
lines changed

9 files changed

+422
-10
lines changed

src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,14 +406,40 @@ public Future<FederatedResponse>[] executeMultipleSlices(long tid, boolean wait,
406406
return ret.toArray(new Future[0]);
407407
}
408408

409+
@SuppressWarnings("unchecked")
410+
public Future<FederatedResponse>[] executeRoll(long tid, boolean wait,
411+
FederatedRequest frEnd, FederatedRequest frStart, long rlen)
412+
{
413+
// executes step1[] - step 2 - ... step4 (only first step federated-data-specific)
414+
setThreadID(tid, new FederatedRequest[]{frStart, frEnd});
415+
List<Future<FederatedResponse>> ret = new ArrayList<>();
416+
417+
for(Pair<FederatedRange, FederatedData> e : _fedMap) {
418+
if (e.getKey().getEndDims()[0] == rlen) {
419+
ret.add(e.getValue().executeFederatedOperation(frEnd));
420+
} else if (e.getKey().getBeginDims()[0] == 0){
421+
ret.add(e.getValue().executeFederatedOperation(frStart));
422+
}
423+
}
424+
425+
// prepare results (future federated responses), with optional wait to ensure the
426+
// order of requests without data dependencies (e.g., cleanup RPCs)
427+
if(wait)
428+
FederationUtils.waitFor(ret);
429+
return (Future<FederatedResponse>[])ret.toArray(new Future[0]);
430+
}
431+
409432
public List<Pair<FederatedRange, Future<FederatedResponse>>> requestFederatedData() {
410433
if(!isInitialized())
411434
throw new DMLRuntimeException("Federated matrix read only supported on initialized FederatedData");
412435

413436
List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = new ArrayList<>();
414-
FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, _ID);
415-
for(Pair<FederatedRange, FederatedData> e : _fedMap)
437+
438+
for(Pair<FederatedRange, FederatedData> e : _fedMap){
439+
FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, e.getValue().getVarID());
416440
readResponses.add(Pair.of(e.getKey(), e.getValue().executeFederatedOperation(request)));
441+
}
442+
417443
return readResponses;
418444
}
419445

@@ -692,6 +718,7 @@ public void reverseFedMap() {
692718
}
693719
}
694720

721+
695722
private static class MappingTask implements Callable<Void> {
696723
private final FederatedRange _range;
697724
private final FederatedData _data;

src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ public class FEDInstructionParser extends InstructionParser
8686
String2FEDInstructionType.put( "r'" , FEDType.Reorg );
8787
String2FEDInstructionType.put( "rdiag" , FEDType.Reorg );
8888
String2FEDInstructionType.put( "rev" , FEDType.Reorg );
89+
String2FEDInstructionType.put( "roll" , FEDType.Reorg );
8990
//String2FEDInstructionType.put( "rshape" , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser!
9091
//String2FEDInstructionType.put( "rsort" , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser!
9192

src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand c
8686
* @param istr ?
8787
*/
8888
private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) {
89-
super(CPType.Reorg, op, in, out, opcode, istr);
89+
super(CPType.Reorg, op, in, shift, out, opcode, istr);
9090
_col = null;
9191
_desc = null;
9292
_ixret = null;

src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java

Lines changed: 141 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@
3636
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
3737
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
3838
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
39+
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
3940
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
4041
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
4142
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
4243
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
4344
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
4445
import org.apache.sysds.runtime.functionobjects.DiagIndex;
4546
import org.apache.sysds.runtime.functionobjects.RevIndex;
47+
import org.apache.sysds.runtime.functionobjects.RollIndex;
4648
import org.apache.sysds.runtime.functionobjects.SwapIndex;
4749
import org.apache.sysds.runtime.instructions.InstructionUtils;
4850
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -57,6 +59,8 @@
5759
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
5860

5961
public class ReorgFEDInstruction extends UnaryFEDInstruction {
62+
// roll-specific attributes
63+
private CPOperand _shift = null;
6064

6165
public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
6266
super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
@@ -66,14 +70,29 @@ public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opc
6670
super(FEDType.Reorg, op, in1, out, opcode, istr);
6771
}
6872

73+
private ReorgFEDInstruction(Operator op, CPOperand in, CPOperand shift, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
74+
super(FEDType.Reorg, op, in, shift, out, opcode, istr, fedOut);
75+
_shift = shift;
76+
}
77+
6978
public static ReorgFEDInstruction parseInstruction(ReorgCPInstruction rinst) {
70-
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(),
71-
rinst.getInstructionString(), FederatedOutput.NONE);
79+
if (rinst.input2 != null) {
80+
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(),
81+
rinst.getInstructionString(), FederatedOutput.NONE);
82+
} else{
83+
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(),
84+
rinst.getInstructionString(), FederatedOutput.NONE);
85+
}
7286
}
7387

7488
public static ReorgFEDInstruction parseInstruction(ReorgSPInstruction rinst) {
75-
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(),
76-
rinst.getInstructionString(), FederatedOutput.NONE);
89+
if (rinst.input2 != null) {
90+
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(),
91+
rinst.getInstructionString(), FederatedOutput.NONE);
92+
} else{
93+
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(),
94+
rinst.getInstructionString(), FederatedOutput.NONE);
95+
}
7796
}
7897

7998
public static ReorgFEDInstruction parseInstruction(String str) {
@@ -105,6 +124,15 @@ else if(opcode.equalsIgnoreCase("rev")) {
105124
return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str,
106125
fedOut);
107126
}
127+
else if (opcode.equalsIgnoreCase("roll")) {
128+
InstructionUtils.checkNumFields(str, 3);
129+
in.split(parts[1]);
130+
out.split(parts[3]);
131+
CPOperand shift = new CPOperand(parts[2]);
132+
fedOut = parseFedOutFlag(str, 3);
133+
return new ReorgFEDInstruction(new ReorgOperator(new RollIndex(0)),
134+
in, out, shift, opcode, str, fedOut);
135+
}
108136
else {
109137
throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: " + opcode);
110138
}
@@ -167,6 +195,36 @@ else if(instOpcode.equalsIgnoreCase("rev")) {
167195
.setBlocksize(mo1.getBlocksize()).setNonZeros(nnz);
168196
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
169197

198+
optionalForceLocal(out);
199+
} else if (instOpcode.equalsIgnoreCase("roll")) {
200+
long rlen = mo1.getNumRows();
201+
long shift = ec.getScalarInput(_shift).getLongValue();
202+
shift %= (rlen != 0 ? rlen : 1); // roll matrix with axis=none
203+
204+
long inID = mo1.getFedMapping().getID();
205+
long outEndID = FederationUtils.getNextFedDataID();
206+
long outStartID = FederationUtils.getNextFedDataID();
207+
208+
List<Pair<FederatedRange, FederatedData>> inMap = mo1.getFedMapping().getMap();
209+
Pair<FederationMap, Long> rollResult = rollFedMap(
210+
inMap, inID, outEndID, outStartID, shift, rlen, mo1.getFedMapping().getType());
211+
long length = rollResult.getValue();
212+
FederationMap outFedMap = rollResult.getKey();
213+
214+
FederatedRequest frEnd = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outEndID,
215+
new ReorgFEDInstruction.SliceMatrix(inID, outEndID, length, true));
216+
FederatedRequest frStart = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outStartID,
217+
new ReorgFEDInstruction.SliceMatrix(inID, outStartID, length, false));
218+
Future<FederatedResponse>[] ffr = outFedMap.executeRoll(getTID(), true, frEnd, frStart, rlen);
219+
220+
//derive output federated mapping
221+
MatrixObject out = ec.getMatrixObject(output);
222+
long nnz = (mo1.getNnz() != -1) ? mo1.getNnz() : FederationUtils.sumNonZeros(ffr);
223+
out.getDataCharacteristics()
224+
.setDimension(mo1.getNumRows(), mo1.getNumColumns())
225+
.setBlocksize(mo1.getBlocksize())
226+
.setNonZeros(nnz);
227+
out.setFedMapping(outFedMap);
170228
optionalForceLocal(out);
171229
}
172230
else if (instOpcode.equals("rdiag")) {
@@ -189,6 +247,40 @@ else if (instOpcode.equals("rdiag")) {
189247
}
190248
}
191249

250+
251+
public Pair<FederationMap, Long> rollFedMap(List<Pair<FederatedRange, FederatedData>> oldMap, long inID,
252+
long outEndID, long outStartID, long shift, long rlen, FType type) {
253+
List<Pair<FederatedRange, FederatedData>> map = new ArrayList<>();
254+
long length = 0;
255+
256+
for(Map.Entry<FederatedRange, FederatedData> e : oldMap) {
257+
if(e.getKey().getSize() == 0) continue;
258+
FederatedRange fedRange = new FederatedRange(e.getKey());
259+
long beginRow = fedRange.getBeginDims()[0] + shift;
260+
long endRow = fedRange.getEndDims()[0] + shift;
261+
262+
beginRow = beginRow > rlen ? beginRow - rlen : beginRow;
263+
endRow = endRow > rlen ? endRow - rlen : endRow;
264+
265+
if (beginRow < endRow) {
266+
fedRange.setBeginDim(0, beginRow);
267+
fedRange.setEndDim(0, endRow);
268+
map.add(Pair.of(fedRange, e.getValue().copyWithNewID(inID)));
269+
} else {
270+
length = rlen - beginRow;
271+
fedRange.setBeginDim(0, beginRow);
272+
fedRange.setEndDim(0, rlen);
273+
map.add(Pair.of(fedRange, e.getValue().copyWithNewID(outEndID)));
274+
275+
FederatedRange startRange = new FederatedRange(fedRange);
276+
startRange.setBeginDim(0, 0);
277+
startRange.setEndDim(0, endRow);
278+
map.add(Pair.of(startRange, e.getValue().copyWithNewID(outStartID)));
279+
}
280+
}
281+
return Pair.of(new FederationMap(outEndID, map, type), length);
282+
}
283+
192284
/**
193285
* Update the federated ranges of result and return the updated federation map.
194286
* @param result RdiagResult for which the fedmap is updated
@@ -307,6 +399,51 @@ private RdiagResult rdiagM2V (MatrixObject mo1, ReorgOperator r_op) {
307399
return new RdiagResult(diagFedMap, dcs);
308400
}
309401

402+
public static class SliceMatrix extends FederatedUDF {
403+
private static final long serialVersionUID = -3466926635958851402L;
404+
private final long _outputID;
405+
private final int _sliceRow;
406+
private final boolean _isRight;
407+
408+
private SliceMatrix(long input, long outputID, long sliceRow, boolean isRight) {
409+
super(new long[] {input});
410+
_outputID = outputID;
411+
_sliceRow = (int) sliceRow;
412+
_isRight = isRight;
413+
}
414+
415+
@Override
416+
public FederatedResponse execute(ExecutionContext ec, Data... data) {
417+
MatrixBlock oriBlock = ((MatrixObject) data[0]).acquireReadAndRelease();
418+
MatrixBlock resBlock;
419+
420+
if (_sliceRow != 0){
421+
if (_isRight){
422+
resBlock = oriBlock.slice(0, _sliceRow-1, 0,
423+
oriBlock.getNumColumns()-1, new MatrixBlock());
424+
} else{
425+
resBlock = oriBlock.slice(_sliceRow, oriBlock.getNumRows()-1,
426+
0, oriBlock.getNumColumns()-1, new MatrixBlock());
427+
}
428+
} else{
429+
resBlock = oriBlock;
430+
}
431+
ec.setMatrixOutput(String.valueOf(_outputID), resBlock);
432+
return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, resBlock);
433+
}
434+
435+
@Override
436+
public List<Long> getOutputIds() {
437+
return new ArrayList<>(Arrays.asList(_outputID));
438+
}
439+
440+
@Override
441+
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
442+
return Pair.of(String.valueOf(_outputID),
443+
new LineageItem());
444+
}
445+
}
446+
310447
public static class Rdiag extends FederatedUDF {
311448

312449
private static final long serialVersionUID = -3466926635958851402L;

src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ public static UnaryFEDInstruction parseInstruction(UnaryCPInstruction inst, Exec
8888
}
8989
}
9090
else if(inst instanceof ReorgCPInstruction &&
91-
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
91+
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag")
92+
|| inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) {
9293
ReorgCPInstruction rinst = (ReorgCPInstruction) inst;
9394
CacheableData<?> mo = ec.getCacheableData(rinst.input1);
9495

@@ -157,7 +158,8 @@ else if(inst instanceof AggregateUnarySPInstruction) {
157158
return AggregateUnaryFEDInstruction.parseInstruction(auinstruction);
158159
}
159160
else if(inst instanceof ReorgSPInstruction &&
160-
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
161+
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag")
162+
|| inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) {
161163
ReorgSPInstruction rinst = (ReorgSPInstruction) inst;
162164
CacheableData<?> mo = ec.getCacheableData(rinst.input1);
163165
if((mo instanceof MatrixObject || mo instanceof FrameObject) && mo.isFederated() &&

src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ private ReorgSPInstruction(Operator op, CPOperand in, CPOperand col, CPOperand d
8585
}
8686

8787
private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) {
88-
this(op, in, out, opcode, istr);
88+
super(SPType.Reorg, op, in, shift, null, out, opcode, istr);
8989
_shift = shift;
9090
}
9191

0 commit comments

Comments
 (0)