3636import org .apache .sysds .runtime .controlprogram .caching .MatrixObject ;
3737import org .apache .sysds .runtime .controlprogram .context .ExecutionContext ;
3838import org .apache .sysds .runtime .controlprogram .federated .FederatedRange ;
39+ import org .apache .sysds .runtime .controlprogram .federated .FederatedData ;
3940import org .apache .sysds .runtime .controlprogram .federated .FederatedRequest ;
4041import org .apache .sysds .runtime .controlprogram .federated .FederatedResponse ;
4142import org .apache .sysds .runtime .controlprogram .federated .FederatedUDF ;
4243import org .apache .sysds .runtime .controlprogram .federated .FederationMap ;
4344import org .apache .sysds .runtime .controlprogram .federated .FederationUtils ;
4445import org .apache .sysds .runtime .functionobjects .DiagIndex ;
4546import org .apache .sysds .runtime .functionobjects .RevIndex ;
47+ import org .apache .sysds .runtime .functionobjects .RollIndex ;
4648import org .apache .sysds .runtime .functionobjects .SwapIndex ;
4749import org .apache .sysds .runtime .instructions .InstructionUtils ;
4850import org .apache .sysds .runtime .instructions .cp .CPOperand ;
5759import org .apache .sysds .runtime .meta .MatrixCharacteristics ;
5860
5961public class ReorgFEDInstruction extends UnaryFEDInstruction {
62+ // roll-specific attributes
63+ private CPOperand _shift = null ;
6064
6165 public ReorgFEDInstruction (Operator op , CPOperand in1 , CPOperand out , String opcode , String istr , FederatedOutput fedOut ) {
6266 super (FEDType .Reorg , op , in1 , out , opcode , istr , fedOut );
@@ -66,14 +70,29 @@ public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opc
6670 super (FEDType .Reorg , op , in1 , out , opcode , istr );
6771 }
6872
73+ private ReorgFEDInstruction (Operator op , CPOperand in , CPOperand shift , CPOperand out , String opcode , String istr , FederatedOutput fedOut ) {
74+ super (FEDType .Reorg , op , in , shift , out , opcode , istr , fedOut );
75+ _shift = shift ;
76+ }
77+
6978 public static ReorgFEDInstruction parseInstruction (ReorgCPInstruction rinst ) {
70- return new ReorgFEDInstruction (rinst .getOperator (), rinst .input1 , rinst .output , rinst .getOpcode (),
71- rinst .getInstructionString (), FederatedOutput .NONE );
79+ if (rinst .input2 != null ) {
80+ return new ReorgFEDInstruction (rinst .getOperator (), rinst .input1 , rinst .input2 , rinst .output , rinst .getOpcode (),
81+ rinst .getInstructionString (), FederatedOutput .NONE );
82+ } else {
83+ return new ReorgFEDInstruction (rinst .getOperator (), rinst .input1 , rinst .output , rinst .getOpcode (),
84+ rinst .getInstructionString (), FederatedOutput .NONE );
85+ }
7286 }
7387
7488 public static ReorgFEDInstruction parseInstruction (ReorgSPInstruction rinst ) {
75- return new ReorgFEDInstruction (rinst .getOperator (), rinst .input1 , rinst .output , rinst .getOpcode (),
76- rinst .getInstructionString (), FederatedOutput .NONE );
89+ if (rinst .input2 != null ) {
90+ return new ReorgFEDInstruction (rinst .getOperator (), rinst .input1 , rinst .input2 , rinst .output , rinst .getOpcode (),
91+ rinst .getInstructionString (), FederatedOutput .NONE );
92+ } else {
93+ return new ReorgFEDInstruction (rinst .getOperator (), rinst .input1 , rinst .output , rinst .getOpcode (),
94+ rinst .getInstructionString (), FederatedOutput .NONE );
95+ }
7796 }
7897
7998 public static ReorgFEDInstruction parseInstruction (String str ) {
@@ -105,6 +124,15 @@ else if(opcode.equalsIgnoreCase("rev")) {
105124 return new ReorgFEDInstruction (new ReorgOperator (RevIndex .getRevIndexFnObject ()), in , out , opcode , str ,
106125 fedOut );
107126 }
127+ else if (opcode .equalsIgnoreCase ("roll" )) {
128+ InstructionUtils .checkNumFields (str , 3 );
129+ in .split (parts [1 ]);
130+ out .split (parts [3 ]);
131+ CPOperand shift = new CPOperand (parts [2 ]);
132+ fedOut = parseFedOutFlag (str , 3 );
133+ return new ReorgFEDInstruction (new ReorgOperator (new RollIndex (0 )),
134+ in , out , shift , opcode , str , fedOut );
135+ }
108136 else {
109137 throw new DMLRuntimeException ("ReorgFEDInstruction: unsupported opcode: " + opcode );
110138 }
@@ -167,6 +195,36 @@ else if(instOpcode.equalsIgnoreCase("rev")) {
167195 .setBlocksize (mo1 .getBlocksize ()).setNonZeros (nnz );
168196 out .setFedMapping (mo1 .getFedMapping ().copyWithNewID (fr1 .getID ()));
169197
198+ optionalForceLocal (out );
199+ } else if (instOpcode .equalsIgnoreCase ("roll" )) {
200+ long rlen = mo1 .getNumRows ();
201+ long shift = ec .getScalarInput (_shift ).getLongValue ();
202+ shift %= (rlen != 0 ? rlen : 1 ); // roll matrix with axis=none
203+
204+ long inID = mo1 .getFedMapping ().getID ();
205+ long outEndID = FederationUtils .getNextFedDataID ();
206+ long outStartID = FederationUtils .getNextFedDataID ();
207+
208+ List <Pair <FederatedRange , FederatedData >> inMap = mo1 .getFedMapping ().getMap ();
209+ Pair <FederationMap , Long > rollResult = rollFedMap (
210+ inMap , inID , outEndID , outStartID , shift , rlen , mo1 .getFedMapping ().getType ());
211+ long length = rollResult .getValue ();
212+ FederationMap outFedMap = rollResult .getKey ();
213+
214+ FederatedRequest frEnd = new FederatedRequest (FederatedRequest .RequestType .EXEC_UDF , outEndID ,
215+ new ReorgFEDInstruction .SliceMatrix (inID , outEndID , length , true ));
216+ FederatedRequest frStart = new FederatedRequest (FederatedRequest .RequestType .EXEC_UDF , outStartID ,
217+ new ReorgFEDInstruction .SliceMatrix (inID , outStartID , length , false ));
218+ Future <FederatedResponse >[] ffr = outFedMap .executeRoll (getTID (), true , frEnd , frStart , rlen );
219+
220+ //derive output federated mapping
221+ MatrixObject out = ec .getMatrixObject (output );
222+ long nnz = (mo1 .getNnz () != -1 ) ? mo1 .getNnz () : FederationUtils .sumNonZeros (ffr );
223+ out .getDataCharacteristics ()
224+ .setDimension (mo1 .getNumRows (), mo1 .getNumColumns ())
225+ .setBlocksize (mo1 .getBlocksize ())
226+ .setNonZeros (nnz );
227+ out .setFedMapping (outFedMap );
170228 optionalForceLocal (out );
171229 }
172230 else if (instOpcode .equals ("rdiag" )) {
@@ -189,6 +247,40 @@ else if (instOpcode.equals("rdiag")) {
189247 }
190248 }
191249
250+
251+ public Pair <FederationMap , Long > rollFedMap (List <Pair <FederatedRange , FederatedData >> oldMap , long inID ,
252+ long outEndID , long outStartID , long shift , long rlen , FType type ) {
253+ List <Pair <FederatedRange , FederatedData >> map = new ArrayList <>();
254+ long length = 0 ;
255+
256+ for (Map .Entry <FederatedRange , FederatedData > e : oldMap ) {
257+ if (e .getKey ().getSize () == 0 ) continue ;
258+ FederatedRange fedRange = new FederatedRange (e .getKey ());
259+ long beginRow = fedRange .getBeginDims ()[0 ] + shift ;
260+ long endRow = fedRange .getEndDims ()[0 ] + shift ;
261+
262+ beginRow = beginRow > rlen ? beginRow - rlen : beginRow ;
263+ endRow = endRow > rlen ? endRow - rlen : endRow ;
264+
265+ if (beginRow < endRow ) {
266+ fedRange .setBeginDim (0 , beginRow );
267+ fedRange .setEndDim (0 , endRow );
268+ map .add (Pair .of (fedRange , e .getValue ().copyWithNewID (inID )));
269+ } else {
270+ length = rlen - beginRow ;
271+ fedRange .setBeginDim (0 , beginRow );
272+ fedRange .setEndDim (0 , rlen );
273+ map .add (Pair .of (fedRange , e .getValue ().copyWithNewID (outEndID )));
274+
275+ FederatedRange startRange = new FederatedRange (fedRange );
276+ startRange .setBeginDim (0 , 0 );
277+ startRange .setEndDim (0 , endRow );
278+ map .add (Pair .of (startRange , e .getValue ().copyWithNewID (outStartID )));
279+ }
280+ }
281+ return Pair .of (new FederationMap (outEndID , map , type ), length );
282+ }
283+
192284 /**
193285 * Update the federated ranges of result and return the updated federation map.
194286 * @param result RdiagResult for which the fedmap is updated
@@ -307,6 +399,51 @@ private RdiagResult rdiagM2V (MatrixObject mo1, ReorgOperator r_op) {
307399 return new RdiagResult (diagFedMap , dcs );
308400 }
309401
402+ public static class SliceMatrix extends FederatedUDF {
403+ private static final long serialVersionUID = -3466926635958851402L ;
404+ private final long _outputID ;
405+ private final int _sliceRow ;
406+ private final boolean _isRight ;
407+
408+ private SliceMatrix (long input , long outputID , long sliceRow , boolean isRight ) {
409+ super (new long [] {input });
410+ _outputID = outputID ;
411+ _sliceRow = (int ) sliceRow ;
412+ _isRight = isRight ;
413+ }
414+
415+ @ Override
416+ public FederatedResponse execute (ExecutionContext ec , Data ... data ) {
417+ MatrixBlock oriBlock = ((MatrixObject ) data [0 ]).acquireReadAndRelease ();
418+ MatrixBlock resBlock ;
419+
420+ if (_sliceRow != 0 ){
421+ if (_isRight ){
422+ resBlock = oriBlock .slice (0 , _sliceRow -1 , 0 ,
423+ oriBlock .getNumColumns ()-1 , new MatrixBlock ());
424+ } else {
425+ resBlock = oriBlock .slice (_sliceRow , oriBlock .getNumRows ()-1 ,
426+ 0 , oriBlock .getNumColumns ()-1 , new MatrixBlock ());
427+ }
428+ } else {
429+ resBlock = oriBlock ;
430+ }
431+ ec .setMatrixOutput (String .valueOf (_outputID ), resBlock );
432+ return new FederatedResponse (FederatedResponse .ResponseType .SUCCESS , resBlock );
433+ }
434+
435+ @ Override
436+ public List <Long > getOutputIds () {
437+ return new ArrayList <>(Arrays .asList (_outputID ));
438+ }
439+
440+ @ Override
441+ public Pair <String , LineageItem > getLineageItem (ExecutionContext ec ) {
442+ return Pair .of (String .valueOf (_outputID ),
443+ new LineageItem ());
444+ }
445+ }
446+
310447 public static class Rdiag extends FederatedUDF {
311448
312449 private static final long serialVersionUID = -3466926635958851402L ;
0 commit comments