Skip to content

Commit 9a09c45

Browse files
committed
[SYSTEMDS-3845] Federated Threading Bug
The federated back end spawn threads for parallel execution instead of using the threadpool. This commit fixes the issue by naming the worker threads to enable the threadpool usage. The performance on a local experiment using the FederatedKMeans test improve from 4.3 sec average to 3.2 sec average. To reproduce the results set the federated k-means test to repeat the federated call 20 times. Closes #2245
1 parent a0dd02b commit 9a09c45

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,10 @@ private static void exec(ExecutionContext ec, Instruction ins){
616616

617617
try {
618618
// execute single instruction
619+
// TODO move this thread naming to Netty thread creation!
620+
Thread curThread = Thread.currentThread();
621+
long id = curThread.getId();
622+
Thread.currentThread().setName("FedExec_"+ id);
619623
pb.execute(ec);
620624
}
621625
catch(Exception ex) {

src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ public synchronized static ExecutorService get(int k) {
115115
final boolean mainThread = threadName.contains("main");
116116
if(size == k && mainThread)
117117
return shared; // use the default thread pool if main thread and max parallelism.
118-
else if(mainThread || threadName.contains("PARFOR")) {
118+
else if(mainThread || threadName.contains("PARFOR") || threadName.contains("FedExec")) {
119119
CommonThreadPool pool;
120120
if(shared2 == null) // If there is no current shared pool allocate one.
121121
shared2 = new ConcurrentHashMap<>();

src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import java.util.Arrays;
2323
import java.util.Collection;
2424

25+
import org.apache.commons.logging.Log;
26+
import org.apache.commons.logging.LogFactory;
2527
import org.apache.sysds.common.Types;
2628
import org.apache.sysds.common.Types.ExecMode;
2729
import org.apache.sysds.runtime.controlprogram.ParForProgramBlock;
@@ -31,6 +33,7 @@
3133
import org.apache.sysds.test.AutomatedTestBase;
3234
import org.apache.sysds.test.TestConfiguration;
3335
import org.apache.sysds.test.TestUtils;
36+
import org.apache.sysds.utils.stats.Timing;
3437
import org.junit.Assert;
3538
import org.junit.Ignore;
3639
import org.junit.Test;
@@ -40,6 +43,7 @@
4043
@RunWith(value = Parameterized.class)
4144
@net.jcip.annotations.NotThreadSafe
4245
public class FederatedKmeansTest extends AutomatedTestBase {
46+
protected static final Log LOG = LogFactory.getLog(FederatedKmeansTest.class.getName());
4347

4448
private final static String TEST_DIR = "functions/federated/";
4549
private final static String TEST_NAME = "FederatedKmeansTest";
@@ -120,7 +124,6 @@ public void federatedKmeans(Types.ExecMode execMode, boolean singleWorker) {
120124
programArgs = new String[] {"-args", input("X1"), input("X2"),
121125
String.valueOf(singleWorker).toUpperCase(), String.valueOf(runs), expected("Z")};
122126
runTest(true, false, null, -1);
123-
124127
// Run actual dml script with federated matrix
125128
fullDMLScriptName = HOME + TEST_NAME + ".dml";
126129
programArgs = new String[] {"-stats","20", "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
@@ -130,8 +133,9 @@ public void federatedKmeans(Types.ExecMode execMode, boolean singleWorker) {
130133
for(int i = 0; i < rep; i++) {
131134
ParForProgramBlock.resetWorkerIDs();
132135
FederationUtils.resetFedDataID();
136+
Timing t = new Timing();
133137
runTest(true, false, null, -1);
134-
138+
LOG.debug("Federated kmeans runtime: " + t);
135139
// check for federated operations
136140
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
137141
// Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));

0 commit comments

Comments
 (0)