@@ -43,6 +43,9 @@ public class FederatedCovarianceTest extends AutomatedTestBase {
4343
4444 private final static String TEST_NAME1 = "FederatedCovarianceTest" ;
4545 private final static String TEST_NAME2 = "FederatedCovarianceAlignedTest" ;
46+ private final static String TEST_NAME3 = "FederatedCovarianceWeightedTest" ;
47+ private final static String TEST_NAME4 = "FederatedCovarianceAlignedWeightedTest" ;
48+ private final static String TEST_NAME5 = "FederatedCovarianceAllAlignedWeightedTest" ;
4649 private final static String TEST_DIR = "functions/federated/" ;
4750 private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCovarianceTest .class .getSimpleName () + "/" ;
4851
@@ -64,19 +67,37 @@ public void setUp() {
6467 TestUtils .clearAssertionInformation ();
6568 addTestConfiguration (TEST_NAME1 , new TestConfiguration (TEST_CLASS_DIR , TEST_NAME1 , new String [] {"S.scalar" }));
6669 addTestConfiguration (TEST_NAME2 , new TestConfiguration (TEST_CLASS_DIR , TEST_NAME2 , new String [] {"S.scalar" }));
70+ addTestConfiguration (TEST_NAME3 , new TestConfiguration (TEST_CLASS_DIR , TEST_NAME3 , new String [] {"S.scalar" }));
71+ addTestConfiguration (TEST_NAME4 , new TestConfiguration (TEST_CLASS_DIR , TEST_NAME4 , new String [] {"S.scalar" }));
72+ addTestConfiguration (TEST_NAME5 , new TestConfiguration (TEST_CLASS_DIR , TEST_NAME5 , new String [] {"S.scalar" }));
6773 }
6874
6975 @ Test
7076 public void testCovCP () {
71- runCovTest (ExecMode .SINGLE_NODE , false );
77+ runCovarianceTest (ExecMode .SINGLE_NODE , false );
7278 }
7379
7480 @ Test
7581 public void testAlignedCovCP () {
76- runCovTest (ExecMode .SINGLE_NODE , true );
82+ runCovarianceTest (ExecMode .SINGLE_NODE , true );
7783 }
7884
79- private void runCovTest (ExecMode execMode , boolean alignedFedInput ) {
85+ @ Test
86+ public void testCovarianceWeightedCP () {
87+ runWeightedCovarianceTest (ExecMode .SINGLE_NODE , false , false );
88+ }
89+
90+ @ Test
91+ public void testAlignedCovarianceWeightedCP () {
92+ runWeightedCovarianceTest (ExecMode .SINGLE_NODE , true , false );
93+ }
94+
95+ @ Test
96+ public void testAllAlignedCovarianceWeightedCP () {
97+ runWeightedCovarianceTest (ExecMode .SINGLE_NODE , true , true );
98+ }
99+
100+ private void runCovarianceTest (ExecMode execMode , boolean alignedFedInput ) {
80101 boolean sparkConfigOld = DMLScript .USE_LOCAL_SPARK_CONFIG ;
81102 ExecMode platformOld = rtplatform ;
82103
@@ -190,4 +211,221 @@ private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
190211 DMLScript .USE_LOCAL_SPARK_CONFIG = sparkConfigOld ;
191212 }
192213 }
214+
215+ private void runWeightedCovarianceTest (ExecMode execMode , boolean alignedInput , boolean alignedWeights ) {
216+ boolean sparkConfigOld = DMLScript .USE_LOCAL_SPARK_CONFIG ;
217+ ExecMode platformOld = rtplatform ;
218+
219+ if (rtplatform == ExecMode .SPARK )
220+ DMLScript .USE_LOCAL_SPARK_CONFIG = true ;
221+
222+ String TEST_NAME = !alignedInput ? TEST_NAME3 : (!alignedWeights ? TEST_NAME4 : TEST_NAME5 );
223+ getAndLoadTestConfiguration (TEST_NAME );
224+
225+ String HOME = SCRIPT_DIR + TEST_DIR ;
226+
227+ int r = rows / 4 ;
228+ int c = cols ;
229+
230+ fullDMLScriptName = "" ;
231+
232+ // Create 4 random 5x1 matrices
233+ double [][] X1 = getRandomMatrix (r , c , 1 , 5 , 1 , 3 );
234+ double [][] X2 = getRandomMatrix (r , c , 1 , 5 , 1 , 7 );
235+ double [][] X3 = getRandomMatrix (r , c , 1 , 5 , 1 , 8 );
236+ double [][] X4 = getRandomMatrix (r , c , 1 , 5 , 1 , 9 );
237+
238+ // Create a 20x1 weights matrix
239+ double [][] W = getRandomMatrix (rows , c , 0 , 1 , 1 , 3 );
240+
241+ MatrixCharacteristics mc = new MatrixCharacteristics (r , c , blocksize , r * c );
242+ writeInputMatrixWithMTD ("X1" , X1 , false , mc );
243+ writeInputMatrixWithMTD ("X2" , X2 , false , mc );
244+ writeInputMatrixWithMTD ("X3" , X3 , false , mc );
245+ writeInputMatrixWithMTD ("X4" , X4 , false , mc );
246+
247+ writeInputMatrixWithMTD ("W" , W , false , new MatrixCharacteristics (rows , cols , blocksize , r * c ));
248+
249+ // empty script name because we don't execute any script, just start the worker
250+ fullDMLScriptName = "" ;
251+ int port1 = getRandomAvailablePort ();
252+ int port2 = getRandomAvailablePort ();
253+ int port3 = getRandomAvailablePort ();
254+ int port4 = getRandomAvailablePort ();
255+
256+ Process t1 = startLocalFedWorker (port1 , FED_WORKER_WAIT_S );
257+ Process t2 = startLocalFedWorker (port2 , FED_WORKER_WAIT_S );
258+ Process t3 = startLocalFedWorker (port3 , FED_WORKER_WAIT_S );
259+ Process t4 = startLocalFedWorker (port4 );
260+
261+ try {
262+ if (!isAlive (t1 , t2 , t3 , t4 ))
263+ throw new RuntimeException ("Failed starting federated worker" );
264+
265+ rtplatform = execMode ;
266+ if (rtplatform == ExecMode .SPARK ) {
267+ System .out .println (7 );
268+ DMLScript .USE_LOCAL_SPARK_CONFIG = true ;
269+ }
270+
271+ TestConfiguration config = availableTestConfigurations .get (TEST_NAME );
272+ loadTestConfiguration (config );
273+
274+ if (alignedInput ) {
275+ // Create 4 random 5x1 matrices
276+ double [][] Y1 = getRandomMatrix (r , c , 1 , 5 , 1 , 3 );
277+ double [][] Y2 = getRandomMatrix (r , c , 1 , 5 , 1 , 7 );
278+ double [][] Y3 = getRandomMatrix (r , c , 1 , 5 , 1 , 8 );
279+ double [][] Y4 = getRandomMatrix (r , c , 1 , 5 , 1 , 9 );
280+
281+ writeInputMatrixWithMTD ("Y1" , Y1 , false , mc );
282+ writeInputMatrixWithMTD ("Y2" , Y2 , false , mc );
283+ writeInputMatrixWithMTD ("Y3" , Y3 , false , mc );
284+ writeInputMatrixWithMTD ("Y4" , Y4 , false , mc );
285+
286+ if (!alignedWeights ) {
287+ // Run reference dml script with a normal matrix
288+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml" ;
289+ programArgs = new String [] {
290+ "-stats" , "100" , "-args" ,
291+ input ("X1" ),
292+ input ("X2" ),
293+ input ("X3" ),
294+ input ("X4" ),
295+
296+ input ("Y1" ),
297+ input ("Y2" ),
298+ input ("Y3" ),
299+ input ("Y4" ),
300+
301+ input ("W" ),
302+ expected ("S" )
303+ };
304+ runTest (null );
305+
306+ // Run the dml script with federated matrices
307+ fullDMLScriptName = HOME + TEST_NAME + ".dml" ;
308+ programArgs = new String [] {"-stats" , "100" , "-nvargs" ,
309+ "in_X1=" + TestUtils .federatedAddress (port1 , input ("X1" )),
310+ "in_Y1=" + TestUtils .federatedAddress (port1 , input ("Y1" )),
311+
312+ "in_X2=" + TestUtils .federatedAddress (port2 , input ("X2" )),
313+ "in_Y2=" + TestUtils .federatedAddress (port2 , input ("Y2" )),
314+
315+ "in_X3=" + TestUtils .federatedAddress (port3 , input ("X3" )),
316+ "in_Y3=" + TestUtils .federatedAddress (port3 , input ("Y3" )),
317+
318+ "in_X4=" + TestUtils .federatedAddress (port4 , input ("X4" )),
319+ "in_Y4=" + TestUtils .federatedAddress (port4 , input ("Y4" )),
320+
321+ "in_W1=" + input ("W" ),
322+ "rows=" + rows , "cols=" + cols ,
323+ "out_S=" + output ("S" )};
324+ runTest (null );
325+ }
326+ else {
327+ double [][] W1 = getRandomMatrix (r , c , 0 , 1 , 1 , 3 );
328+ double [][] W2 = getRandomMatrix (r , c , 0 , 1 , 1 , 7 );
329+ double [][] W3 = getRandomMatrix (r , c , 0 , 1 , 1 , 8 );
330+ double [][] W4 = getRandomMatrix (r , c , 0 , 1 , 1 , 9 );
331+
332+ writeInputMatrixWithMTD ("W1" , W1 , false , mc );
333+ writeInputMatrixWithMTD ("W2" , W2 , false , mc );
334+ writeInputMatrixWithMTD ("W3" , W3 , false , mc );
335+ writeInputMatrixWithMTD ("W4" , W4 , false , mc );
336+
337+ // Run reference dml script with a normal matrix
338+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml" ;
339+ programArgs = new String [] {
340+ "-stats" , "100" , "-args" ,
341+ input ("X1" ),
342+ input ("X2" ),
343+ input ("X3" ),
344+ input ("X4" ),
345+
346+ input ("Y1" ),
347+ input ("Y2" ),
348+ input ("Y3" ),
349+ input ("Y4" ),
350+
351+ input ("W1" ),
352+ input ("W2" ),
353+ input ("W3" ),
354+ input ("W4" ),
355+
356+ expected ("S" )
357+ };
358+ runTest (null );
359+
360+ // Run the dml script with federated matrices and weights
361+ fullDMLScriptName = HOME + TEST_NAME + ".dml" ;
362+ programArgs = new String [] {"-stats" , "100" , "-nvargs" ,
363+ "in_X1=" + TestUtils .federatedAddress (port1 , input ("X1" )),
364+ "in_Y1=" + TestUtils .federatedAddress (port1 , input ("Y1" )),
365+ "in_W1=" + TestUtils .federatedAddress (port1 , input ("W1" )),
366+
367+ "in_X2=" + TestUtils .federatedAddress (port2 , input ("X2" )),
368+ "in_Y2=" + TestUtils .federatedAddress (port2 , input ("Y2" )),
369+ "in_W2=" + TestUtils .federatedAddress (port2 , input ("W2" )),
370+
371+ "in_X3=" + TestUtils .federatedAddress (port3 , input ("X3" )),
372+ "in_Y3=" + TestUtils .federatedAddress (port3 , input ("Y3" )),
373+ "in_W3=" + TestUtils .federatedAddress (port3 , input ("W3" )),
374+
375+ "in_X4=" + TestUtils .federatedAddress (port4 , input ("X4" )),
376+ "in_Y4=" + TestUtils .federatedAddress (port4 , input ("Y4" )),
377+ "in_W4=" + TestUtils .federatedAddress (port4 , input ("W4" )),
378+
379+ "rows=" + rows , "cols=" + cols ,
380+ "out_S=" + output ("S" )};
381+ runTest (null );
382+ }
383+
384+ }
385+ else {
386+ // Create a random 20x1 input matrix
387+ double [][] Y = getRandomMatrix (rows , c , 1 , 5 , 1 , 3 );
388+ writeInputMatrixWithMTD ("Y" , Y , false , new MatrixCharacteristics (rows , cols , blocksize , r * c ));
389+
390+ // Run reference dml script with a normal matrix
391+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml" ;
392+ programArgs = new String [] {
393+ "-stats" , "100" , "-args" ,
394+ input ("X1" ),
395+ input ("X2" ),
396+ input ("X3" ),
397+ input ("X4" ),
398+
399+ input ("Y" ), input ("W" ), expected ("S" )
400+ };
401+ runTest (null );
402+
403+ // Run the dml script with a federated matrix
404+ fullDMLScriptName = HOME + TEST_NAME + ".dml" ;
405+ programArgs = new String [] {"-stats" , "100" , "-nvargs" ,
406+ "in_X1=" + TestUtils .federatedAddress (port1 , input ("X1" )),
407+ "in_X2=" + TestUtils .federatedAddress (port2 , input ("X2" )),
408+ "in_X3=" + TestUtils .federatedAddress (port3 , input ("X3" )),
409+ "in_X4=" + TestUtils .federatedAddress (port4 , input ("X4" )),
410+
411+ "in_W1=" + input ("W" ),
412+ "Y=" + input ("Y" ),
413+
414+ "rows=" + rows ,
415+ "cols=" + cols ,
416+ "out_S=" + output ("S" )};
417+ runTest (null );
418+ }
419+
420+ // compare via files
421+ compareResults (1e-2 );
422+ Assert .assertTrue (heavyHittersContainsString ("fed_cov" ));
423+
424+ }
425+ finally {
426+ TestUtils .shutdownThreads (t1 , t2 , t3 , t4 );
427+ rtplatform = platformOld ;
428+ DMLScript .USE_LOCAL_SPARK_CONFIG = sparkConfigOld ;
429+ }
430+ }
193431}
0 commit comments