Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions agentscope-core/src/main/java/io/agentscope/core/ReActAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -681,52 +682,60 @@ private class HookNotifier {
Mono<List<Msg>> notifyPreReasoning(AgentBase agent, List<Msg> msgs) {
PreReasoningEvent event =
new PreReasoningEvent(agent, model.getModelName(), null, msgs);
Mono<PreReasoningEvent> result = Mono.just(event);
for (Hook hook : getSortedHooks()) {
result = result.flatMap(e -> hook.onEvent(e));
}
return result.map(PreReasoningEvent::getInputMessages);
return Flux.fromIterable(getSortedHooksCache())
.reduce(
Mono.just(event),
(currentMono, hook) -> currentMono.flatMap(hook::onEvent))
.flatMap(Function.identity())
.map(PreReasoningEvent::getInputMessages);
}

Mono<Msg> notifyPostReasoning(Msg reasoningMsg) {
PostReasoningEvent event =
new PostReasoningEvent(
ReActAgent.this, model.getModelName(), null, reasoningMsg);
Mono<PostReasoningEvent> result = Mono.just(event);
for (Hook hook : getSortedHooks()) {
result = result.flatMap(e -> hook.onEvent(e));
}
return result.map(PostReasoningEvent::getReasoningMessage);
return Flux.fromIterable(getSortedHooksCache())
.reduce(
Mono.just(event),
(currentMono, hook) -> currentMono.flatMap(hook::onEvent))
.flatMap(Function.identity())
.map(PostReasoningEvent::getReasoningMessage);
}

Mono<Void> notifyReasoningChunk(Msg chunk, Msg accumulated) {
ReasoningChunkEvent event =
new ReasoningChunkEvent(
ReActAgent.this, model.getModelName(), null, chunk, accumulated);
return Flux.fromIterable(getSortedHooks()).flatMap(hook -> hook.onEvent(event)).then();
return Flux.fromIterable(getSortedHooksCache())
.concatMap(hook -> hook.onEvent(event))
.then();
}

Mono<ToolUseBlock> notifyPreActing(ToolUseBlock toolUse) {
PreActingEvent event = new PreActingEvent(ReActAgent.this, toolkit, toolUse);
Mono<PreActingEvent> result = Mono.just(event);
for (Hook hook : getSortedHooks()) {
result = result.flatMap(e -> hook.onEvent(e));
}
return result.map(PreActingEvent::getToolUse);
return Flux.fromIterable(getSortedHooksCache())
.reduce(
Mono.just(event),
(currentMono, hook) -> currentMono.flatMap(hook::onEvent))
.flatMap(Function.identity())
.map(PreActingEvent::getToolUse);
}

Mono<Void> notifyActingChunk(ToolUseBlock toolUse, ToolResultBlock chunk) {
ActingChunkEvent event = new ActingChunkEvent(ReActAgent.this, toolkit, toolUse, chunk);
return Flux.fromIterable(getSortedHooks()).flatMap(hook -> hook.onEvent(event)).then();
return Flux.fromIterable(getSortedHooksCache())
.concatMap(hook -> hook.onEvent(event))
.then();
}

Mono<ToolResultBlock> notifyPostActing(ToolUseBlock toolUse, ToolResultBlock toolResult) {
var event = new PostActingEvent(ReActAgent.this, toolkit, toolUse, toolResult);
Mono<PostActingEvent> result = Mono.just(event);
for (Hook hook : getSortedHooks()) {
result = result.flatMap(e -> hook.onEvent(e));
}
return result.map(PostActingEvent::getToolResult);
return Flux.fromIterable(getSortedHooksCache())
.reduce(
Mono.just(event),
(currentMono, hook) -> currentMono.flatMap(hook::onEvent))
.flatMap(Function.identity())
.map(PostActingEvent::getToolResult);
}

Mono<Void> notifyStreamingMsg(Msg msg, ReasoningContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import io.agentscope.core.state.StateModuleBase;
import io.agentscope.core.tracing.TracerRegistry;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
Expand Down Expand Up @@ -91,6 +93,7 @@ public abstract class AgentBase extends StateModuleBase implements Agent {
private final AtomicBoolean running = new AtomicBoolean(false);
private final boolean checkRunning;
private final List<Hook> hooks;
private volatile List<Hook> sortedHooks;
private static final List<Hook> systemHooks = new CopyOnWriteArrayList<>();
private final Map<String, List<AgentBase>> hubSubscribers = new ConcurrentHashMap<>();

Expand Down Expand Up @@ -133,6 +136,7 @@ public AgentBase(String name, String description, boolean checkRunning, List<Hoo
this.checkRunning = checkRunning;
this.hooks = new CopyOnWriteArrayList<>(hooks != null ? hooks : List.of());
this.hooks.addAll(systemHooks);
this.sortedHooks = refreshSortedHooks();

// Register basic agent state
registerState("id", obj -> this.agentId, obj -> obj);
Expand Down Expand Up @@ -393,24 +397,87 @@ protected Mono<Void> doObserve(Msg msg) {
*/
protected abstract Mono<Msg> handleInterrupt(InterruptContext context, Msg... originalArgs);

/**
* Hook to obtain a unified entrance
*
* @param needSorted Do you need to return sorted cache hooks
* @return List of hooks corresponding to the state
*/
public List<Hook> getHooks(boolean needSorted) {
return needSorted ? getSortedHooksCache() : getHooks();
}

/**
* Get the list of hooks for this agent.
* Protected to allow subclasses to access hooks for custom notification logic.
*
* @return List of hooks
*/
public List<Hook> getHooks() {
return hooks;
protected List<Hook> getHooks() {
return Collections.unmodifiableList(this.hooks);
}

/**
* Get the pre-sorted Hook cache (for subclass calls, high-frequency read without locking)
*
* @return Immutable sorted list of Hooks
*/
protected List<Hook> getSortedHooksCache() {
return this.sortedHooks;
}

/**
* Fully replace the Hook list (based on atomicity of CopyOnWriteArrayList)
*
* @param newHooks New list of Hooks (can be null, system hooks are automatically preserved)
*/
public void updateHooks(List<Hook> newHooks) {
List<Hook> combinedHooks = new CopyOnWriteArrayList<>(systemHooks);
if (newHooks != null) {
combinedHooks.addAll(newHooks);
}

this.hooks.clear();
this.hooks.addAll(combinedHooks);

this.sortedHooks = refreshSortedHooks();
}

/**
* Incrementally add a single Hook (based on atomicity of CopyOnWriteArrayList)
*
* @param hooks The Hook to be added (must not be null)
*/
public void addHook(List<Hook> hooks) {
if (hooks == null || hooks.isEmpty()) {
return;
}
this.hooks.addAll(hooks);
this.sortedHooks = refreshSortedHooks();
}

/**
* Incrementally remove a single Hook (based on atomicity of CopyOnWriteArrayList, system hooks cannot be removed)
*
* @param hook The Hook to be removed
*/
public void removeHook(Hook hook) {
if (hook == null || systemHooks.contains(hook)) {
return;
}
boolean removed = this.hooks.remove(hook);
if (removed) {
this.sortedHooks = refreshSortedHooks();
}
}

/**
* Get hooks sorted by priority (lower value = higher priority).
* Hooks with the same priority maintain registration order.
* Refresh the pre-sorted Hook cache (all agents share the sorting logic)
*
* @return Sorted list of hooks
* @return Immutable list of Hooks sorted by priority
*/
protected List<Hook> getSortedHooks() {
return hooks.stream().sorted(java.util.Comparator.comparingInt(Hook::priority)).toList();
private List<Hook> refreshSortedHooks() {
return this.hooks.stream().sorted(Comparator.comparingInt(Hook::priority)).toList();
}

/**
Expand All @@ -425,7 +492,7 @@ protected List<Hook> getSortedHooks() {
private Mono<List<Msg>> notifyPreCall(List<Msg> msgs) {
PreCallEvent event = new PreCallEvent(this, msgs);
Mono<PreCallEvent> result = Mono.just(event);
for (Hook hook : getSortedHooks()) {
for (Hook hook : getSortedHooksCache()) {
result = result.flatMap(hook::onEvent);
}
return result.map(PreCallEvent::getInputMessages);
Expand All @@ -444,7 +511,7 @@ private Mono<Msg> notifyPostCall(Msg finalMsg) {
}
PostCallEvent event = new PostCallEvent(this, finalMsg);
Mono<PostCallEvent> result = Mono.just(event);
for (Hook hook : getSortedHooks()) {
for (Hook hook : getSortedHooksCache()) {
result = result.flatMap(hook::onEvent);
}
// After hooks, broadcast to subscribers
Expand All @@ -460,7 +527,7 @@ private Mono<Msg> notifyPostCall(Msg finalMsg) {
*/
private Mono<Void> notifyError(Throwable error) {
ErrorEvent event = new ErrorEvent(this, error);
return Flux.fromIterable(getSortedHooks()).flatMap(hook -> hook.onEvent(event)).then();
return Flux.fromIterable(getSortedHooksCache()).flatMap(hook -> hook.onEvent(event)).then();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import io.agentscope.core.message.Msg;
import io.agentscope.core.message.TextBlock;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.function.BiConsumer;
Expand Down Expand Up @@ -107,7 +108,7 @@ private A2aAgent(
this.agentCardResolver = agentCardResolver;
this.memory = memory;
LoggerUtil.debug(log, "A2aAgent init with config: {}", a2aAgentConfig);
getHooks().add(new A2aClientLifecycleHook());
super.addHook(Collections.singletonList(new A2aClientLifecycleHook()));
this.clientEventHandlerRouter = new ClientEventHandlerRouter();
}

Expand All @@ -119,7 +120,7 @@ protected Mono<Msg> doCall(List<Msg> msgs) {
LoggerUtil.info(log, "[{}] A2aAgent start call.", currentRequestId);
LoggerUtil.debug(log, "[{}] A2aAgent call with input messages: ", currentRequestId);
LoggerUtil.logTextMsgDetail(log, memory.getMessages());
clientEventContext.setHooks(getSortedHooks());
clientEventContext.setHooks(getSortedHooksCache());
return Mono.defer(
() -> {
Message message = MessageConvertUtil.convertFromMsg(memory.getMessages());
Expand Down
Loading