11package com .javaaidev .agenticpatterns .parallelizationworkflow ;
22
33import com .javaaidev .agenticpatterns .taskexecution .TaskExecutionAgent ;
4+ import io .micrometer .context .ContextExecutorService ;
5+ import io .micrometer .context .ContextSnapshotFactory ;
46import io .micrometer .observation .ObservationRegistry ;
57import java .lang .reflect .Type ;
68import java .time .Duration ;
7- import java .time .Instant ;
89import java .util .List ;
910import java .util .Map ;
1011import java .util .Map .Entry ;
12+ import java .util .Objects ;
1113import java .util .concurrent .CopyOnWriteArrayList ;
12- import java .util .concurrent .StructuredTaskScope ;
13- import java .util .concurrent .StructuredTaskScope .Subtask ;
14- import java .util .concurrent .StructuredTaskScope .Subtask .State ;
14+ import java .util .concurrent .ExecutionException ;
15+ import java .util .concurrent .ExecutorService ;
16+ import java .util .concurrent .Executors ;
17+ import java .util .concurrent .Future ;
18+ import java .util .concurrent .TimeUnit ;
1519import java .util .concurrent .TimeoutException ;
1620import java .util .function .Function ;
1721import java .util .stream .Collectors ;
@@ -31,13 +35,32 @@ public record SubtaskCreationRequest<Request>(
3135
3236 }
3337
34- public record SubtaskContext < Request > (
35- SubtaskCreationRequest < Request > creationRequest ,
36- @ Nullable Subtask <?> job ,
38+ public record TaskExecutionContext (
39+ Future <?> job ,
40+ Duration maxWaitTime ,
3741 @ Nullable Object result ,
3842 @ Nullable Throwable error
3943 ) {
4044
45+ public TaskExecutionContext (Future <?> job , Duration maxWaitTime ) {
46+ this (job , maxWaitTime , null , null );
47+ }
48+
49+ public TaskExecutionContext collectResult () {
50+ try {
51+ var result = job ().get (maxWaitTime ().toSeconds (), TimeUnit .SECONDS );
52+ return new TaskExecutionContext (job (), maxWaitTime (), result , null );
53+ } catch (InterruptedException | ExecutionException | TimeoutException e ) {
54+ return new TaskExecutionContext (job (), maxWaitTime (), null , job ().exceptionNow ());
55+ }
56+ }
57+ }
58+
59+ public record SubtaskContext <Request >(
60+ SubtaskCreationRequest <Request > creationRequest ,
61+ @ Nullable TaskExecutionContext taskExecutionContext
62+ ) {
63+
4164 public static <Request , TaskRequest , TaskResponse > SubtaskContext <Request > create (String taskId ,
4265 TaskExecutionAgent <TaskRequest , TaskResponse > task ,
4366 Function <Request , TaskRequest > requestTransformer ) {
@@ -46,27 +69,31 @@ public static <Request, TaskRequest, TaskResponse> SubtaskContext<Request> creat
4669
4770 public static <Request > SubtaskContext <Request > create (
4871 SubtaskCreationRequest <Request > creationRequest ) {
49- return new SubtaskContext <>(creationRequest , null , null , null );
72+ return new SubtaskContext <>(creationRequest , null );
5073 }
5174
52- public SubtaskContext <Request > taskStarted (Subtask <?> job ) {
53- return new SubtaskContext <>(this .creationRequest (), job , null ,
54- null );
75+ public SubtaskContext <Request > taskStarted (Future <?> job , Duration maxWaitTime ) {
76+ return new SubtaskContext <>(this .creationRequest (),
77+ new TaskExecutionContext ( job , maxWaitTime ) );
5578 }
5679
5780 public SubtaskContext <Request > collectResult () {
58- if (this .job () == null ) {
59- return this ;
60- }
61- var state = this .job ().state ();
62- return new SubtaskContext <>(this .creationRequest (), this .job (),
63- state == State .SUCCESS ? this .job ().get () : null ,
64- state == State .FAILED ? this .job ().exception () : null );
81+ return new SubtaskContext <>(creationRequest (),
82+ Objects .requireNonNull (taskExecutionContext (), "task execution context cannot be null" )
83+ .collectResult ());
6584 }
6685
6786 public String taskId () {
6887 return creationRequest ().taskId ();
6988 }
89+
90+ public @ Nullable Object result () {
91+ return taskExecutionContext () != null ? taskExecutionContext ().result () : null ;
92+ }
93+
94+ public @ Nullable Throwable error () {
95+ return taskExecutionContext () != null ? taskExecutionContext ().error () : null ;
96+ }
7097 }
7198
7299 public ParallelizationWorkflowAgent (ChatClient chatClient ,
@@ -90,7 +117,7 @@ protected <TaskRequest, TaskResponse> void addSubtask(String taskId,
90117 subtasks .add (SubtaskContext .create (taskId , subtask , requestTransformer ));
91118 }
92119
93- protected Duration getMaxExecutionDuration () {
120+ protected Duration getMaxTaskExecutionDuration () {
94121 return Duration .ofMinutes (3 );
95122 }
96123
@@ -118,26 +145,28 @@ public Map<String, Object> allSuccessfulResults() {
118145 }
119146 }
120147
148+ protected ExecutorService getTaskExecutorService () {
149+ var executor = Executors .newThreadPerTaskExecutor (
150+ Thread .ofVirtual ().name ("agent-task-" , 1 ).factory ());
151+ return ContextExecutorService .wrap (executor ,
152+ ContextSnapshotFactory .builder ().clearMissing (true ).build ());
153+ }
154+
121155 protected TaskExecutionResults runSubtasks (@ Nullable Request request ) {
122156 var createdTasks = createTasks (request );
123157 if (createdTasks != null ) {
124158 subtasks .addAll (createdTasks .stream ().map (SubtaskContext ::create ).toList ());
125159 }
126- try (var scope = new StructuredTaskScope <> ()) {
160+ try (var executor = getTaskExecutorService ()) {
127161 var jobs = subtasks .stream ().map (context -> {
128162 var creationRequest = context .creationRequest ();
129163 LOGGER .info ("Starting subtask {}" , creationRequest .taskId ());
130- var job = scope . fork (
164+ var job = executor . submit (
131165 () -> creationRequest .task ().call (creationRequest .requestTransformer ().apply (request )));
132- return context .taskStarted (job );
166+ return context .taskStarted (job , getMaxTaskExecutionDuration () );
133167 }).toList ();
134- try {
135- LOGGER .info ("Waiting for all subtasks to finish, timeout in {}" , getMaxExecutionDuration ());
136- scope .joinUntil (Instant .now ().plus (getMaxExecutionDuration ()));
137- } catch (InterruptedException | TimeoutException e ) {
138- LOGGER .error ("Error occurred when executing subtask, check status for individual subtask" ,
139- e );
140- }
168+ LOGGER .info ("Waiting for all subtasks to finish" );
169+ jobs .forEach (SubtaskContext ::collectResult );
141170 LOGGER .info ("All subtasks completed, assembling the results" );
142171 var results = jobs .stream ().map (SubtaskContext ::collectResult )
143172 .collect (Collectors .toMap (SubtaskContext ::taskId ,
0 commit comments