Skip to content

Commit 6a104d4

Browse files
committed
Single thread
1 parent aea8162 commit 6a104d4

File tree

2 files changed

+116
-1
lines changed

2 files changed

+116
-1
lines changed

src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.apache.sysds.runtime.util;
2121

22+
import java.util.ArrayList;
2223
import java.util.Collection;
2324
import java.util.List;
2425
import java.util.Map.Entry;
@@ -29,6 +30,7 @@
2930
import java.util.concurrent.Executors;
3031
import java.util.concurrent.ForkJoinPool;
3132
import java.util.concurrent.Future;
33+
import java.util.concurrent.FutureTask;
3234
import java.util.concurrent.TimeUnit;
3335
import java.util.concurrent.TimeoutException;
3436

@@ -103,11 +105,15 @@ public static ExecutorService get() {
103105
* @return The executor with specified parallelism
104106
*/
105107
public synchronized static ExecutorService get(int k) {
108+
if(k <= 1){
109+
LOG.warn("Invalid to create thread pool with <= one thread returning single thread executor", new RuntimeException());
110+
return new SameThreadExecutorService();
111+
}
112+
106113
final Thread thisThread = Thread.currentThread();
107114
final String threadName = thisThread.getName();
108115
// Contains main, because we name our test threads TestRunner_main
109116
final boolean mainThread = threadName.contains("main");
110-
111117
if(size == k && mainThread)
112118
return shared; // use the default thread pool if main thread and max parallelism.
113119
else if(mainThread || threadName.contains("PARFOR")) {
@@ -134,6 +140,8 @@ else if(mainThread || threadName.contains("PARFOR")) {
134140
return new CommonThreadPool(new ForkJoinPool(k));
135141
}
136142
return Executors.newFixedThreadPool(k);
143+
144+
137145
}
138146

139147
}
@@ -330,4 +338,110 @@ else if(name.contains("test"))
330338
else
331339
return false;
332340
}
341+
342+
343+
private static class SameThreadExecutorService implements ExecutorService {
344+
345+
@Override
346+
public void execute(Runnable command) {
347+
command.run();
348+
}
349+
350+
@Override
351+
public void shutdown() {
352+
// nothing
353+
}
354+
355+
@Override
356+
public List<Runnable> shutdownNow() {
357+
return new ArrayList<>();
358+
}
359+
360+
@Override
361+
public boolean isShutdown() {
362+
return false;
363+
}
364+
365+
@Override
366+
public boolean isTerminated() {
367+
return false;
368+
369+
}
370+
371+
@Override
372+
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
373+
return true;
374+
}
375+
376+
@Override
377+
public <T> Future<T> submit(Callable<T> task) {
378+
return new FutureTask<>(task);
379+
}
380+
381+
@Override
382+
public <T> Future<T> submit(Runnable task, T result) {
383+
return new FutureTask<>(() -> {
384+
task.run();
385+
return result;
386+
});
387+
}
388+
389+
@Override
390+
public Future<?> submit(Runnable task) {
391+
return new FutureTask<>(() -> {
392+
task.run();
393+
return null;
394+
});
395+
}
396+
397+
@Override
398+
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) throws InterruptedException {
399+
List<Future<T>> ret = new ArrayList<>();
400+
for(Callable<T> t : tasks)
401+
ret.add(new FutureTask<>(t));
402+
return ret;
403+
}
404+
405+
@Override
406+
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
407+
throws InterruptedException {
408+
List<Future<T>> ret = new ArrayList<>();
409+
for(Callable<T> t : tasks)
410+
ret.add(new FutureTask<>(t));
411+
return ret;
412+
}
413+
414+
@Override
415+
public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException, ExecutionException {
416+
Exception e = null;
417+
for(Callable<T> t : tasks) {
418+
try {
419+
T r = t.call();
420+
return r;
421+
}
422+
catch(Exception ee) {
423+
e = ee;
424+
}
425+
426+
}
427+
throw new ExecutionException("failed", e);
428+
}
429+
430+
@Override
431+
public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
432+
throws InterruptedException, ExecutionException, TimeoutException {
433+
Exception e = null;
434+
for(Callable<T> t : tasks) {
435+
try {
436+
T r = t.call();
437+
return r;
438+
}
439+
catch(Exception ee) {
440+
e = ee;
441+
}
442+
443+
}
444+
throw new ExecutionException("failed", e);
445+
}
446+
}
333447
}

src/test/java/org/apache/sysds/test/functions/codegen/RowVectorComparisonTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ private void testCodegenIntegration( String testname, boolean rewrites, ExecType
128128
{
129129
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
130130
ExecMode platformOld = setExecMode(instType);
131+
setOutputBuffering(true);
131132

132133
try
133134
{

0 commit comments

Comments
 (0)