Skip to content

Commit 209411c

Browse files
committed
Modify the FederatedCovarianceTest to account for weighted covariance
1 parent a7d27ae commit 209411c

File tree

7 files changed

+426
-3
lines changed

7 files changed

+426
-3
lines changed

src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java

Lines changed: 241 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ public class FederatedCovarianceTest extends AutomatedTestBase {
4343

4444
private final static String TEST_NAME1 = "FederatedCovarianceTest";
4545
private final static String TEST_NAME2 = "FederatedCovarianceAlignedTest";
46+
private final static String TEST_NAME3 = "FederatedCovarianceWeightedTest";
47+
private final static String TEST_NAME4 = "FederatedCovarianceAlignedWeightedTest";
48+
private final static String TEST_NAME5 = "FederatedCovarianceAllAlignedWeightedTest";
4649
private final static String TEST_DIR = "functions/federated/";
4750
private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCovarianceTest.class.getSimpleName() + "/";
4851

@@ -64,19 +67,37 @@ public void setUp() {
6467
TestUtils.clearAssertionInformation();
6568
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
6669
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
70+
addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"}));
71+
addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S.scalar"}));
72+
addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S.scalar"}));
6773
}
6874

6975
@Test
7076
public void testCovCP() {
71-
runCovTest(ExecMode.SINGLE_NODE, false);
77+
runCovarianceTest(ExecMode.SINGLE_NODE, false);
7278
}
7379

7480
@Test
7581
public void testAlignedCovCP() {
76-
runCovTest(ExecMode.SINGLE_NODE, true);
82+
runCovarianceTest(ExecMode.SINGLE_NODE, true);
7783
}
7884

79-
private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
85+
@Test
86+
public void testCovarianceWeightedCP() {
87+
runWeightedCovarianceTest(ExecMode.SINGLE_NODE, false, false);
88+
}
89+
90+
@Test
91+
public void testAlignedCovarianceWeightedCP() {
92+
runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, false);
93+
}
94+
95+
@Test
96+
public void testAllAlignedCovarianceWeightedCP() {
97+
runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, true);
98+
}
99+
100+
private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) {
80101
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
81102
ExecMode platformOld = rtplatform;
82103

@@ -190,4 +211,221 @@ private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
190211
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
191212
}
192213
}
214+
215+
private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput, boolean alignedWeights) {
216+
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
217+
ExecMode platformOld = rtplatform;
218+
219+
if(rtplatform == ExecMode.SPARK)
220+
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
221+
222+
String TEST_NAME = !alignedInput ? TEST_NAME3 : (!alignedWeights ? TEST_NAME4 : TEST_NAME5);
223+
getAndLoadTestConfiguration(TEST_NAME);
224+
225+
String HOME = SCRIPT_DIR + TEST_DIR;
226+
227+
int r = rows / 4;
228+
int c = cols;
229+
230+
fullDMLScriptName = "";
231+
232+
// Create 4 random 5x1 matrices
233+
double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
234+
double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
235+
double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
236+
double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
237+
238+
// Create a 20x1 weights matrix
239+
double[][] W = getRandomMatrix(rows, c, 0, 1, 1, 3);
240+
241+
MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
242+
writeInputMatrixWithMTD("X1", X1, false, mc);
243+
writeInputMatrixWithMTD("X2", X2, false, mc);
244+
writeInputMatrixWithMTD("X3", X3, false, mc);
245+
writeInputMatrixWithMTD("X4", X4, false, mc);
246+
247+
writeInputMatrixWithMTD("W", W, false, new MatrixCharacteristics(rows, cols, blocksize, r * c));
248+
249+
// empty script name because we don't execute any script, just start the worker
250+
fullDMLScriptName = "";
251+
int port1 = getRandomAvailablePort();
252+
int port2 = getRandomAvailablePort();
253+
int port3 = getRandomAvailablePort();
254+
int port4 = getRandomAvailablePort();
255+
256+
Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
257+
Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
258+
Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
259+
Process t4 = startLocalFedWorker(port4);
260+
261+
try {
262+
if(!isAlive(t1, t2, t3, t4))
263+
throw new RuntimeException("Failed starting federated worker");
264+
265+
rtplatform = execMode;
266+
if(rtplatform == ExecMode.SPARK) {
267+
System.out.println(7);
268+
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
269+
}
270+
271+
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
272+
loadTestConfiguration(config);
273+
274+
if (alignedInput) {
275+
// Create 4 random 5x1 matrices
276+
double[][] Y1 = getRandomMatrix(r, c, 1, 5, 1, 3);
277+
double[][] Y2 = getRandomMatrix(r, c, 1, 5, 1, 7);
278+
double[][] Y3 = getRandomMatrix(r, c, 1, 5, 1, 8);
279+
double[][] Y4 = getRandomMatrix(r, c, 1, 5, 1, 9);
280+
281+
writeInputMatrixWithMTD("Y1", Y1, false, mc);
282+
writeInputMatrixWithMTD("Y2", Y2, false, mc);
283+
writeInputMatrixWithMTD("Y3", Y3, false, mc);
284+
writeInputMatrixWithMTD("Y4", Y4, false, mc);
285+
286+
if (!alignedWeights) {
287+
// Run reference dml script with a normal matrix
288+
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
289+
programArgs = new String[] {
290+
"-stats", "100", "-args",
291+
input("X1"),
292+
input("X2"),
293+
input("X3"),
294+
input("X4"),
295+
296+
input("Y1"),
297+
input("Y2"),
298+
input("Y3"),
299+
input("Y4"),
300+
301+
input("W"),
302+
expected("S")
303+
};
304+
runTest(null);
305+
306+
// Run the dml script with federated matrices
307+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
308+
programArgs = new String[] {"-stats", "100", "-nvargs",
309+
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
310+
"in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
311+
312+
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
313+
"in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")),
314+
315+
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
316+
"in_Y3=" + TestUtils.federatedAddress(port3, input("Y3")),
317+
318+
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
319+
"in_Y4=" + TestUtils.federatedAddress(port4, input("Y4")),
320+
321+
"in_W1=" + input("W"),
322+
"rows=" + rows, "cols=" + cols,
323+
"out_S=" + output("S")};
324+
runTest(null);
325+
}
326+
else {
327+
double[][] W1 = getRandomMatrix(r, c, 0, 1, 1, 3);
328+
double[][] W2 = getRandomMatrix(r, c, 0, 1, 1, 7);
329+
double[][] W3 = getRandomMatrix(r, c, 0, 1, 1, 8);
330+
double[][] W4 = getRandomMatrix(r, c, 0, 1, 1, 9);
331+
332+
writeInputMatrixWithMTD("W1", W1, false, mc);
333+
writeInputMatrixWithMTD("W2", W2, false, mc);
334+
writeInputMatrixWithMTD("W3", W3, false, mc);
335+
writeInputMatrixWithMTD("W4", W4, false, mc);
336+
337+
// Run reference dml script with a normal matrix
338+
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
339+
programArgs = new String[] {
340+
"-stats", "100", "-args",
341+
input("X1"),
342+
input("X2"),
343+
input("X3"),
344+
input("X4"),
345+
346+
input("Y1"),
347+
input("Y2"),
348+
input("Y3"),
349+
input("Y4"),
350+
351+
input("W1"),
352+
input("W2"),
353+
input("W3"),
354+
input("W4"),
355+
356+
expected("S")
357+
};
358+
runTest(null);
359+
360+
// Run the dml script with federated matrices and weights
361+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
362+
programArgs = new String[] {"-stats", "100", "-nvargs",
363+
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
364+
"in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
365+
"in_W1=" + TestUtils.federatedAddress(port1, input("W1")),
366+
367+
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
368+
"in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")),
369+
"in_W2=" + TestUtils.federatedAddress(port2, input("W2")),
370+
371+
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
372+
"in_Y3=" + TestUtils.federatedAddress(port3, input("Y3")),
373+
"in_W3=" + TestUtils.federatedAddress(port3, input("W3")),
374+
375+
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
376+
"in_Y4=" + TestUtils.federatedAddress(port4, input("Y4")),
377+
"in_W4=" + TestUtils.federatedAddress(port4, input("W4")),
378+
379+
"rows=" + rows, "cols=" + cols,
380+
"out_S=" + output("S")};
381+
runTest(null);
382+
}
383+
384+
}
385+
else {
386+
// Create a random 20x1 input matrix
387+
double[][] Y = getRandomMatrix(rows, c, 1, 5, 1, 3);
388+
writeInputMatrixWithMTD("Y", Y, false, new MatrixCharacteristics(rows, cols, blocksize, r * c));
389+
390+
// Run reference dml script with a normal matrix
391+
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
392+
programArgs = new String[] {
393+
"-stats", "100", "-args",
394+
input("X1"),
395+
input("X2"),
396+
input("X3"),
397+
input("X4"),
398+
399+
input("Y"), input("W"), expected("S")
400+
};
401+
runTest(null);
402+
403+
// Run the dml script with a federated matrix
404+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
405+
programArgs = new String[] {"-stats", "100", "-nvargs",
406+
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
407+
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
408+
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
409+
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
410+
411+
"in_W1=" + input("W"),
412+
"Y=" + input("Y"),
413+
414+
"rows=" + rows,
415+
"cols=" + cols,
416+
"out_S=" + output("S")};
417+
runTest(null);
418+
}
419+
420+
// compare via files
421+
compareResults(1e-2);
422+
Assert.assertTrue(heavyHittersContainsString("fed_cov"));
423+
424+
}
425+
finally {
426+
TestUtils.shutdownThreads(t1, t2, t3, t4);
427+
rtplatform = platformOld;
428+
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
429+
}
430+
}
193431
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# 5x1 on 4 workers -> 20x1
23+
X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
24+
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
25+
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
26+
27+
# 5x1 on 4 workers -> 20x1
28+
Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
29+
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
30+
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
31+
32+
W = read($in_W1); # 20x1
33+
34+
s = cov(X, Y, W);
35+
write(s, $out_S);
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
X = rbind(read($1), read($2), read($3), read($4)); # 20x1
23+
Y = rbind(read($5), read($6), read($7), read($8)); # 20x1
24+
W = read($9); # 20x1
25+
26+
s = cov(X, Y, W);
27+
write(s, $10);
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# 5x1 on 4 workers -> 20x1
23+
X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
24+
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
25+
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
26+
27+
# 5x1 on 4 workers -> 20x1
28+
Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
29+
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
30+
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
31+
32+
# 5x1 on 4 workers -> 20x1
33+
W = federated(addresses=list($in_W1, $in_W2, $in_W3, $in_W4),
34+
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
35+
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
36+
37+
s = cov(X, Y, W);
38+
write(s, $out_S);

0 commit comments

Comments
 (0)