Skip to content

Commit 53cba12

Browse files
gaturchenkomboehm7
authored andcommitted
[SYSTEMDS-3800] Improve Code Coverage for Federated Operations
Closes #2148.
1 parent b2f3966 commit 53cba12

File tree

10 files changed

+329
-108
lines changed

10 files changed

+329
-108
lines changed

src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,38 @@ public void processInstruction(ExecutionContext ec) {
122122
if(matrixInputsCount == 3)
123123
processMatrixInput(ec, mo1, mo2, mo3);
124124
else if(matrixInputsCount == 1) {
125-
CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
125+
CPOperand in;
126+
// determine the position of a matrix in the input and whether any of the scalars are not literals
127+
if (mo1 == null) {
128+
if (mo2 == null) { // sc, sc, mat
129+
in = input3;
130+
instString = InstructionUtils.replaceOperand(instString, 2,
131+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input1).getStringValue(), Types.ValueType.FP64));
132+
if (!input2.isLiteral()) {
133+
instString = InstructionUtils.replaceOperand(instString, 3,
134+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input2).getStringValue(), Types.ValueType.FP64));
135+
}
136+
} else { // sc, mat, sc
137+
in = input2;
138+
instString = InstructionUtils.replaceOperand(instString, 2,
139+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input1).getStringValue(), Types.ValueType.FP64));
140+
if (!input3.isLiteral()) {
141+
instString = InstructionUtils.replaceOperand(instString, 4,
142+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input3).getStringValue(), Types.ValueType.FP64));
143+
}
144+
}
145+
} else { // mat, sc, sc
146+
in = input1;
147+
if (!input2.isLiteral()) {
148+
instString = InstructionUtils.replaceOperand(instString, 3,
149+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input2).getStringValue(), Types.ValueType.FP64));
150+
}
151+
if (!input3.isLiteral()) {
152+
instString = InstructionUtils.replaceOperand(instString, 4,
153+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input3).getStringValue(), Types.ValueType.FP64));
154+
}
155+
}
156+
126157
mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
127158
processMatrixScalarInput(ec, mo1, in);
128159
}
@@ -150,11 +181,10 @@ else if(mo1 != null && mo3 != null) {
150181

151182
private void processMatrixScalarInput(ExecutionContext ec, MatrixLineagePair mo1, CPOperand in) {
152183
long id = FederationUtils.getNextFedDataID();
153-
FederatedRequest fr1 = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), mo1.getDataType());
154-
155-
FederatedRequest fr2 = FederationUtils.callInstruction(instString, output, id, new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()},
184+
FederatedRequest fr = FederationUtils.callInstruction(instString, output, id, new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()},
156185
InstructionUtils.getExecType(instString), false);
157-
sendFederatedRequests(ec, mo1.getMO(), fr1.getID(), fr1, fr2);
186+
187+
sendFederatedRequests(ec, mo1.getMO(), fr.getID(), fr);
158188
}
159189

160190
private void process2MatrixScalarInput(ExecutionContext ec, MatrixLineagePair mo1, MatrixLineagePair mo2, CPOperand in1, CPOperand in2) {

src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java

Lines changed: 70 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -40,52 +40,71 @@
4040
public class FederatedCentralMomentTest extends AutomatedTestBase {
4141

4242
private final static String TEST_DIR = "functions/federated/";
43-
private final static String TEST_NAME = "FederatedCentralMomentTest";
43+
private final static String TEST_NAME1 = "FederatedCentralMomentTest";
44+
private final static String TEST_NAME2 = "FederatedCentralMomentWeightedTest";
4445
private final static String TEST_CLASS_DIR = TEST_DIR + FederatedCentralMomentTest.class.getSimpleName() + "/";
4546

4647
private final static int blocksize = 1024;
4748
@Parameterized.Parameter()
4849
public int rows;
4950

5051
@Parameterized.Parameter(1)
52+
public int cols;
53+
54+
@Parameterized.Parameter(2)
5155
public int k;
5256

5357
@Parameterized.Parameters
5458
public static Collection<Object[]> data() {
55-
return Arrays.asList(new Object[][] {{1000, 2}, {1000, 3}, {1000, 4}});
59+
return Arrays.asList(new Object[][] {{20, 1, 2}});
5660
}
5761

5862
@Override
5963
public void setUp() {
6064
TestUtils.clearAssertionInformation();
61-
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S.scalar"}));
65+
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
66+
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
6267
}
6368

6469
@Test
65-
@Ignore // infinite runtime online but works locally.
6670
public void federatedCentralMomentCP() {
67-
federatedCentralMoment(Types.ExecMode.SINGLE_NODE);
71+
federatedCentralMoment(Types.ExecMode.SINGLE_NODE, false);
72+
}
73+
74+
@Test
75+
public void federatedCentralMomentWeightedCP() {
76+
federatedCentralMoment(Types.ExecMode.SINGLE_NODE, true);
6877
}
6978

7079
@Test
71-
@Ignore
7280
public void federatedCentralMomentSP() {
73-
federatedCentralMoment(Types.ExecMode.SPARK);
81+
federatedCentralMoment(Types.ExecMode.SPARK, false);
82+
}
83+
84+
// The test fails due to an error while executing rmvar instruction after cm calculation
85+
// The CacheStatus of the weights variable is READ hence it can't be modified
86+
// In this test the input matrix is federated and weights are read from file
87+
@Ignore
88+
@Test
89+
public void federatedCentralMomentWeightedSP() {
90+
federatedCentralMoment(Types.ExecMode.SPARK, true);
7491
}
7592

76-
public void federatedCentralMoment(Types.ExecMode execMode) {
93+
public void federatedCentralMoment(Types.ExecMode execMode, boolean isWeighted) {
7794
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
7895
Types.ExecMode platformOld = rtplatform;
7996

97+
String TEST_NAME = isWeighted ? TEST_NAME2 : TEST_NAME1;
8098
getAndLoadTestConfiguration(TEST_NAME);
8199
String HOME = SCRIPT_DIR + TEST_DIR;
82100

83101
int r = rows / 4;
102+
int c = cols;
84103

85-
double[][] X1 = getRandomMatrix(r, 1, 1, 5, 1, 3);
86-
double[][] X2 = getRandomMatrix(r, 1, 1, 5, 1, 7);
87-
double[][] X3 = getRandomMatrix(r, 1, 1, 5, 1, 8);
88-
double[][] X4 = getRandomMatrix(r, 1, 1, 5, 1, 9);
104+
double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
105+
double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
106+
double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
107+
double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
89108

90109
MatrixCharacteristics mc = new MatrixCharacteristics(r, 1, blocksize, r);
91110
writeInputMatrixWithMTD("X1", X1, false, mc);
@@ -114,24 +133,47 @@ public void federatedCentralMoment(Types.ExecMode execMode) {
114133
if(rtplatform == Types.ExecMode.SPARK) {
115134
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
116135
}
117-
// Run reference dml script with normal matrix for Row/Col
118-
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
119-
programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
120-
expected("S"), String.valueOf(k)};
121-
runTest(null);
122-
123136
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
124137
loadTestConfiguration(config);
125-
126-
fullDMLScriptName = HOME + TEST_NAME + ".dml";
127-
programArgs = new String[] {"-stats", "100", "-nvargs",
128-
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
129-
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
130-
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
131-
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + 1,
132-
"out_S=" + output("S"), "k=" + k};
133-
runTest(null);
134-
138+
if (isWeighted) {
139+
double[][] W1 = getRandomMatrix(r, c, 0, 1, 1, 3);
140+
141+
writeInputMatrixWithMTD("W1", W1, false, mc);
142+
143+
// Run reference dml script with normal matrix
144+
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
145+
programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
146+
input("W1"), expected("S"), "" + k};
147+
runTest(null);
148+
149+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
150+
programArgs = new String[] {"-stats", "100", "-nvargs",
151+
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
152+
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
153+
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
154+
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
155+
"in_W1=" + input("W1"),
156+
"rows=" + rows, "cols=" + cols, "k=" + k,
157+
"out_S=" + output("S")};
158+
runTest(null);
159+
}
160+
else {
161+
// Run reference dml script with normal matrix for Row/Col
162+
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
163+
programArgs = new String[]{"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
164+
expected("S"), String.valueOf(k)};
165+
runTest(null);
166+
167+
168+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
169+
programArgs = new String[]{"-stats", "100", "-nvargs",
170+
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
171+
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
172+
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
173+
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + 1,
174+
"out_S=" + output("S"), "k=" + k};
175+
runTest(null);
176+
}
135177
// compare all sums via files
136178
compareResults(0.01);
137179

src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ public void testMinDenseMatrixCP() {
8989
runCumOperationTest(OpType.MIN, ExecType.CP);
9090
}
9191

92+
@Test
93+
public void testProdDenseMatrixCP() {
94+
runCumOperationTest(OpType.PROD, ExecType.CP);
95+
}
96+
97+
@Test
98+
public void testSumProdDenseMatrixCP() {
99+
runCumOperationTest(OpType.SUMPROD, ExecType.CP);
100+
}
101+
92102
@Test
93103
@Ignore
94104
public void testSumDenseMatrixSP() {
@@ -189,7 +199,10 @@ private void runCumOperationTest(OpType type, ExecType instType) {
189199
runTest(true, false, null, -1);
190200

191201
// compare via files
192-
compareResults(1e-6, "DML1", "DML2");
202+
if (type != OpType.SUMPROD && type != OpType.PROD)
203+
compareResults(1e-6, "DML1", "DML2");
204+
else // we sum over the cumsumprod matrix and get a very large number, hence the large tolerance
205+
compareResults(1e+73, "DML1", "DML2");
193206

194207
switch(type) {
195208
case SUM:
@@ -208,12 +221,20 @@ private void runCumOperationTest(OpType type, ExecType instType) {
208221
heavyHittersContainsString(instType == ExecType.SPARK ? "fed_bcumoffmin" : "fed_ucummin"));
209222
break;
210223
case SUMPROD:
211-
Assert.assertTrue(heavyHittersContainsString(instType == ExecType.SPARK ? "fed_bcumoff+*" : "ucumk+*"));
224+
// when input is column-partitioned, ucumk+* is executed instead of fed_ucumk+*
225+
Assert.assertTrue(heavyHittersContainsString(instType == ExecType.SPARK ? "fed_bcumoff+*" :
226+
rowPartitioned ? "fed_ucumk+*" : "ucumk+*"));
212227
break;
213228
}
214229

215-
if(instType != ExecType.SPARK) // verify output is federated
216-
Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
230+
if(instType != ExecType.SPARK) { // verify output is federated
231+
if (type == OpType.SUMPROD && !rowPartitioned) {
232+
Assert.assertTrue(heavyHittersContainsString("uak+"));
233+
} else {
234+
Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
235+
}
236+
}
237+
217238

218239
// check that federated input files are still existing
219240
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));

0 commit comments

Comments
 (0)