@@ -114,27 +114,20 @@ private void processAlignedFedCov(ExecutionContext ec, MatrixObject mo1, MatrixO
114114 new CPOperand []{input1 , input2 , input3 }, new long []{mo1 .getFedMapping ().getID (),
115115 mo2 .getFedMapping ().getID (), moLin3 .getFedMapping ().getID ()});
116116 }
117-
117+
118118 FederatedRequest fr2 = new FederatedRequest (FederatedRequest .RequestType .GET_VAR , fr1 .getID ());
119119 FederatedRequest fr3 = mo1 .getFedMapping ().cleanup (getTID (), fr1 .getID ());
120- Future <FederatedResponse >[] covTmp = mo1 .getFedMapping ().execute (getTID (), true , fr1 , fr2 , fr3 );
121-
122- //means
123- Future <FederatedResponse >[] meanTmp1 = processMean (mo1 , moLin3 , 0 );
124- Future <FederatedResponse >[] meanTmp2 = processMean (mo2 , moLin3 , 1 );
125-
126- Double [] cov = getResponses (covTmp );
127- Double [] mean1 = getResponses (meanTmp1 );
128- Double [] mean2 = getResponses (meanTmp2 );
120+ Double [] cov = getResponses (mo1 .getFedMapping ().execute (getTID (), fr1 , fr2 , fr3 ));
121+ Double [] mean1 = getResponses (processMean (mo1 , moLin3 , 0 ));
122+ Double [] mean2 = getResponses (processMean (mo2 , moLin3 , 1 ));
129123
130124 if (moLin3 == null ) {
131125 double result = aggCov (cov , mean1 , mean2 , mo1 .getFedMapping ().getFederatedRanges ());
132126 ec .setVariable (output .getName (), new DoubleObject (result ));
133127 }
134128 else {
135- Future <FederatedResponse >[] weightsSumTmp = getWeightsSum (moLin3 , moLin3 .getFedMapping ().getID (), instString , moLin3 .getFedMapping ());
136- Double [] weights = getResponses (weightsSumTmp );
137-
129+ Double [] weights = getResponses (
130+ getWeightsSum (moLin3 , moLin3 .getFedMapping ().getID (), instString , moLin3 .getFedMapping ()));
138131 double result = aggWeightedCov (cov , mean1 , mean2 , weights );
139132 ec .setVariable (output .getName (), new DoubleObject (result ));
140133 }
@@ -154,21 +147,13 @@ private void processFedCovWeights(ExecutionContext ec, MatrixObject mo1, MatrixO
154147 new CPOperand []{input1 , input2 , input3 },
155148 new long []{mo1 .getFedMapping ().getID (), mo2 .getFedMapping ().getID (), fr1 [0 ].getID ()}
156149 );
150+ //sequential execution of cov and means for robustness
157151 FederatedRequest fr3 = new FederatedRequest (FederatedRequest .RequestType .GET_VAR , fr2 .getID ());
158152 FederatedRequest fr4 = mo1 .getFedMapping ().cleanup (getTID (), fr2 .getID ());
159- Future <FederatedResponse >[] covTmp = mo1 .getFedMapping ().execute (getTID (), fr1 , fr2 , fr3 , fr4 );
160-
161- //means
162- Future <FederatedResponse >[] meanTmp1 = processMean (mo1 , 0 , fr1 [0 ].getID ());
163- Future <FederatedResponse >[] meanTmp2 = processMean (mo2 , 1 , fr1 [0 ].getID ());
164-
165- Double [] cov = getResponses (covTmp );
166- Double [] mean1 = getResponses (meanTmp1 );
167- Double [] mean2 = getResponses (meanTmp2 );
168-
169- Future <FederatedResponse >[] weightsSumTmp = getWeightsSum (moLin3 , fr1 [0 ].getID (), instString , mo1 .getFedMapping ());
170- Double [] weights = getResponses (weightsSumTmp );
171-
153+ Double [] cov = getResponses (mo1 .getFedMapping ().execute (getTID (), true , fr1 , fr2 , fr3 , fr4 ));
154+ Double [] mean1 = getResponses (processMean (mo1 , 0 , fr1 [0 ].getID ()));
155+ Double [] mean2 = getResponses (processMean (mo2 , 1 , fr1 [0 ].getID ()));
156+ Double [] weights = getResponses (getWeightsSum (moLin3 , fr1 [0 ].getID (), instString , mo1 .getFedMapping ()));
172157 double result = aggWeightedCov (cov , mean1 , mean2 , weights );
173158 ec .setVariable (output .getName (), new DoubleObject (result ));
174159 }
@@ -243,7 +228,7 @@ private static Double[] getResponses(Future<FederatedResponse>[] ffr) {
243228 fr [i ] = ((ScalarObject ) ffr [i ].get ().getData ()[0 ]).getDoubleValue ();
244229 }
245230 catch (Exception e ) {
246- throw new DMLRuntimeException ("CovarianceFEDInstruction: incorrect means or cov." );
231+ throw new DMLRuntimeException ("CovarianceFEDInstruction: incorrect means or cov." , e );
247232 }
248233 });
249234
@@ -302,7 +287,7 @@ private static double aggWeightedCov(Double[] covValues, Double[] mean1, Double[
302287 }
303288
304289 private Future <FederatedResponse >[] processMean (MatrixObject mo1 , MatrixLineagePair moLin3 , int var ){
305- String [] parts = instString .split ("°" );
290+ String [] parts = instString .split (Lop . OPERAND_DELIMITOR );
306291 Future <FederatedResponse >[] meanTmp = null ;
307292 if (moLin3 == null ) {
308293 String meanInstr = instString .replace (getOpcode (), getOpcode ().replace ("cov" , "uamean" ));
0 commit comments