Skip to content

Commit 31aff0e

Browse files
committed
[MINOR] Federated Compressed Workload Estimation Fixes
This commit fixes a bug for asynchronous compression on federated workers. Previously, the compression would only tigger if the sum of federated requests instructions summed to % 10 == 9. This bug effectively made it impossible to perform compression if all requests send contained an even number of instructions. This commit change the logic to instruction counter >= 10. Closes #2159 Signed-off-by: Sebastian Baunsgaard <[email protected]>
1 parent 29b4d92 commit 31aff0e

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class FederatedWorkloadAnalyzer {
3535
protected static final Log LOG = LogFactory.getLog(FederatedWorkerHandler.class.getName());
3636

3737
/** Frequency value for how many instructions before we do a pass for compression */
38-
private static int compressRunFrequency = 10;
38+
private static final int compressRunFrequency = 10;
3939

4040
/** Instruction maps to interesting variables */
4141
private final ConcurrentHashMap<Long, ConcurrentHashMap<Long, InstructionTypeCounter>> m;
@@ -49,14 +49,17 @@ public FederatedWorkloadAnalyzer() {
4949
}
5050

5151
public void incrementWorkload(ExecutionContext ec, long tid, Instruction ins) {
52+
LOG.error("Increment Workload " + tid + " " + ins + "\n" + this);
5253
if(ins instanceof ComputationCPInstruction)
5354
incrementWorkload(ec, tid, (ComputationCPInstruction) ins);
5455
// currently we ignore everything that is not CP instructions
5556
}
5657

5758
public void compressRun(ExecutionContext ec, long tid) {
58-
if(counter % compressRunFrequency == compressRunFrequency - 1)
59+
if(counter >= compressRunFrequency ){
60+
counter = 0;
5961
get(tid).forEach((K, V) -> CompressedMatrixBlockFactory.compressAsync(ec, Long.toString(K), V));
62+
}
6063
}
6164

6265
private void incrementWorkload(ExecutionContext ec, long tid, ComputationCPInstruction cpIns) {
@@ -77,13 +80,16 @@ public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap<Long, Instr
7780
int r2 = (int) d2.getDim(0);
7881
int c2 = (int) d2.getDim(1);
7982
if(validSize(r1, c1)) {
80-
getOrMakeCounter(mm, Long.parseLong(n1)).incRMM(r1);
83+
getOrMakeCounter(mm, Long.parseLong(n1)).incRMM(c2);
84+
// safety add overlapping decompress for RMM
85+
getOrMakeCounter(mm, Long.parseLong(n1)).incOverlappingDecompressions();
8186
counter++;
8287
}
8388
if(validSize(r2, c2)) {
84-
getOrMakeCounter(mm, Long.parseLong(n2)).incLMM(c2);
89+
getOrMakeCounter(mm, Long.parseLong(n2)).incLMM(r1);
8590
counter++;
8691
}
92+
8793
}
8894
}
8995

@@ -111,4 +117,16 @@ private ConcurrentHashMap<Long, InstructionTypeCounter> get(long tid) {
111117
private static boolean validSize(int nRow, int nCol) {
112118
return nRow > 90 && nRow >= nCol;
113119
}
120+
121+
@Override
122+
public String toString(){
123+
StringBuilder sb = new StringBuilder();
124+
sb.append(this.getClass().getSimpleName());
125+
sb.append(" Counter: ");
126+
sb.append(counter);
127+
sb.append("\n");
128+
sb.append(m);
129+
130+
return sb.toString();
131+
}
114132
}

0 commit comments

Comments
 (0)