Skip to content

Commit 9a318ee

Browse files
committed
[SYSTEMDS-3796] Fix robustness federated weighted covariance and tests
1 parent a6a1509 commit 9a318ee

File tree

3 files changed

+19
-35
lines changed

3 files changed

+19
-35
lines changed

src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,8 @@ public FederatedRequest[] broadcastSliced(CacheableData<?> data, boolean transpo
147147
return broadcastSliced(data, null, transposed);
148148
}
149149

150-
public FederatedRequest[] broadcastSliced(MatrixLineagePair moLin,
151-
boolean transposed) {
152-
return broadcastSliced(moLin.getMO(), moLin.getLI(),
153-
transposed);
150+
public FederatedRequest[] broadcastSliced(MatrixLineagePair moLin, boolean transposed) {
151+
return broadcastSliced(moLin.getMO(), moLin.getLI(), transposed);
154152
}
155153

156154
/**

src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -114,27 +114,20 @@ private void processAlignedFedCov(ExecutionContext ec, MatrixObject mo1, MatrixO
114114
new CPOperand[]{input1, input2, input3}, new long[]{mo1.getFedMapping().getID(),
115115
mo2.getFedMapping().getID(), moLin3.getFedMapping().getID()});
116116
}
117-
117+
118118
FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
119119
FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
120-
Future<FederatedResponse>[] covTmp = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
121-
122-
//means
123-
Future<FederatedResponse>[] meanTmp1 = processMean(mo1, moLin3, 0);
124-
Future<FederatedResponse>[] meanTmp2 = processMean(mo2, moLin3, 1);
125-
126-
Double[] cov = getResponses(covTmp);
127-
Double[] mean1 = getResponses(meanTmp1);
128-
Double[] mean2 = getResponses(meanTmp2);
120+
Double[] cov = getResponses(mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3));
121+
Double[] mean1 = getResponses(processMean(mo1, moLin3, 0));
122+
Double[] mean2 = getResponses(processMean(mo2, moLin3, 1));
129123

130124
if (moLin3 == null) {
131125
double result = aggCov(cov, mean1, mean2, mo1.getFedMapping().getFederatedRanges());
132126
ec.setVariable(output.getName(), new DoubleObject(result));
133127
}
134128
else {
135-
Future<FederatedResponse>[] weightsSumTmp = getWeightsSum(moLin3, moLin3.getFedMapping().getID(), instString, moLin3.getFedMapping());
136-
Double[] weights = getResponses(weightsSumTmp);
137-
129+
Double[] weights = getResponses(
130+
getWeightsSum(moLin3, moLin3.getFedMapping().getID(), instString, moLin3.getFedMapping()));
138131
double result = aggWeightedCov(cov, mean1, mean2, weights);
139132
ec.setVariable(output.getName(), new DoubleObject(result));
140133
}
@@ -154,21 +147,13 @@ private void processFedCovWeights(ExecutionContext ec, MatrixObject mo1, MatrixO
154147
new CPOperand[]{input1, input2, input3},
155148
new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), fr1[0].getID()}
156149
);
150+
//sequential execution of cov and means for robustness
157151
FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
158152
FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), fr2.getID());
159-
Future<FederatedResponse>[] covTmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
160-
161-
//means
162-
Future<FederatedResponse>[] meanTmp1 = processMean(mo1, 0, fr1[0].getID());
163-
Future<FederatedResponse>[] meanTmp2 = processMean(mo2, 1, fr1[0].getID());
164-
165-
Double[] cov = getResponses(covTmp);
166-
Double[] mean1 = getResponses(meanTmp1);
167-
Double[] mean2 = getResponses(meanTmp2);
168-
169-
Future<FederatedResponse>[] weightsSumTmp = getWeightsSum(moLin3, fr1[0].getID(), instString, mo1.getFedMapping());
170-
Double[] weights = getResponses(weightsSumTmp);
171-
153+
Double[] cov = getResponses(mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3, fr4));
154+
Double[] mean1 = getResponses(processMean(mo1, 0, fr1[0].getID()));
155+
Double[] mean2 = getResponses(processMean(mo2, 1, fr1[0].getID()));
156+
Double[] weights = getResponses(getWeightsSum(moLin3, fr1[0].getID(), instString, mo1.getFedMapping()));
172157
double result = aggWeightedCov(cov, mean1, mean2, weights);
173158
ec.setVariable(output.getName(), new DoubleObject(result));
174159
}
@@ -243,7 +228,7 @@ private static Double[] getResponses(Future<FederatedResponse>[] ffr) {
243228
fr[i] = ((ScalarObject) ffr[i].get().getData()[0]).getDoubleValue();
244229
}
245230
catch(Exception e) {
246-
throw new DMLRuntimeException("CovarianceFEDInstruction: incorrect means or cov.");
231+
throw new DMLRuntimeException("CovarianceFEDInstruction: incorrect means or cov.", e);
247232
}
248233
});
249234

@@ -302,7 +287,7 @@ private static double aggWeightedCov(Double[] covValues, Double[] mean1, Double[
302287
}
303288

304289
private Future<FederatedResponse>[] processMean(MatrixObject mo1, MatrixLineagePair moLin3, int var){
305-
String[] parts = instString.split("°");
290+
String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
306291
Future<FederatedResponse>[] meanTmp = null;
307292
if (moLin3 == null) {
308293
String meanInstr = instString.replace(getOpcode(), getOpcode().replace("cov", "uamean"));

src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,17 @@ public class FederatedCovarianceTest extends AutomatedTestBase {
4949
private final static String TEST_DIR = "functions/federated/";
5050
private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCovarianceTest.class.getSimpleName() + "/";
5151

52-
private final static int blocksize = 1024;
52+
private final static int blocksize = 1000;
5353
@Parameterized.Parameter
5454
public int rows;
5555
@Parameterized.Parameter(1)
5656
public int cols;
5757

5858
@Parameterized.Parameters
5959
public static Collection<Object[]> data() {
60-
return Arrays.asList(new Object[][] {{20, 1},
61-
// {100, 1}, {1000, 1}
60+
return Arrays.asList(new Object[][] {
61+
{120, 1},
62+
{1100, 1},
6263
});
6364
}
6465

0 commit comments

Comments
 (0)