@@ -43,6 +43,9 @@ public class FederatedCovarianceTest extends AutomatedTestBase {
4343
4444 private final static String TEST_NAME1 = "FederatedCovarianceTest" ;
4545 private final static String TEST_NAME2 = "FederatedCovarianceAlignedTest" ;
46+ private final static String TEST_NAME3 = "FederatedCovarianceWeightedTest" ;
47+ private final static String TEST_NAME4 = "FederatedCovarianceAlignedWeightedTest" ;
48+ private final static String TEST_NAME5 = "FederatedCovarianceAllAlignedWeightedTest" ;
4649 private final static String TEST_DIR = "functions/federated/" ;
4750 private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCovarianceTest .class .getSimpleName () + "/" ;
4851
@@ -64,19 +67,37 @@ public void setUp() {
6467 TestUtils .clearAssertionInformation ();
6568 addTestConfiguration (TEST_NAME1 , new TestConfiguration (TEST_CLASS_DIR , TEST_NAME1 , new String [] {"S.scalar" }));
6669 addTestConfiguration (TEST_NAME2 , new TestConfiguration (TEST_CLASS_DIR , TEST_NAME2 , new String [] {"S.scalar" }));
70+ addTestConfiguration (TEST_NAME3 , new TestConfiguration (TEST_CLASS_DIR , TEST_NAME3 , new String [] {"S.scalar" }));
71+ addTestConfiguration (TEST_NAME4 , new TestConfiguration (TEST_CLASS_DIR , TEST_NAME4 , new String [] {"S.scalar" }));
72+ addTestConfiguration (TEST_NAME5 , new TestConfiguration (TEST_CLASS_DIR , TEST_NAME5 , new String [] {"S.scalar" }));
6773 }
6874
6975 @ Test
7076 public void testCovCP () {
71- runCovTest (ExecMode .SINGLE_NODE , false );
77+ runCovarianceTest (ExecMode .SINGLE_NODE , false );
7278 }
7379
7480 @ Test
7581 public void testAlignedCovCP () {
76- runCovTest (ExecMode .SINGLE_NODE , true );
82+ runCovarianceTest (ExecMode .SINGLE_NODE , true );
7783 }
7884
79- private void runCovTest (ExecMode execMode , boolean alignedFedInput ) {
85+ @ Test
86+ public void testCovarianceWeightedCP () {
87+ runWeightedCovarianceTest (ExecMode .SINGLE_NODE , false , false );
88+ }
89+
90+ @ Test
91+ public void testAlignedCovarianceWeightedCP () {
92+ runWeightedCovarianceTest (ExecMode .SINGLE_NODE , true , false );
93+ }
94+
95+ @ Test
96+ public void testAllAlignedCovarianceWeightedCP () {
97+ runWeightedCovarianceTest (ExecMode .SINGLE_NODE , true , true );
98+ }
99+
100+ private void runCovarianceTest (ExecMode execMode , boolean alignedFedInput ) {
80101 boolean sparkConfigOld = DMLScript .USE_LOCAL_SPARK_CONFIG ;
81102 ExecMode platformOld = rtplatform ;
82103
@@ -190,4 +211,176 @@ private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
190211 DMLScript .USE_LOCAL_SPARK_CONFIG = sparkConfigOld ;
191212 }
192213 }
214+
215+ private void runWeightedCovarianceTest (ExecMode execMode , boolean alignedInput , boolean alignedWeights ) {
216+ boolean sparkConfigOld = DMLScript .USE_LOCAL_SPARK_CONFIG ;
217+ ExecMode platformOld = rtplatform ;
218+
219+ if (rtplatform == ExecMode .SPARK )
220+ DMLScript .USE_LOCAL_SPARK_CONFIG = true ;
221+
222+ String TEST_NAME = !alignedInput ? TEST_NAME3 : (!alignedWeights ? TEST_NAME4 : TEST_NAME5 );
223+ getAndLoadTestConfiguration (TEST_NAME );
224+
225+ String HOME = SCRIPT_DIR + TEST_DIR ;
226+
227+ int r = rows / 4 ;
228+ int c = cols ;
229+
230+ fullDMLScriptName = "" ;
231+
232+ // Create 4 random 5x1 matrices
233+ double [][] X1 = getRandomMatrix (r , c , 1 , 5 , 1 , 3 );
234+ double [][] X2 = getRandomMatrix (r , c , 1 , 5 , 1 , 7 );
235+ double [][] X3 = getRandomMatrix (r , c , 1 , 5 , 1 , 8 );
236+ double [][] X4 = getRandomMatrix (r , c , 1 , 5 , 1 , 9 );
237+
238+ // Create a 20x1 weights matrix
239+ double [][] W = getRandomMatrix (rows , c , 0 , 1 , 1 , 3 );
240+
241+ MatrixCharacteristics mc = new MatrixCharacteristics (r , c , blocksize , r * c );
242+ writeInputMatrixWithMTD ("X1" , X1 , false , mc );
243+ writeInputMatrixWithMTD ("X2" , X2 , false , mc );
244+ writeInputMatrixWithMTD ("X3" , X3 , false , mc );
245+ writeInputMatrixWithMTD ("X4" , X4 , false , mc );
246+
247+ writeInputMatrixWithMTD ("W" , W , false , new MatrixCharacteristics (rows , cols , blocksize , r * c ));
248+
249+ // empty script name because we don't execute any script, just start the worker
250+ fullDMLScriptName = "" ;
251+ int port1 = getRandomAvailablePort ();
252+ int port2 = getRandomAvailablePort ();
253+ int port3 = getRandomAvailablePort ();
254+ int port4 = getRandomAvailablePort ();
255+
256+ Process t1 = startLocalFedWorker (port1 , FED_WORKER_WAIT_S );
257+ Process t2 = startLocalFedWorker (port2 , FED_WORKER_WAIT_S );
258+ Process t3 = startLocalFedWorker (port3 , FED_WORKER_WAIT_S );
259+ Process t4 = startLocalFedWorker (port4 );
260+
261+ try {
262+ if (!isAlive (t1 , t2 , t3 , t4 ))
263+ throw new RuntimeException ("Failed starting federated worker" );
264+
265+ rtplatform = execMode ;
266+ if (rtplatform == ExecMode .SPARK ) {
267+ System .out .println (7 );
268+ DMLScript .USE_LOCAL_SPARK_CONFIG = true ;
269+ }
270+
271+ TestConfiguration config = availableTestConfigurations .get (TEST_NAME );
272+ loadTestConfiguration (config );
273+
274+ if (alignedInput ) {
275+ // Create 4 random 5x1 matrices
276+ double [][] Y1 = getRandomMatrix (r , c , 1 , 5 , 1 , 3 );
277+ double [][] Y2 = getRandomMatrix (r , c , 1 , 5 , 1 , 7 );
278+ double [][] Y3 = getRandomMatrix (r , c , 1 , 5 , 1 , 8 );
279+ double [][] Y4 = getRandomMatrix (r , c , 1 , 5 , 1 , 9 );
280+
281+ writeInputMatrixWithMTD ("Y1" , Y1 , false , mc );
282+ writeInputMatrixWithMTD ("Y2" , Y2 , false , mc );
283+ writeInputMatrixWithMTD ("Y3" , Y3 , false , mc );
284+ writeInputMatrixWithMTD ("Y4" , Y4 , false , mc );
285+
286+ if (!alignedWeights ) {
287+ // Run reference dml script with a normal matrix
288+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml" ;
289+ programArgs = new String [] { "-stats" , "100" , "-args" ,
290+ input ("X1" ), input ("X2" ), input ("X3" ), input ("X4" ),
291+ input ("Y1" ), input ("Y2" ), input ("Y3" ), input ("Y4" ),
292+ input ("W" ), expected ("S" )
293+ };
294+ runTest (null );
295+
296+ // Run the dml script with federated matrices
297+ fullDMLScriptName = HOME + TEST_NAME + ".dml" ;
298+ programArgs = new String [] {"-stats" , "100" , "-nvargs" ,
299+ "in_X1=" + TestUtils .federatedAddress (port1 , input ("X1" )),
300+ "in_Y1=" + TestUtils .federatedAddress (port1 , input ("Y1" )),
301+ "in_X2=" + TestUtils .federatedAddress (port2 , input ("X2" )),
302+ "in_Y2=" + TestUtils .federatedAddress (port2 , input ("Y2" )),
303+ "in_X3=" + TestUtils .federatedAddress (port3 , input ("X3" )),
304+ "in_Y3=" + TestUtils .federatedAddress (port3 , input ("Y3" )),
305+ "in_X4=" + TestUtils .federatedAddress (port4 , input ("X4" )),
306+ "in_Y4=" + TestUtils .federatedAddress (port4 , input ("Y4" )),
307+ "in_W1=" + input ("W" ), "rows=" + rows , "cols=" + cols , "out_S=" + output ("S" )};
308+ runTest (null );
309+ }
310+ else {
311+ double [][] W1 = getRandomMatrix (r , c , 0 , 1 , 1 , 3 );
312+ double [][] W2 = getRandomMatrix (r , c , 0 , 1 , 1 , 7 );
313+ double [][] W3 = getRandomMatrix (r , c , 0 , 1 , 1 , 8 );
314+ double [][] W4 = getRandomMatrix (r , c , 0 , 1 , 1 , 9 );
315+
316+ writeInputMatrixWithMTD ("W1" , W1 , false , mc );
317+ writeInputMatrixWithMTD ("W2" , W2 , false , mc );
318+ writeInputMatrixWithMTD ("W3" , W3 , false , mc );
319+ writeInputMatrixWithMTD ("W4" , W4 , false , mc );
320+
321+ // Run reference dml script with a normal matrix
322+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml" ;
323+ programArgs = new String [] {"-stats" , "100" , "-args" ,
324+ input ("X1" ), input ("X2" ), input ("X3" ), input ("X4" ),
325+ input ("Y1" ), input ("Y2" ), input ("Y3" ), input ("Y4" ),
326+ input ("W1" ), input ("W2" ), input ("W3" ), input ("W4" ), expected ("S" )
327+ };
328+ runTest (null );
329+
330+ // Run the dml script with federated matrices and weights
331+ fullDMLScriptName = HOME + TEST_NAME + ".dml" ;
332+ programArgs = new String [] {"-stats" , "100" , "-nvargs" ,
333+ "in_X1=" + TestUtils .federatedAddress (port1 , input ("X1" )),
334+ "in_Y1=" + TestUtils .federatedAddress (port1 , input ("Y1" )),
335+ "in_W1=" + TestUtils .federatedAddress (port1 , input ("W1" )),
336+ "in_X2=" + TestUtils .federatedAddress (port2 , input ("X2" )),
337+ "in_Y2=" + TestUtils .federatedAddress (port2 , input ("Y2" )),
338+ "in_W2=" + TestUtils .federatedAddress (port2 , input ("W2" )),
339+ "in_X3=" + TestUtils .federatedAddress (port3 , input ("X3" )),
340+ "in_Y3=" + TestUtils .federatedAddress (port3 , input ("Y3" )),
341+ "in_W3=" + TestUtils .federatedAddress (port3 , input ("W3" )),
342+ "in_X4=" + TestUtils .federatedAddress (port4 , input ("X4" )),
343+ "in_Y4=" + TestUtils .federatedAddress (port4 , input ("Y4" )),
344+ "in_W4=" + TestUtils .federatedAddress (port4 , input ("W4" )),
345+ "rows=" + rows , "cols=" + cols , "out_S=" + output ("S" )};
346+ runTest (null );
347+ }
348+
349+ }
350+ else {
351+ // Create a random 20x1 input matrix
352+ double [][] Y = getRandomMatrix (rows , c , 1 , 5 , 1 , 3 );
353+ writeInputMatrixWithMTD ("Y" , Y , false , new MatrixCharacteristics (rows , cols , blocksize , r * c ));
354+
355+ // Run reference dml script with a normal matrix
356+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml" ;
357+ programArgs = new String [] {"-stats" , "100" , "-args" ,
358+ input ("X1" ), input ("X2" ), input ("X3" ), input ("X4" ),
359+ input ("Y" ), input ("W" ), expected ("S" )
360+ };
361+ runTest (null );
362+
363+ // Run the dml script with a federated matrix
364+ fullDMLScriptName = HOME + TEST_NAME + ".dml" ;
365+ programArgs = new String [] {"-stats" , "100" , "-nvargs" ,
366+ "in_X1=" + TestUtils .federatedAddress (port1 , input ("X1" )),
367+ "in_X2=" + TestUtils .federatedAddress (port2 , input ("X2" )),
368+ "in_X3=" + TestUtils .federatedAddress (port3 , input ("X3" )),
369+ "in_X4=" + TestUtils .federatedAddress (port4 , input ("X4" )),
370+ "in_W1=" + input ("W" ), "Y=" + input ("Y" ),
371+ "rows=" + rows , "cols=" + cols , "out_S=" + output ("S" )};
372+ runTest (null );
373+ }
374+
375+ // compare via files
376+ compareResults (1e-2 );
377+ Assert .assertTrue (heavyHittersContainsString ("fed_cov" ));
378+
379+ }
380+ finally {
381+ TestUtils .shutdownThreads (t1 , t2 , t3 , t4 );
382+ rtplatform = platformOld ;
383+ DMLScript .USE_LOCAL_SPARK_CONFIG = sparkConfigOld ;
384+ }
385+ }
193386}
0 commit comments