Skip to content

Commit a6a1509

Browse files
committed
[SYSTEMDS-3796] Fix flaky federated primitive tests and instructions
1 parent a4b3de3 commit a6a1509

File tree

3 files changed

+52
-59
lines changed

3 files changed

+52
-59
lines changed

src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import java.util.regex.Pattern;
2929

3030
import org.apache.commons.lang3.tuple.Pair;
31+
import org.apache.sysds.lops.Lop;
3132
import org.apache.sysds.runtime.DMLRuntimeException;
3233
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
3334
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -116,7 +117,7 @@ private void processAlignedFedCov(ExecutionContext ec, MatrixObject mo1, MatrixO
116117

117118
FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
118119
FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
119-
Future<FederatedResponse>[] covTmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
120+
Future<FederatedResponse>[] covTmp = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
120121

121122
//means
122123
Future<FederatedResponse>[] meanTmp1 = processMean(mo1, moLin3, 0);
@@ -145,7 +146,7 @@ private void processFedCovWeights(ExecutionContext ec, MatrixObject mo1, MatrixO
145146
FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(moLin3, false);
146147

147148
// the original instruction encodes weights as "pREADW", change to the new ID
148-
String[] parts = instString.split("°");
149+
String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
149150
String covInstr = instString.replace(parts[4], String.valueOf(fr1[0].getID()) + "·MATRIX·FP64");
150151

151152
FederatedRequest fr2 = FederationUtils.callInstruction(
@@ -305,7 +306,7 @@ private Future<FederatedResponse>[] processMean(MatrixObject mo1, MatrixLineageP
305306
Future<FederatedResponse>[] meanTmp = null;
306307
if (moLin3 == null) {
307308
String meanInstr = instString.replace(getOpcode(), getOpcode().replace("cov", "uamean"));
308-
meanInstr = meanInstr.replace((var == 0 ? parts[2] : parts[3]) + "°", "");
309+
meanInstr = meanInstr.replace((var == 0 ? parts[2] : parts[3]) + Lop.OPERAND_DELIMITOR, "");
309310
meanInstr = meanInstr.replace(parts[4], parts[4].replace("FP64", "STRING°16"));
310311

311312
//create federated commands for aggregation
@@ -321,7 +322,7 @@ private Future<FederatedResponse>[] processMean(MatrixObject mo1, MatrixLineageP
321322
String multOutput = incrementVar(parts[4], 1);
322323
String multInstr = instString
323324
.replace(getOpcode(), getOpcode().replace("cov", "*"))
324-
.replace((var == 0 ? parts[2] : parts[3]) + "°", "")
325+
.replace((var == 0 ? parts[2] : parts[3]) + Lop.OPERAND_DELIMITOR, "")
325326
.replace(parts[5], multOutput);
326327

327328
CPOperand multOutputCPOp = new CPOperand(
@@ -337,13 +338,13 @@ private Future<FederatedResponse>[] processMean(MatrixObject mo1, MatrixLineageP
337338
);
338339

339340
// calculate the sum of the obtained vector
340-
String[] partsMult = multInstr.split("°");
341+
String[] partsMult = multInstr.split(Lop.OPERAND_DELIMITOR);
341342
String sumInstr1Output = incrementVar(multOutput, 1)
342343
.replace("m", "")
343344
.replace("MATRIX", "SCALAR");
344345
String sumInstr1 = multInstr
345346
.replace(partsMult[1], "uak+")
346-
.replace(partsMult[3] + "°", "")
347+
.replace(partsMult[3] + Lop.OPERAND_DELIMITOR, "")
347348
.replace(partsMult[4], sumInstr1Output)
348349
.replace(partsMult[2], multOutput);
349350

@@ -358,7 +359,7 @@ private Future<FederatedResponse>[] processMean(MatrixObject mo1, MatrixLineageP
358359
);
359360

360361
// calculate the sum of weights
361-
String[] partsSum1 = sumInstr1.split("°");
362+
String[] partsSum1 = sumInstr1.split(Lop.OPERAND_DELIMITOR);
362363
String sumInstr2Output = incrementVar(sumInstr1Output, 1);
363364
String sumInstr2 = sumInstr1
364365
.replace(partsSum1[2], parts[4])
@@ -367,20 +368,21 @@ private Future<FederatedResponse>[] processMean(MatrixObject mo1, MatrixLineageP
367368
FederatedRequest sumFr2 = FederationUtils.callInstruction(
368369
sumInstr2,
369370
new CPOperand(
370-
sumInstr2Output.substring(0, sumInstr2Output.indexOf("·")),
371+
sumInstr2Output.substring(0, sumInstr2Output.indexOf(Lop.DATATYPE_PREFIX)),
371372
output.getValueType(), output.getDataType()
372373
),
373374
new CPOperand[]{input3},
374375
new long[]{moLin3.getFedMapping().getID()}
375376
);
376377

377378
// divide sum(X*W) by sum(W)
378-
String[] partsSum2 = sumInstr2.split("°");
379+
String[] partsSum2 = sumInstr2.split(Lop.OPERAND_DELIMITOR);
379380
String divInstrOutput = incrementVar(sumInstr2Output, 1);
380-
String divInstrInput1 = partsSum2[2].replace(partsSum2[2], sumInstr1Output + "·false");
381-
String divInstrInput2 = partsSum2[3].replace(partsSum2[3], sumInstr2Output + "·false");
382-
String divInstr = partsSum2[0] + "°" + partsSum2[1].replace("uak+", "/") + "°" +
383-
divInstrInput1 + "°" + divInstrInput2 + "°" + divInstrOutput + "°" + partsSum2[4];
381+
String divInstrInput1 = partsSum2[2].replace(partsSum2[2], sumInstr1Output + Lop.DATATYPE_PREFIX + "false");
382+
String divInstrInput2 = partsSum2[3].replace(partsSum2[3], sumInstr2Output + Lop.DATATYPE_PREFIX + "false");
383+
String divInstr = partsSum2[0] + Lop.OPERAND_DELIMITOR + partsSum2[1].replace("uak+", "/")
384+
+ Lop.OPERAND_DELIMITOR + divInstrInput1 + Lop.OPERAND_DELIMITOR + divInstrInput2
385+
+ Lop.OPERAND_DELIMITOR + divInstrOutput + Lop.OPERAND_DELIMITOR + partsSum2[4];
384386

385387
FederatedRequest divFr1 = FederationUtils.callInstruction(
386388
divInstr,
@@ -390,11 +392,11 @@ private Future<FederatedResponse>[] processMean(MatrixObject mo1, MatrixLineageP
390392
),
391393
new CPOperand[]{
392394
new CPOperand(
393-
sumInstr1Output.substring(0, sumInstr1Output.indexOf("·")),
395+
sumInstr1Output.substring(0, sumInstr1Output.indexOf(Lop.DATATYPE_PREFIX)),
394396
output.getValueType(), output.getDataType(), output.isLiteral()
395397
),
396398
new CPOperand(
397-
sumInstr2Output.substring(0, sumInstr2Output.indexOf("·")),
399+
sumInstr2Output.substring(0, sumInstr2Output.indexOf(Lop.DATATYPE_PREFIX)),
398400
output.getValueType(), output.getDataType(), output.isLiteral()
399401
)
400402
},
@@ -409,19 +411,19 @@ private Future<FederatedResponse>[] processMean(MatrixObject mo1, MatrixLineageP
409411
}
410412

411413
private Future<FederatedResponse>[] processMean(MatrixObject mo1, int var, long weightsID){
412-
String[] parts = instString.split("°");
414+
String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
413415
Future<FederatedResponse>[] meanTmp = null;
414416

415417
// multiply input X by weights W element-wise
416418
String multOutput = (var == 0 ? incrementVar(parts[2], 5) : incrementVar(parts[3], 3));
417419
String multInstr = instString
418420
.replace(getOpcode(), getOpcode().replace("cov", "*"))
419-
.replace((var == 0 ? parts[2] : parts[3]) + "°", "")
421+
.replace((var == 0 ? parts[2] : parts[3]) + Lop.OPERAND_DELIMITOR, "")
420422
.replace(parts[4], String.valueOf(weightsID) + "·MATRIX·FP64")
421423
.replace(parts[5], multOutput);
422424

423425
CPOperand multOutputCPOp = new CPOperand(
424-
multOutput.substring(0, multOutput.indexOf("·")),
426+
multOutput.substring(0, multOutput.indexOf(Lop.DATATYPE_PREFIX)),
425427
mo1.getValueType(), mo1.getDataType()
426428
);
427429

@@ -433,28 +435,28 @@ private Future<FederatedResponse>[] processMean(MatrixObject mo1, int var, long
433435
);
434436

435437
// calculate the sum of the obtained vector
436-
String[] partsMult = multInstr.split("°");
438+
String[] partsMult = multInstr.split(Lop.OPERAND_DELIMITOR);
437439
String sumInstr1Output = incrementVar(multOutput, 1)
438440
.replace("m", "")
439441
.replace("MATRIX", "SCALAR");
440442
String sumInstr1 = multInstr
441443
.replace(partsMult[1], "uak+")
442-
.replace(partsMult[3] + "°", "")
444+
.replace(partsMult[3] + Lop.OPERAND_DELIMITOR, "")
443445
.replace(partsMult[4], sumInstr1Output)
444446
.replace(partsMult[2], multOutput);
445447

446448
FederatedRequest sumFr1 = FederationUtils.callInstruction(
447449
sumInstr1,
448450
new CPOperand(
449-
sumInstr1Output.substring(0, sumInstr1Output.indexOf("·")),
451+
sumInstr1Output.substring(0, sumInstr1Output.indexOf(Lop.DATATYPE_PREFIX)),
450452
output.getValueType(), output.getDataType()
451453
),
452454
new CPOperand[]{multOutputCPOp},
453455
new long[]{multFr.getID()}
454456
);
455457

456458
// calculate the sum of weights
457-
String[] partsSum1 = sumInstr1.split("°");
459+
String[] partsSum1 = sumInstr1.split(Lop.OPERAND_DELIMITOR);
458460
String sumInstr2Output = incrementVar(sumInstr1Output, 1);
459461
String sumInstr2 = sumInstr1
460462
.replace(partsSum1[2], String.valueOf(weightsID) + "·MATRIX·FP64")
@@ -463,34 +465,35 @@ private Future<FederatedResponse>[] processMean(MatrixObject mo1, int var, long
463465
FederatedRequest sumFr2 = FederationUtils.callInstruction(
464466
sumInstr2,
465467
new CPOperand(
466-
sumInstr2Output.substring(0, sumInstr2Output.indexOf("·")),
468+
sumInstr2Output.substring(0, sumInstr2Output.indexOf(Lop.DATATYPE_PREFIX)),
467469
output.getValueType(), output.getDataType()
468470
),
469471
new CPOperand[]{input3},
470472
new long[]{weightsID}
471473
);
472474

473475
// divide sum(X*W) by sum(W)
474-
String[] partsSum2 = sumInstr2.split("°");
476+
String[] partsSum2 = sumInstr2.split(Lop.OPERAND_DELIMITOR);
475477
String divInstrOutput = incrementVar(sumInstr2Output, 1);
476-
String divInstrInput1 = partsSum2[2].replace(partsSum2[2], sumInstr1Output + "·false");
477-
String divInstrInput2 = partsSum2[3].replace(partsSum2[3], sumInstr2Output + "·false");
478-
String divInstr = partsSum2[0] + "°" + partsSum2[1].replace("uak+", "/") + "°" +
479-
divInstrInput1 + "°" + divInstrInput2 + "°" + divInstrOutput + "°" + partsSum2[4];
478+
String divInstrInput1 = partsSum2[2].replace(partsSum2[2], sumInstr1Output + Lop.DATATYPE_PREFIX + "false");
479+
String divInstrInput2 = partsSum2[3].replace(partsSum2[3], sumInstr2Output + Lop.DATATYPE_PREFIX + "false");
480+
String divInstr = partsSum2[0] + Lop.OPERAND_DELIMITOR + partsSum2[1].replace("uak+", "/") + Lop.OPERAND_DELIMITOR
481+
+ divInstrInput1 + Lop.OPERAND_DELIMITOR + divInstrInput2 + Lop.OPERAND_DELIMITOR
482+
+ divInstrOutput + Lop.OPERAND_DELIMITOR + partsSum2[4];
480483

481484
FederatedRequest divFr1 = FederationUtils.callInstruction(
482485
divInstr,
483486
new CPOperand(
484-
divInstrOutput.substring(0, divInstrOutput.indexOf("·")),
487+
divInstrOutput.substring(0, divInstrOutput.indexOf(Lop.DATATYPE_PREFIX)),
485488
output.getValueType(), output.getDataType()
486489
),
487490
new CPOperand[]{
488491
new CPOperand(
489-
sumInstr1Output.substring(0, sumInstr1Output.indexOf("·")),
492+
sumInstr1Output.substring(0, sumInstr1Output.indexOf(Lop.DATATYPE_PREFIX)),
490493
output.getValueType(), output.getDataType(), output.isLiteral()
491494
),
492495
new CPOperand(
493-
sumInstr2Output.substring(0, sumInstr2Output.indexOf("·")),
496+
sumInstr2Output.substring(0, sumInstr2Output.indexOf(Lop.DATATYPE_PREFIX)),
494497
output.getValueType(), output.getDataType(), output.isLiteral()
495498
)
496499
},
@@ -506,14 +509,15 @@ private Future<FederatedResponse>[] processMean(MatrixObject mo1, int var, long
506509
private Future<FederatedResponse>[] getWeightsSum(MatrixLineagePair moLin3, long weightsID, String instString, FederationMap fedMap) {
507510
Future<FederatedResponse>[] weightsSumTmp = null;
508511

509-
String[] parts = instString.split("°");
512+
String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
510513
if (!instString.contains("pREADW")) {
511-
String sumInstr = "CP°uak+°" + parts[4] + "°" + parts[5] + "°" + parts[6];
514+
String sumInstr = "CP"+Lop.OPERAND_DELIMITOR+"uak+" + Lop.OPERAND_DELIMITOR
515+
+ parts[4] + Lop.OPERAND_DELIMITOR + parts[5] + Lop.OPERAND_DELIMITOR + parts[6];
512516

513517
FederatedRequest sumFr = FederationUtils.callInstruction(
514518
sumInstr,
515519
new CPOperand(
516-
parts[5].substring(0, parts[5].indexOf("·")),
520+
parts[5].substring(0, parts[5].indexOf(Lop.DATATYPE_PREFIX)),
517521
output.getValueType(),
518522
output.getDataType()
519523
),
@@ -526,11 +530,13 @@ private Future<FederatedResponse>[] getWeightsSum(MatrixLineagePair moLin3, long
526530
weightsSumTmp = fedMap.execute(getTID(), sumFr, sumFr2, sumFr3);
527531
}
528532
else {
529-
String sumInstr = "CP°uak+°" + String.valueOf(weightsID) + "·MATRIX·FP64" + "°" + parts[5] + "°" + parts[6];
533+
String sumInstr = "CP"+Lop.OPERAND_DELIMITOR+"uak+"+Lop.OPERAND_DELIMITOR
534+
+ String.valueOf(weightsID) + "·MATRIX·FP64" + Lop.OPERAND_DELIMITOR + parts[5]
535+
+ Lop.OPERAND_DELIMITOR + parts[6];
530536
FederatedRequest sumFr = FederationUtils.callInstruction(
531537
sumInstr,
532538
new CPOperand(
533-
parts[5].substring(0, parts[5].indexOf("·")),
539+
parts[5].substring(0, parts[5].indexOf(Lop.DATATYPE_PREFIX)),
534540
output.getValueType(),
535541
output.getDataType()
536542
),
@@ -576,7 +582,8 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) {
576582
return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.covOperations(_op, _mo2));
577583
}
578584

579-
@Override public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
585+
@Override
586+
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
580587
return null;
581588
}
582589
}
@@ -600,7 +607,8 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) {
600607
return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.covOperations(_op, _mo2, _weights));
601608
}
602609

603-
@Override public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
610+
@Override
611+
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
604612
return null;
605613
}
606614
}

src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,13 @@ private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) {
136136
Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
137137
Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
138138
Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
139-
Process t4 = startLocalFedWorker(port4);
139+
Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT);
140140

141141
try {
142142
if(!isAlive(t1, t2, t3, t4))
143143
throw new RuntimeException("Failed starting federated worker");
144144

145-
rtplatform = execMode;
146-
if(rtplatform == ExecMode.SPARK) {
147-
System.out.println(7);
148-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
149-
}
145+
setExecMode(execMode);
150146
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
151147
loadTestConfiguration(config);
152148

@@ -214,11 +210,7 @@ private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) {
214210

215211
private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput, boolean alignedWeights) {
216212
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
217-
ExecMode platformOld = rtplatform;
218-
219-
if(rtplatform == ExecMode.SPARK)
220-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
221-
213+
ExecMode platformOld = setExecMode(execMode);
222214
String TEST_NAME = !alignedInput ? TEST_NAME3 : (!alignedWeights ? TEST_NAME4 : TEST_NAME5);
223215
getAndLoadTestConfiguration(TEST_NAME);
224216

@@ -256,18 +248,11 @@ private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput,
256248
Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
257249
Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
258250
Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
259-
Process t4 = startLocalFedWorker(port4);
251+
Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT);
260252

261253
try {
262254
if(!isAlive(t1, t2, t3, t4))
263255
throw new RuntimeException("Failed starting federated worker");
264-
265-
rtplatform = execMode;
266-
if(rtplatform == ExecMode.SPARK) {
267-
System.out.println(7);
268-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
269-
}
270-
271256
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
272257
loadTestConfiguration(config);
273258

src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMatrixScalarOperationsTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.junit.runner.RunWith;
2525
import org.apache.sysds.api.DMLScript;
2626
import org.apache.sysds.common.Types;
27+
import org.apache.sysds.runtime.util.CommonThreadPool;
2728
import org.apache.sysds.test.AutomatedTestBase;
2829
import org.apache.sysds.test.TestConfiguration;
2930
import org.apache.sysds.test.TestUtils;
@@ -197,8 +198,7 @@ private void runGenericTest(String dmlFile, int scalar) {
197198
// we need the reference file to not be written to hdfs, so we get the correct format
198199
rtplatform = Types.ExecMode.SINGLE_NODE;
199200
programArgs = new String[] {"-w", Integer.toString(FEDERATED_WORKER_PORT)};
200-
t = new Thread(() -> runTest(true, false, null, -1));
201-
t.start();
201+
CommonThreadPool.get().submit(() -> runTest(true, false, null, -1));
202202
sleep(FED_WORKER_WAIT);
203203
fullDMLScriptName = SCRIPT_DIR + TEST_DIR + dmlFile + ".dml";
204204
programArgs = new String[] {"-nvargs",

0 commit comments

Comments
 (0)