Skip to content

Commit bc1fd71

Browse files
committed
update implementation of ParallelizationWorkflowAgent
1 parent 68b92fa commit bc1fd71

File tree

4 files changed

+59
-59
lines changed

4 files changed

+59
-59
lines changed

examples/pom.xml

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,4 @@
104104
<artifactId>micrometer-tracing-bridge-otel</artifactId>
105105
</dependency>
106106
</dependencies>
107-
108-
<build>
109-
<plugins>
110-
<plugin>
111-
<groupId>org.apache.maven.plugins</groupId>
112-
<artifactId>maven-compiler-plugin</artifactId>
113-
<version>3.13.0</version>
114-
<configuration>
115-
<compilerArgs>
116-
<arg>--enable-preview</arg>
117-
</compilerArgs>
118-
</configuration>
119-
</plugin>
120-
</plugins>
121-
</build>
122107
</project>

examples/src/main/resources/application.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ spring:
77
api-key: ${OPENAI_API_KEY:}
88
chat:
99
options:
10-
model: gpt-4o
10+
model: gpt-4o-mini
1111
temperature: 0.0
1212
logging:
1313
level:

parallelization-workflow/pom.xml

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,4 @@
2626
</dependency>
2727
</dependencies>
2828

29-
<build>
30-
<plugins>
31-
<plugin>
32-
<groupId>org.apache.maven.plugins</groupId>
33-
<artifactId>maven-compiler-plugin</artifactId>
34-
<configuration>
35-
<source>21</source>
36-
<target>21</target>
37-
<compilerArgs>--enable-preview</compilerArgs>
38-
</configuration>
39-
</plugin>
40-
</plugins>
41-
</build>
42-
4329
</project>

parallelization-workflow/src/main/java/com/javaaidev/agenticpatterns/parallelizationworkflow/ParallelizationWorkflowAgent.java

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
package com.javaaidev.agenticpatterns.parallelizationworkflow;
22

33
import com.javaaidev.agenticpatterns.taskexecution.TaskExecutionAgent;
4+
import io.micrometer.context.ContextExecutorService;
5+
import io.micrometer.context.ContextSnapshotFactory;
46
import io.micrometer.observation.ObservationRegistry;
57
import java.lang.reflect.Type;
68
import java.time.Duration;
7-
import java.time.Instant;
89
import java.util.List;
910
import java.util.Map;
1011
import java.util.Map.Entry;
12+
import java.util.Objects;
1113
import java.util.concurrent.CopyOnWriteArrayList;
12-
import java.util.concurrent.StructuredTaskScope;
13-
import java.util.concurrent.StructuredTaskScope.Subtask;
14-
import java.util.concurrent.StructuredTaskScope.Subtask.State;
14+
import java.util.concurrent.ExecutionException;
15+
import java.util.concurrent.ExecutorService;
16+
import java.util.concurrent.Executors;
17+
import java.util.concurrent.Future;
18+
import java.util.concurrent.TimeUnit;
1519
import java.util.concurrent.TimeoutException;
1620
import java.util.function.Function;
1721
import java.util.stream.Collectors;
@@ -31,13 +35,32 @@ public record SubtaskCreationRequest<Request>(
3135

3236
}
3337

34-
public record SubtaskContext<Request>(
35-
SubtaskCreationRequest<Request> creationRequest,
36-
@Nullable Subtask<?> job,
38+
public record TaskExecutionContext(
39+
Future<?> job,
40+
Duration maxWaitTime,
3741
@Nullable Object result,
3842
@Nullable Throwable error
3943
) {
4044

45+
public TaskExecutionContext(Future<?> job, Duration maxWaitTime) {
46+
this(job, maxWaitTime, null, null);
47+
}
48+
49+
public TaskExecutionContext collectResult() {
50+
try {
51+
var result = job().get(maxWaitTime().toSeconds(), TimeUnit.SECONDS);
52+
return new TaskExecutionContext(job(), maxWaitTime(), result, null);
53+
} catch (InterruptedException | ExecutionException | TimeoutException e) {
54+
return new TaskExecutionContext(job(), maxWaitTime(), null, job().exceptionNow());
55+
}
56+
}
57+
}
58+
59+
public record SubtaskContext<Request>(
60+
SubtaskCreationRequest<Request> creationRequest,
61+
@Nullable TaskExecutionContext taskExecutionContext
62+
) {
63+
4164
public static <Request, TaskRequest, TaskResponse> SubtaskContext<Request> create(String taskId,
4265
TaskExecutionAgent<TaskRequest, TaskResponse> task,
4366
Function<Request, TaskRequest> requestTransformer) {
@@ -46,27 +69,31 @@ public static <Request, TaskRequest, TaskResponse> SubtaskContext<Request> creat
4669

4770
public static <Request> SubtaskContext<Request> create(
4871
SubtaskCreationRequest<Request> creationRequest) {
49-
return new SubtaskContext<>(creationRequest, null, null, null);
72+
return new SubtaskContext<>(creationRequest, null);
5073
}
5174

52-
public SubtaskContext<Request> taskStarted(Subtask<?> job) {
53-
return new SubtaskContext<>(this.creationRequest(), job, null,
54-
null);
75+
public SubtaskContext<Request> taskStarted(Future<?> job, Duration maxWaitTime) {
76+
return new SubtaskContext<>(this.creationRequest(),
77+
new TaskExecutionContext(job, maxWaitTime));
5578
}
5679

5780
public SubtaskContext<Request> collectResult() {
58-
if (this.job() == null) {
59-
return this;
60-
}
61-
var state = this.job().state();
62-
return new SubtaskContext<>(this.creationRequest(), this.job(),
63-
state == State.SUCCESS ? this.job().get() : null,
64-
state == State.FAILED ? this.job().exception() : null);
81+
return new SubtaskContext<>(creationRequest(),
82+
Objects.requireNonNull(taskExecutionContext(), "task execution context cannot be null")
83+
.collectResult());
6584
}
6685

6786
public String taskId() {
6887
return creationRequest().taskId();
6988
}
89+
90+
public @Nullable Object result() {
91+
return taskExecutionContext() != null ? taskExecutionContext().result() : null;
92+
}
93+
94+
public @Nullable Throwable error() {
95+
return taskExecutionContext() != null ? taskExecutionContext().error() : null;
96+
}
7097
}
7198

7299
public ParallelizationWorkflowAgent(ChatClient chatClient,
@@ -90,7 +117,7 @@ protected <TaskRequest, TaskResponse> void addSubtask(String taskId,
90117
subtasks.add(SubtaskContext.create(taskId, subtask, requestTransformer));
91118
}
92119

93-
protected Duration getMaxExecutionDuration() {
120+
protected Duration getMaxTaskExecutionDuration() {
94121
return Duration.ofMinutes(3);
95122
}
96123

@@ -118,26 +145,28 @@ public Map<String, Object> allSuccessfulResults() {
118145
}
119146
}
120147

148+
protected ExecutorService getTaskExecutorService() {
149+
var executor = Executors.newThreadPerTaskExecutor(
150+
Thread.ofVirtual().name("agent-task-", 1).factory());
151+
return ContextExecutorService.wrap(executor,
152+
ContextSnapshotFactory.builder().clearMissing(true).build());
153+
}
154+
121155
protected TaskExecutionResults runSubtasks(@Nullable Request request) {
122156
var createdTasks = createTasks(request);
123157
if (createdTasks != null) {
124158
subtasks.addAll(createdTasks.stream().map(SubtaskContext::create).toList());
125159
}
126-
try (var scope = new StructuredTaskScope<>()) {
160+
try (var executor = getTaskExecutorService()) {
127161
var jobs = subtasks.stream().map(context -> {
128162
var creationRequest = context.creationRequest();
129163
LOGGER.info("Starting subtask {}", creationRequest.taskId());
130-
var job = scope.fork(
164+
var job = executor.submit(
131165
() -> creationRequest.task().call(creationRequest.requestTransformer().apply(request)));
132-
return context.taskStarted(job);
166+
return context.taskStarted(job, getMaxTaskExecutionDuration());
133167
}).toList();
134-
try {
135-
LOGGER.info("Waiting for all subtasks to finish, timeout in {}", getMaxExecutionDuration());
136-
scope.joinUntil(Instant.now().plus(getMaxExecutionDuration()));
137-
} catch (InterruptedException | TimeoutException e) {
138-
LOGGER.error("Error occurred when executing subtask, check status for individual subtask",
139-
e);
140-
}
168+
LOGGER.info("Waiting for all subtasks to finish");
169+
jobs.forEach(SubtaskContext::collectResult);
141170
LOGGER.info("All subtasks completed, assembling the results");
142171
var results = jobs.stream().map(SubtaskContext::collectResult)
143172
.collect(Collectors.toMap(SubtaskContext::taskId,

0 commit comments

Comments
 (0)