11package com .marginallyclever .nodegraphcore ;
22
3- import java . util . concurrent . ExecutorService ;
4- import java . util . concurrent . Executors ;
5- import java . util . concurrent . TimeUnit ;
3+ import org . slf4j . Logger ;
4+ import org . slf4j . LoggerFactory ;
5+
66import java .util .concurrent .*;
77import java .util .concurrent .atomic .AtomicInteger ;
88
99/**
1010 * A scheduler that runs a graph of nodes using a thread pool.
1111 */
1212public class ThreadPoolScheduler {
13+ private static final Logger logger = LoggerFactory .getLogger (ThreadPoolScheduler .class );
1314 private final ExecutorService threadPool = Executors .newVirtualThreadPerTaskExecutor ();
1415 private final BlockingQueue <Node > readyNodes = new LinkedBlockingQueue <>();
1516 private final AtomicInteger activeTasks = new AtomicInteger (0 );
@@ -22,10 +23,12 @@ public ThreadPoolScheduler() {}
2223 */
2324 public void submit (Node node ) {
2425 if (readyNodes .contains (node )) {
25- // move node to the end of the queue
26+ logger .debug ("defer {}" , node .getName ());
27+ // move node to the tail of the queue
2628 readyNodes .remove (node );
2729 readyNodes .add (node );
2830 } else {
31+ logger .debug ("add {}" , node .getName ());
2932 readyNodes .add (node ); // Add the node to the ready queue
3033 activeTasks .incrementAndGet (); // Increment task count
3134 }
@@ -55,19 +58,21 @@ public void update() {
5558 if (node == null ) return ;
5659
5760 threadPool .submit (() -> {
61+ logger .debug ("start {}" , node .getName ());
5862 try {
5963 node .update ();
6064 node .updateBounds ();
6165 node .setInputsClean ();
6266 for (Node downstreamNode : node .getDownstreamNodes ()) {
63- if (downstreamNode .isDirty ()) {
64- submit (downstreamNode ); // Submit downstream nodes
67+ if (downstreamNode .isDirty () && ! hasQueued ( downstreamNode ) ) {
68+ submit (downstreamNode ); // Submit downstream nodes
6569 }
6670 }
6771 } catch (Exception e ) {
6872 System .err .println ("Error in node execution: " + e .getMessage ());
6973 e .printStackTrace ();
7074 } finally {
75+ logger .debug ("end {}" , node .getName ());
7176 activeTasks .decrementAndGet (); // Mark this task as completed
7277 }
7378 });
@@ -84,4 +89,12 @@ public void shutdown(long timeoutSeconds) {
8489 Thread .currentThread ().interrupt ();
8590 }
8691 }
92+
93+ /**
94+ * @param n the {@link Node} to check
95+ * @return true if the node is in the ready queue
96+ */
97+ public boolean hasQueued (Node n ) {
98+ return readyNodes .contains (n );
99+ }
87100}
0 commit comments