Skip to content

Commit 1d1b008

Browse files
gaturchenkomboehm7
authored andcommitted
[SYSTEMDS-3789] Fix federated covariance (missing weighted case)
Closes #2137.
1 parent e326add commit 1d1b008

File tree

8 files changed

+739
-58
lines changed

8 files changed

+739
-58
lines changed

src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java

Lines changed: 358 additions & 55 deletions
Large diffs are not rendered by default.

src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java

Lines changed: 196 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ public class FederatedCovarianceTest extends AutomatedTestBase {
4343

4444
private final static String TEST_NAME1 = "FederatedCovarianceTest";
4545
private final static String TEST_NAME2 = "FederatedCovarianceAlignedTest";
46+
private final static String TEST_NAME3 = "FederatedCovarianceWeightedTest";
47+
private final static String TEST_NAME4 = "FederatedCovarianceAlignedWeightedTest";
48+
private final static String TEST_NAME5 = "FederatedCovarianceAllAlignedWeightedTest";
4649
private final static String TEST_DIR = "functions/federated/";
4750
private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCovarianceTest.class.getSimpleName() + "/";
4851

@@ -64,19 +67,37 @@ public void setUp() {
6467
TestUtils.clearAssertionInformation();
6568
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
6669
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
70+
addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"}));
71+
addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S.scalar"}));
72+
addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S.scalar"}));
6773
}
6874

6975
@Test
7076
public void testCovCP() {
71-
runCovTest(ExecMode.SINGLE_NODE, false);
77+
runCovarianceTest(ExecMode.SINGLE_NODE, false);
7278
}
7379

7480
@Test
7581
public void testAlignedCovCP() {
76-
runCovTest(ExecMode.SINGLE_NODE, true);
82+
runCovarianceTest(ExecMode.SINGLE_NODE, true);
7783
}
7884

79-
private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
85+
@Test
86+
public void testCovarianceWeightedCP() {
87+
runWeightedCovarianceTest(ExecMode.SINGLE_NODE, false, false);
88+
}
89+
90+
@Test
91+
public void testAlignedCovarianceWeightedCP() {
92+
runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, false);
93+
}
94+
95+
@Test
96+
public void testAllAlignedCovarianceWeightedCP() {
97+
runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, true);
98+
}
99+
100+
private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) {
80101
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
81102
ExecMode platformOld = rtplatform;
82103

@@ -190,4 +211,176 @@ private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
190211
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
191212
}
192213
}
214+
215+
private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput, boolean alignedWeights) {
216+
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
217+
ExecMode platformOld = rtplatform;
218+
219+
if(rtplatform == ExecMode.SPARK)
220+
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
221+
222+
String TEST_NAME = !alignedInput ? TEST_NAME3 : (!alignedWeights ? TEST_NAME4 : TEST_NAME5);
223+
getAndLoadTestConfiguration(TEST_NAME);
224+
225+
String HOME = SCRIPT_DIR + TEST_DIR;
226+
227+
int r = rows / 4;
228+
int c = cols;
229+
230+
fullDMLScriptName = "";
231+
232+
// Create 4 random 5x1 matrices
233+
double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
234+
double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
235+
double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
236+
double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
237+
238+
// Create a 20x1 weights matrix
239+
double[][] W = getRandomMatrix(rows, c, 0, 1, 1, 3);
240+
241+
MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
242+
writeInputMatrixWithMTD("X1", X1, false, mc);
243+
writeInputMatrixWithMTD("X2", X2, false, mc);
244+
writeInputMatrixWithMTD("X3", X3, false, mc);
245+
writeInputMatrixWithMTD("X4", X4, false, mc);
246+
247+
writeInputMatrixWithMTD("W", W, false, new MatrixCharacteristics(rows, cols, blocksize, r * c));
248+
249+
// empty script name because we don't execute any script, just start the worker
250+
fullDMLScriptName = "";
251+
int port1 = getRandomAvailablePort();
252+
int port2 = getRandomAvailablePort();
253+
int port3 = getRandomAvailablePort();
254+
int port4 = getRandomAvailablePort();
255+
256+
Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
257+
Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
258+
Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
259+
Process t4 = startLocalFedWorker(port4);
260+
261+
try {
262+
if(!isAlive(t1, t2, t3, t4))
263+
throw new RuntimeException("Failed starting federated worker");
264+
265+
rtplatform = execMode;
266+
if(rtplatform == ExecMode.SPARK) {
267+
System.out.println(7);
268+
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
269+
}
270+
271+
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
272+
loadTestConfiguration(config);
273+
274+
if (alignedInput) {
275+
// Create 4 random 5x1 matrices
276+
double[][] Y1 = getRandomMatrix(r, c, 1, 5, 1, 3);
277+
double[][] Y2 = getRandomMatrix(r, c, 1, 5, 1, 7);
278+
double[][] Y3 = getRandomMatrix(r, c, 1, 5, 1, 8);
279+
double[][] Y4 = getRandomMatrix(r, c, 1, 5, 1, 9);
280+
281+
writeInputMatrixWithMTD("Y1", Y1, false, mc);
282+
writeInputMatrixWithMTD("Y2", Y2, false, mc);
283+
writeInputMatrixWithMTD("Y3", Y3, false, mc);
284+
writeInputMatrixWithMTD("Y4", Y4, false, mc);
285+
286+
if (!alignedWeights) {
287+
// Run reference dml script with a normal matrix
288+
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
289+
programArgs = new String[] { "-stats", "100", "-args",
290+
input("X1"), input("X2"), input("X3"), input("X4"),
291+
input("Y1"), input("Y2"), input("Y3"), input("Y4"),
292+
input("W"), expected("S")
293+
};
294+
runTest(null);
295+
296+
// Run the dml script with federated matrices
297+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
298+
programArgs = new String[] {"-stats", "100", "-nvargs",
299+
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
300+
"in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
301+
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
302+
"in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")),
303+
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
304+
"in_Y3=" + TestUtils.federatedAddress(port3, input("Y3")),
305+
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
306+
"in_Y4=" + TestUtils.federatedAddress(port4, input("Y4")),
307+
"in_W1=" + input("W"), "rows=" + rows, "cols=" + cols, "out_S=" + output("S")};
308+
runTest(null);
309+
}
310+
else {
311+
double[][] W1 = getRandomMatrix(r, c, 0, 1, 1, 3);
312+
double[][] W2 = getRandomMatrix(r, c, 0, 1, 1, 7);
313+
double[][] W3 = getRandomMatrix(r, c, 0, 1, 1, 8);
314+
double[][] W4 = getRandomMatrix(r, c, 0, 1, 1, 9);
315+
316+
writeInputMatrixWithMTD("W1", W1, false, mc);
317+
writeInputMatrixWithMTD("W2", W2, false, mc);
318+
writeInputMatrixWithMTD("W3", W3, false, mc);
319+
writeInputMatrixWithMTD("W4", W4, false, mc);
320+
321+
// Run reference dml script with a normal matrix
322+
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
323+
programArgs = new String[] {"-stats", "100", "-args",
324+
input("X1"), input("X2"), input("X3"), input("X4"),
325+
input("Y1"), input("Y2"), input("Y3"), input("Y4"),
326+
input("W1"), input("W2"), input("W3"), input("W4"), expected("S")
327+
};
328+
runTest(null);
329+
330+
// Run the dml script with federated matrices and weights
331+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
332+
programArgs = new String[] {"-stats", "100", "-nvargs",
333+
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
334+
"in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
335+
"in_W1=" + TestUtils.federatedAddress(port1, input("W1")),
336+
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
337+
"in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")),
338+
"in_W2=" + TestUtils.federatedAddress(port2, input("W2")),
339+
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
340+
"in_Y3=" + TestUtils.federatedAddress(port3, input("Y3")),
341+
"in_W3=" + TestUtils.federatedAddress(port3, input("W3")),
342+
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
343+
"in_Y4=" + TestUtils.federatedAddress(port4, input("Y4")),
344+
"in_W4=" + TestUtils.federatedAddress(port4, input("W4")),
345+
"rows=" + rows, "cols=" + cols, "out_S=" + output("S")};
346+
runTest(null);
347+
}
348+
349+
}
350+
else {
351+
// Create a random 20x1 input matrix
352+
double[][] Y = getRandomMatrix(rows, c, 1, 5, 1, 3);
353+
writeInputMatrixWithMTD("Y", Y, false, new MatrixCharacteristics(rows, cols, blocksize, r * c));
354+
355+
// Run reference dml script with a normal matrix
356+
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
357+
programArgs = new String[] {"-stats", "100", "-args",
358+
input("X1"), input("X2"), input("X3"), input("X4"),
359+
input("Y"), input("W"), expected("S")
360+
};
361+
runTest(null);
362+
363+
// Run the dml script with a federated matrix
364+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
365+
programArgs = new String[] {"-stats", "100", "-nvargs",
366+
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
367+
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
368+
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
369+
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
370+
"in_W1=" + input("W"), "Y=" + input("Y"),
371+
"rows=" + rows, "cols=" + cols, "out_S=" + output("S")};
372+
runTest(null);
373+
}
374+
375+
// compare via files
376+
compareResults(1e-2);
377+
Assert.assertTrue(heavyHittersContainsString("fed_cov"));
378+
379+
}
380+
finally {
381+
TestUtils.shutdownThreads(t1, t2, t3, t4);
382+
rtplatform = platformOld;
383+
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
384+
}
385+
}
193386
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# 5x1 on 4 workers -> 20x1
23+
X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
24+
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
25+
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
26+
27+
# 5x1 on 4 workers -> 20x1
28+
Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
29+
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
30+
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
31+
32+
W = read($in_W1); # 20x1
33+
34+
s = cov(X, Y, W);
35+
write(s, $out_S);
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
X = rbind(read($1), read($2), read($3), read($4)); # 20x1
23+
Y = rbind(read($5), read($6), read($7), read($8)); # 20x1
24+
W = read($9); # 20x1
25+
26+
s = cov(X, Y, W);
27+
write(s, $10);
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# 5x1 on 4 workers -> 20x1
23+
X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
24+
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
25+
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
26+
27+
# 5x1 on 4 workers -> 20x1
28+
Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
29+
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
30+
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
31+
32+
# 5x1 on 4 workers -> 20x1
33+
W = federated(addresses=list($in_W1, $in_W2, $in_W3, $in_W4),
34+
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
35+
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
36+
37+
s = cov(X, Y, W);
38+
write(s, $out_S);
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
X = rbind(read($1), read($2), read($3), read($4)); # 20x1
23+
Y = rbind(read($5), read($6), read($7), read($8)); # 20x1
24+
W = rbind(read($9), read($10), read($11), read($12)); # 20x1
25+
26+
s = cov(X, Y, W);
27+
write(s, $13);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# 5x1 on 4 workers -> 20x1
23+
X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
24+
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
25+
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
26+
27+
Y = read($Y); # 20x1
28+
W = read($in_W1); # 20x1
29+
30+
s = cov(X, Y, W);
31+
write(s, $out_S);

0 commit comments

Comments
 (0)