1919
2020package org .apache .sysds .runtime .util ;
2121
22+ import java .util .ArrayList ;
2223import java .util .Collection ;
2324import java .util .List ;
2425import java .util .Map .Entry ;
2930import java .util .concurrent .Executors ;
3031import java .util .concurrent .ForkJoinPool ;
3132import java .util .concurrent .Future ;
33+ import java .util .concurrent .FutureTask ;
3234import java .util .concurrent .TimeUnit ;
3335import java .util .concurrent .TimeoutException ;
3436
@@ -103,11 +105,15 @@ public static ExecutorService get() {
103105 * @return The executor with specified parallelism
104106 */
105107 public synchronized static ExecutorService get (int k ) {
108+ if (k <= 1 ){
109+ LOG .warn ("Invalid to create thread pool with <= one thread returning single thread executor" , new RuntimeException ());
110+ return new SameThreadExecutorService ();
111+ }
112+
106113 final Thread thisThread = Thread .currentThread ();
107114 final String threadName = thisThread .getName ();
108115 // Contains main, because we name our test threads TestRunner_main
109116 final boolean mainThread = threadName .contains ("main" );
110-
111117 if (size == k && mainThread )
112118 return shared ; // use the default thread pool if main thread and max parallelism.
113119 else if (mainThread || threadName .contains ("PARFOR" )) {
@@ -134,6 +140,8 @@ else if(mainThread || threadName.contains("PARFOR")) {
134140 return new CommonThreadPool (new ForkJoinPool (k ));
135141 }
136142 return Executors .newFixedThreadPool (k );
143+
144+
137145 }
138146
139147 }
@@ -330,4 +338,110 @@ else if(name.contains("test"))
330338 else
331339 return false ;
332340 }
341+
342+
343+ private static class SameThreadExecutorService implements ExecutorService {
344+
345+ @ Override
346+ public void execute (Runnable command ) {
347+ command .run ();
348+ }
349+
350+ @ Override
351+ public void shutdown () {
352+ // nothing
353+ }
354+
355+ @ Override
356+ public List <Runnable > shutdownNow () {
357+ return new ArrayList <>();
358+ }
359+
360+ @ Override
361+ public boolean isShutdown () {
362+ return false ;
363+ }
364+
365+ @ Override
366+ public boolean isTerminated () {
367+ return false ;
368+
369+ }
370+
371+ @ Override
372+ public boolean awaitTermination (long timeout , TimeUnit unit ) throws InterruptedException {
373+ return true ;
374+ }
375+
376+ @ Override
377+ public <T > Future <T > submit (Callable <T > task ) {
378+ return new FutureTask <>(task );
379+ }
380+
381+ @ Override
382+ public <T > Future <T > submit (Runnable task , T result ) {
383+ return new FutureTask <>(() -> {
384+ task .run ();
385+ return result ;
386+ });
387+ }
388+
389+ @ Override
390+ public Future <?> submit (Runnable task ) {
391+ return new FutureTask <>(() -> {
392+ task .run ();
393+ return null ;
394+ });
395+ }
396+
397+ @ Override
398+ public <T > List <Future <T >> invokeAll (Collection <? extends Callable <T >> tasks ) throws InterruptedException {
399+ List <Future <T >> ret = new ArrayList <>();
400+ for (Callable <T > t : tasks )
401+ ret .add (new FutureTask <>(t ));
402+ return ret ;
403+ }
404+
405+ @ Override
406+ public <T > List <Future <T >> invokeAll (Collection <? extends Callable <T >> tasks , long timeout , TimeUnit unit )
407+ throws InterruptedException {
408+ List <Future <T >> ret = new ArrayList <>();
409+ for (Callable <T > t : tasks )
410+ ret .add (new FutureTask <>(t ));
411+ return ret ;
412+ }
413+
414+ @ Override
415+ public <T > T invokeAny (Collection <? extends Callable <T >> tasks ) throws InterruptedException , ExecutionException {
416+ Exception e = null ;
417+ for (Callable <T > t : tasks ) {
418+ try {
419+ T r = t .call ();
420+ return r ;
421+ }
422+ catch (Exception ee ) {
423+ e = ee ;
424+ }
425+
426+ }
427+ throw new ExecutionException ("failed" , e );
428+ }
429+
430+ @ Override
431+ public <T > T invokeAny (Collection <? extends Callable <T >> tasks , long timeout , TimeUnit unit )
432+ throws InterruptedException , ExecutionException , TimeoutException {
433+ Exception e = null ;
434+ for (Callable <T > t : tasks ) {
435+ try {
436+ T r = t .call ();
437+ return r ;
438+ }
439+ catch (Exception ee ) {
440+ e = ee ;
441+ }
442+
443+ }
444+ throw new ExecutionException ("failed" , e );
445+ }
446+ }
333447}
0 commit comments