diff --git a/agentscope-core/src/main/java/io/agentscope/core/ReActAgent.java b/agentscope-core/src/main/java/io/agentscope/core/ReActAgent.java index acd25db63..e295d3d8e 100644 --- a/agentscope-core/src/main/java/io/agentscope/core/ReActAgent.java +++ b/agentscope-core/src/main/java/io/agentscope/core/ReActAgent.java @@ -61,6 +61,7 @@ import java.util.Comparator; import java.util.List; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -681,52 +682,60 @@ private class HookNotifier { Mono> notifyPreReasoning(AgentBase agent, List msgs) { PreReasoningEvent event = new PreReasoningEvent(agent, model.getModelName(), null, msgs); - Mono result = Mono.just(event); - for (Hook hook : getSortedHooks()) { - result = result.flatMap(e -> hook.onEvent(e)); - } - return result.map(PreReasoningEvent::getInputMessages); + return Flux.fromIterable(getSortedHooksCache()) + .reduce( + Mono.just(event), + (currentMono, hook) -> currentMono.flatMap(hook::onEvent)) + .flatMap(Function.identity()) + .map(PreReasoningEvent::getInputMessages); } Mono notifyPostReasoning(Msg reasoningMsg) { PostReasoningEvent event = new PostReasoningEvent( ReActAgent.this, model.getModelName(), null, reasoningMsg); - Mono result = Mono.just(event); - for (Hook hook : getSortedHooks()) { - result = result.flatMap(e -> hook.onEvent(e)); - } - return result.map(PostReasoningEvent::getReasoningMessage); + return Flux.fromIterable(getSortedHooksCache()) + .reduce( + Mono.just(event), + (currentMono, hook) -> currentMono.flatMap(hook::onEvent)) + .flatMap(Function.identity()) + .map(PostReasoningEvent::getReasoningMessage); } Mono notifyReasoningChunk(Msg chunk, Msg accumulated) { ReasoningChunkEvent event = new ReasoningChunkEvent( ReActAgent.this, model.getModelName(), null, chunk, accumulated); - return Flux.fromIterable(getSortedHooks()).flatMap(hook -> hook.onEvent(event)).then(); + return Flux.fromIterable(getSortedHooksCache()) + .concatMap(hook -> hook.onEvent(event)) + .then(); } Mono notifyPreActing(ToolUseBlock toolUse) { PreActingEvent event = new PreActingEvent(ReActAgent.this, toolkit, toolUse); - Mono result = Mono.just(event); - for (Hook hook : getSortedHooks()) { - result = result.flatMap(e -> hook.onEvent(e)); - } - return result.map(PreActingEvent::getToolUse); + return Flux.fromIterable(getSortedHooksCache()) + .reduce( + Mono.just(event), + (currentMono, hook) -> currentMono.flatMap(hook::onEvent)) + .flatMap(Function.identity()) + .map(PreActingEvent::getToolUse); } Mono notifyActingChunk(ToolUseBlock toolUse, ToolResultBlock chunk) { ActingChunkEvent event = new ActingChunkEvent(ReActAgent.this, toolkit, toolUse, chunk); - return Flux.fromIterable(getSortedHooks()).flatMap(hook -> hook.onEvent(event)).then(); + return Flux.fromIterable(getSortedHooksCache()) + .concatMap(hook -> hook.onEvent(event)) + .then(); } Mono notifyPostActing(ToolUseBlock toolUse, ToolResultBlock toolResult) { var event = new PostActingEvent(ReActAgent.this, toolkit, toolUse, toolResult); - Mono result = Mono.just(event); - for (Hook hook : getSortedHooks()) { - result = result.flatMap(e -> hook.onEvent(e)); - } - return result.map(PostActingEvent::getToolResult); + return Flux.fromIterable(getSortedHooksCache()) + .reduce( + Mono.just(event), + (currentMono, hook) -> currentMono.flatMap(hook::onEvent)) + .flatMap(Function.identity()) + .map(PostActingEvent::getToolResult); } Mono notifyStreamingMsg(Msg msg, ReasoningContext context) { diff --git a/agentscope-core/src/main/java/io/agentscope/core/agent/AgentBase.java b/agentscope-core/src/main/java/io/agentscope/core/agent/AgentBase.java index 89b9b01a6..95f8c2dac 100644 --- a/agentscope-core/src/main/java/io/agentscope/core/agent/AgentBase.java +++ b/agentscope-core/src/main/java/io/agentscope/core/agent/AgentBase.java @@ -25,6 +25,8 @@ import io.agentscope.core.state.StateModuleBase; import io.agentscope.core.tracing.TracerRegistry; import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.UUID; @@ -91,6 +93,7 @@ public abstract class AgentBase extends StateModuleBase implements Agent { private final AtomicBoolean running = new AtomicBoolean(false); private final boolean checkRunning; private final List hooks; + private volatile List sortedHooks; private static final List systemHooks = new CopyOnWriteArrayList<>(); private final Map> hubSubscribers = new ConcurrentHashMap<>(); @@ -133,6 +136,7 @@ public AgentBase(String name, String description, boolean checkRunning, List(hooks != null ? hooks : List.of()); this.hooks.addAll(systemHooks); + this.sortedHooks = refreshSortedHooks(); // Register basic agent state registerState("id", obj -> this.agentId, obj -> obj); @@ -393,24 +397,87 @@ protected Mono doObserve(Msg msg) { */ protected abstract Mono handleInterrupt(InterruptContext context, Msg... originalArgs); + /** + * Hook to obtain a unified entrance + * + * @param needSorted Do you need to return sorted cache hooks + * @return List of hooks corresponding to the state + */ + public List getHooks(boolean needSorted) { + return needSorted ? getSortedHooksCache() : getHooks(); + } + /** * Get the list of hooks for this agent. * Protected to allow subclasses to access hooks for custom notification logic. * * @return List of hooks */ - public List getHooks() { - return hooks; + protected List getHooks() { + return Collections.unmodifiableList(this.hooks); + } + + /** + * Get the pre-sorted Hook cache (for subclass calls, high-frequency read without locking) + * + * @return Immutable sorted list of Hooks + */ + protected List getSortedHooksCache() { + return this.sortedHooks; + } + + /** + * Fully replace the Hook list (based on atomicity of CopyOnWriteArrayList) + * + * @param newHooks New list of Hooks (can be null, system hooks are automatically preserved) + */ + public void updateHooks(List newHooks) { + List combinedHooks = new CopyOnWriteArrayList<>(systemHooks); + if (newHooks != null) { + combinedHooks.addAll(newHooks); + } + + this.hooks.clear(); + this.hooks.addAll(combinedHooks); + + this.sortedHooks = refreshSortedHooks(); + } + + /** + * Incrementally add a single Hook (based on atomicity of CopyOnWriteArrayList) + * + * @param hooks The Hook to be added (must not be null) + */ + public void addHook(List hooks) { + if (hooks == null || hooks.isEmpty()) { + return; + } + this.hooks.addAll(hooks); + this.sortedHooks = refreshSortedHooks(); + } + + /** + * Incrementally remove a single Hook (based on atomicity of CopyOnWriteArrayList, system hooks cannot be removed) + * + * @param hook The Hook to be removed + */ + public void removeHook(Hook hook) { + if (hook == null || systemHooks.contains(hook)) { + return; + } + boolean removed = this.hooks.remove(hook); + if (removed) { + this.sortedHooks = refreshSortedHooks(); + } } /** - * Get hooks sorted by priority (lower value = higher priority). - * Hooks with the same priority maintain registration order. + * Refresh the pre-sorted Hook cache (all agents share the sorting logic) * - * @return Sorted list of hooks + * @return Immutable list of Hooks sorted by priority */ - protected List getSortedHooks() { - return hooks.stream().sorted(java.util.Comparator.comparingInt(Hook::priority)).toList(); + private List refreshSortedHooks() { + return this.hooks.stream().sorted(Comparator.comparingInt(Hook::priority)).toList(); } /** @@ -425,7 +492,7 @@ protected List getSortedHooks() { private Mono> notifyPreCall(List msgs) { PreCallEvent event = new PreCallEvent(this, msgs); Mono result = Mono.just(event); - for (Hook hook : getSortedHooks()) { + for (Hook hook : getSortedHooksCache()) { result = result.flatMap(hook::onEvent); } return result.map(PreCallEvent::getInputMessages); @@ -444,7 +511,7 @@ private Mono notifyPostCall(Msg finalMsg) { } PostCallEvent event = new PostCallEvent(this, finalMsg); Mono result = Mono.just(event); - for (Hook hook : getSortedHooks()) { + for (Hook hook : getSortedHooksCache()) { result = result.flatMap(hook::onEvent); } // After hooks, broadcast to subscribers @@ -460,7 +527,7 @@ private Mono notifyPostCall(Msg finalMsg) { */ private Mono notifyError(Throwable error) { ErrorEvent event = new ErrorEvent(this, error); - return Flux.fromIterable(getSortedHooks()).flatMap(hook -> hook.onEvent(event)).then(); + return Flux.fromIterable(getSortedHooksCache()).flatMap(hook -> hook.onEvent(event)).then(); } /** diff --git a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/A2aAgent.java b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/A2aAgent.java index 8b1f4d4b5..3e5c59879 100644 --- a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/A2aAgent.java +++ b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/A2aAgent.java @@ -44,6 +44,7 @@ import io.agentscope.core.message.Msg; import io.agentscope.core.message.TextBlock; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.UUID; import java.util.function.BiConsumer; @@ -107,7 +108,7 @@ private A2aAgent( this.agentCardResolver = agentCardResolver; this.memory = memory; LoggerUtil.debug(log, "A2aAgent init with config: {}", a2aAgentConfig); - getHooks().add(new A2aClientLifecycleHook()); + super.addHook(Collections.singletonList(new A2aClientLifecycleHook())); this.clientEventHandlerRouter = new ClientEventHandlerRouter(); } @@ -119,7 +120,7 @@ protected Mono doCall(List msgs) { LoggerUtil.info(log, "[{}] A2aAgent start call.", currentRequestId); LoggerUtil.debug(log, "[{}] A2aAgent call with input messages: ", currentRequestId); LoggerUtil.logTextMsgDetail(log, memory.getMessages()); - clientEventContext.setHooks(getSortedHooks()); + clientEventContext.setHooks(getSortedHooksCache()); return Mono.defer( () -> { Message message = MessageConvertUtil.convertFromMsg(memory.getMessages());