diff --git a/.github/ISSUE_TEMPLATE/0_bug_report.yml b/.github/ISSUE_TEMPLATE/0_bug_report.yml index 9baf846a6d8..2b77ff9ea90 100644 --- a/.github/ISSUE_TEMPLATE/0_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/0_bug_report.yml @@ -1,7 +1,8 @@ name: 🐛 Bug Report description: Create a report to help us improve -title: '[Bug]: ' -labels: ['BUG'] +title: "[Bug]: " +labels: [] +type: Bug body: - type: markdown attributes: diff --git a/.github/ISSUE_TEMPLATE/1_feature_request.yml b/.github/ISSUE_TEMPLATE/1_feature_request.yml index d980cb83677..de77750cf5a 100644 --- a/.github/ISSUE_TEMPLATE/1_feature_request.yml +++ b/.github/ISSUE_TEMPLATE/1_feature_request.yml @@ -1,7 +1,8 @@ name: 💡 Feature Request description: Suggest an idea for this project -title: '[Feature]: ' -labels: ['feature'] +title: "[Feature]: " +labels: [] +type: Feature body: - type: markdown attributes: diff --git a/docs/zh/references/cherryclaw/channels.md b/docs/zh/references/cherryclaw/channels.md new file mode 100644 index 00000000000..080686369f8 --- /dev/null +++ b/docs/zh/references/cherryclaw/channels.md @@ -0,0 +1,139 @@ +# CherryClaw 频道系统 + +频道系统为 CherryClaw 提供 IM 集成能力,允许用户通过 Telegram 等即时通讯平台与代理交互。系统采用抽象适配器模式,支持未来扩展到 Discord、Slack 等平台。 + +## 架构 + +``` +ChannelManager (单例, 生命周期管理) + ├── adapters Map — 活跃的适配器实例 + ├── notifyChannels Set — 标记为通知接收者的频道 + ├── start() → 加载所有 CherryClaw agent,为启用的频道创建适配器 + ├── stop() → 断开所有适配器 + └── syncAgent(agentId) → 断开旧适配器,根据当前配置重建 + +ChannelAdapter (抽象 EventEmitter) + ├── connect() / disconnect() + ├── sendMessage(chatId, text, opts?) + ├── sendMessageDraft(chatId, draftId, text) — 流式草稿更新 + ├── sendTypingIndicator(chatId) + └── Events: 'message' → ChannelMessageEvent + 'command' → ChannelCommandEvent + +ChannelMessageHandler (单例, 无状态消息路由) + ├── handleIncoming(adapter, message) — 路由到代理 session + ├── handleCommand(adapter, command) — 处理 /new /compact /help + └── sessionTracker Map — 每个 agent 的活跃 session +``` + +## 适配器注册 + +适配器通过 `registerAdapterFactory(type, factory)` 自注册。导入适配器模块即触发注册: + +```typescript +// src/main/services/agents/services/channels/adapters/TelegramAdapter.ts +registerAdapterFactory('telegram', (channel, agentId) => { + return new TelegramAdapter({ channelId: channel.id, agentId, channelConfig: channel.config }) +}) +``` + +`ChannelManager` 启动时导入所有适配器模块(通过 `channels/index.ts`),适配器的 `registerAdapterFactory` 调用作为模块副作用执行。 + +## 消息处理流程 + +### 用户消息 + +``` +用户在 Telegram 发送消息 + → TelegramAdapter 触发 'message' 事件 + → ChannelManager 转发给 ChannelMessageHandler.handleIncoming() + 1. resolveSession(agentId) + → 检查 sessionTracker → 查询已有 session → 创建新 session + 2. 发送 typing indicator(每 4s 刷新一次) + 3. 生成随机 draftId + 4. collectStreamResponse(session, text, abort, onDraft): + - 创建 session message(persist: true) + - 读取 stream: + text-delta → 更新 currentBlockText(块内累积) + text-end → 提交到 completedText,重置当前块 + - 每 500ms 通过 sendMessageDraft 发送草稿 + 5. sendMessage(chatId, finalText) — 超过 4096 字符自动分块 +``` + +### 命令处理 + +| 命令 | 行为 | +|---|---| +| `/new` | 创建新 session,更新 sessionTracker | +| `/compact` | 向当前 session 发送 `/compact`,收集响应 | +| `/help` | 返回代理名称、描述和可用命令列表 | + +## 流式响应 + +CherryClaw 的流式响应遵循以下规则: + +- `text-delta` 事件在同一个文本块内是**累积的**——每个事件包含到目前为止的完整文本,而非增量 +- `ChannelMessageHandler` 在块内使用 `text = value.text`(替换),在 `text-end` 时提交 +- 草稿通过 `sendMessageDraft` 以 500ms 节流频率发送 +- typing indicator 每 4s 刷新一次 + +## Telegram 适配器 + +### 配置 + +```typescript +{ + type: 'telegram', + id: 'unique-channel-id', + enabled: true, + is_notify_receiver: true, // 是否接收通知 + config: { + bot_token: 'YOUR_BOT_TOKEN', + allowed_chat_ids: ['123456789'] // 授权的 chat ID 列表 + } +} +``` + +### 特性 + +- 使用 **grammY** 库,仅支持长轮询(桌面应用在 NAT 后面,不支持 webhook) +- **授权守卫**:第一个中间件检查 chat ID 是否在白名单中,未授权消息直接丢弃 +- **消息分块**:超过 4096 字符的消息自动按段落/行/硬分割发送 +- **草稿流式**:通过 Telegram 的 `sendMessageDraft` API 实现实时响应流式展示 +- **通知目标**:`notifyChatIds` 等于 `allowed_chat_ids`,所有授权的 chat 都接收通知 + +### 已知限制 + +| 限制 | 说明 | +|---|---| +| 速率限制 | `sendMessage` 全局 30/s,每 chat 1/s。草稿节流 500ms,typing 4s | +| 纯文本输出 | 代理响应以纯文本发送(无 `parse_mode`),避免 MarkdownV2 转义问题 | +| 仅长轮询 | 桌面应用无法接收 webhook | + +## 通知频道 + +`ChannelManager` 通过 `notifyChannels` Set 跟踪哪些适配器的频道配置了 `is_notify_receiver: true`。`getNotifyAdapters(agentId)` 返回指定 agent 的所有通知适配器,供 `notify` MCP 工具和调度器任务通知使用。 + +## 生命周期 + +- **启动**: `channelManager.start()` 在应用就绪时与调度器一起调用 +- **停止**: `channelManager.stop()` 在应用退出时调用 +- **同步**: `channelManager.syncAgent(agentId)` 在 agent 更新/删除时调用,断开旧适配器并根据新配置重建 + +## 扩展新频道 + +添加新的频道类型只需: + +1. 实现 `ChannelAdapter` 抽象类 +2. 在模块中调用 `registerAdapterFactory(type, factory)` +3. 在 `channels/index.ts` 中导入该模块 + +## 关键文件 + +| 文件 | 说明 | +|---|---| +| `src/main/services/agents/services/channels/ChannelAdapter.ts` | 抽象接口 + 事件类型 | +| `src/main/services/agents/services/channels/ChannelManager.ts` | 生命周期管理 + 适配器工厂注册 | +| `src/main/services/agents/services/channels/ChannelMessageHandler.ts` | 消息路由 + 流式响应收集 | +| `src/main/services/agents/services/channels/adapters/TelegramAdapter.ts` | Telegram 适配器实现 | +| `src/main/services/agents/services/channels/index.ts` | 公开导出 + 适配器模块导入 | diff --git a/docs/zh/references/cherryclaw/cherryclaw.png b/docs/zh/references/cherryclaw/cherryclaw.png new file mode 100644 index 00000000000..06c34bf6900 Binary files /dev/null and b/docs/zh/references/cherryclaw/cherryclaw.png differ diff --git a/docs/zh/references/cherryclaw/mcp-claw.md b/docs/zh/references/cherryclaw/mcp-claw.md new file mode 100644 index 00000000000..e17873380df --- /dev/null +++ b/docs/zh/references/cherryclaw/mcp-claw.md @@ -0,0 +1,183 @@ +# Claw MCP 服务器 + +Claw MCP 服务器是一个内置的 MCP(Model Context Protocol)服务器,自动注入到每个 CherryClaw 会话中。它为代理提供了四个自主管理工具:`cron`(任务调度)、`notify`(通知)、`skills`(技能管理)和 `memory`(记忆管理)。 + +## 架构 + +``` +CherryClawService.invoke() + → 创建 ClawServer 实例(每次调用一个新实例) + → 注入为内存中的 MCP 服务器: + _internalMcpServers = { claw: { type: 'inmem', instance: clawServer.mcpServer } } + → ClaudeCodeService 合并到 SDK options.mcpServers + → SDK 自动发现工具: mcp__claw__cron, mcp__claw__notify, mcp__claw__skills, mcp__claw__memory +``` + +ClawServer 使用 `@modelcontextprotocol/sdk` 的 `McpServer` 类,以内存模式运行(无需 HTTP 传输)。每个 CherryClaw 会话调用时创建新实例,绑定到当前 agent 的 ID。 + +## 工具白名单 + +当 agent 配置了显式的 `allowed_tools` 白名单时,`CherryClawService` 自动追加 `mcp__claw__*` 通配符,确保 SDK 不会过滤掉内部 MCP 工具。当 `allowed_tools` 为 undefined(无限制)时,所有工具已可用,无需注入。 + +--- + +## cron 工具 + +管理代理的调度任务。代理可以自主创建、查看和删除定期执行的任务。 + +### 动作 + +#### `add` — 创建任务 + +| 参数 | 类型 | 必填 | 说明 | +|---|---|---|---| +| `name` | string | 是 | 任务名称 | +| `message` | string | 是 | 执行时的提示词/指令 | +| `cron` | string | 三选一 | cron 表达式,如 `0 9 * * 1-5` | +| `every` | string | 三选一 | 持续时间,如 `30m`、`2h`、`1h30m` | +| `at` | string | 三选一 | RFC3339 时间戳,用于一次性任务 | +| `session_mode` | string | 否 | `reuse`(默认,保留对话历史)或 `new`(每次新会话) | + +`cron`、`every`、`at` 三者只能选一个。`every` 格式支持 `30m`、`2h`、`1h30m` 等人类友好的时间表示,内部转换为分钟数。 + +调度类型映射: +- `cron` → `schedule_type: 'cron'` +- `every` → `schedule_type: 'interval'`(值为分钟数) +- `at` → `schedule_type: 'once'`(值为 ISO 时间戳) + +会话模式映射: +- `reuse` → `context_mode: 'session'` +- `new` → `context_mode: 'isolated'` + +#### `list` — 列出任务 + +无参数。返回当前 agent 的所有调度任务(上限 100 条),JSON 格式。 + +#### `remove` — 删除任务 + +| 参数 | 类型 | 必填 | 说明 | +|---|---|---|---| +| `id` | string | 是 | 任务 ID | + +--- + +## notify 工具 + +通过已连接的频道(如 Telegram)向用户发送通知消息。代理可以主动通知用户任务结果、状态更新或其他重要信息。 + +### 参数 + +| 参数 | 类型 | 必填 | 说明 | +|---|---|---|---| +| `message` | string | 是 | 通知内容 | +| `channel_id` | string | 否 | 仅发送到指定频道(省略则发送到所有通知频道) | + +### 行为 + +1. 获取当前 agent 的所有 `is_notify_receiver: true` 的频道适配器 +2. 如果指定了 `channel_id`,过滤到该频道 +3. 向每个适配器的所有 `notifyChatIds` 发送消息 +4. 返回发送数量和可能的错误 + +如果没有配置通知频道,返回提示信息而非报错。 + +--- + +## skills 工具 + +管理代理工作区中的 Claude 技能。支持从市场搜索、安装、卸载和列出已安装的技能。 + +### 动作 + +#### `search` — 搜索技能 + +| 参数 | 类型 | 必填 | 说明 | +|---|---|---|---| +| `query` | string | 是 | 搜索关键词 | + +查询公开市场 API(`claude-plugins.dev/api/skills`),返回匹配的技能列表,包含 `name`、`description`、`author`、`identifier`(用于安装)和 `installs` 数量。搜索词中的 `-` 和 `_` 会被替换为空格以提高匹配率。 + +#### `install` — 安装技能 + +| 参数 | 类型 | 必填 | 说明 | +|---|---|---|---| +| `identifier` | string | 是 | 市场技能标识符,格式 `owner/repo/skill-name` | + +内部构造 `marketplace:skill:{identifier}` 路径,委托给 `PluginService.install()` 完成安装。 + +#### `remove` — 卸载技能 + +| 参数 | 类型 | 必填 | 说明 | +|---|---|---|---| +| `name` | string | 是 | 技能文件夹名称(从 list 结果获取) | + +委托给 `PluginService.uninstall()` 完成卸载。 + +#### `list` — 列出已安装技能 + +无参数。返回当前 agent 已安装的所有技能,包含 `name`、`folder` 和 `description`。 + +--- + +## memory 工具 + +管理跨会话的持久化记忆。这是 CherryClaw 记忆系统的写入接口(读取通过系统提示词中的内联内容实现)。 + +### 设计原则 + +工具描述中编码了记忆决策逻辑: + +> 写入 FACT.md 之前,问自己:这个信息 6 个月后还重要吗?如果不是,用 append 代替。 + +### 动作 + +#### `update` — 更新 FACT.md + +| 参数 | 类型 | 必填 | 说明 | +|---|---|---|---| +| `content` | string | 是 | FACT.md 的完整 markdown 内容 | + +原子写入:先写临时文件,再通过 `rename` 替换。确保不会出现写入中途崩溃导致的文件损坏。 + +文件路径支持大小写不敏感匹配。如果 `memory/` 目录不存在会自动创建。 + +**注意**:此操作是全量覆盖,不是增量编辑。代理需要先读取现有内容,修改后再写回完整内容。 + +#### `append` — 追加日志条目 + +| 参数 | 类型 | 必填 | 说明 | +|---|---|---|---| +| `text` | string | 是 | 日志条目文本 | +| `tags` | string[] | 否 | 标签列表 | + +追加一行 JSON 到 `memory/JOURNAL.jsonl`,格式: + +```json +{"ts":"2026-03-10T12:00:00.000Z","tags":["deploy","production"],"text":"部署 v2.1 到生产环境"} +``` + +时间戳自动生成。适用于一次性事件、已完成任务、会话摘要等短期信息。 + +#### `search` — 搜索日志 + +| 参数 | 类型 | 必填 | 说明 | +|---|---|---|---| +| `query` | string | 否 | 大小写不敏感的子串匹配 | +| `tag` | string | 否 | 按标签过滤 | +| `limit` | integer | 否 | 最大返回数量(默认 20) | + +返回匹配的日志条目,按时间倒序排列。`query` 和 `tag` 可以组合使用。 + +--- + +## 错误处理 + +所有工具调用在内部 try-catch 中执行。当发生错误时,返回 `{ isError: true }` 的 MCP 响应,包含错误消息。错误同时记录到 `loggerService`。 + +## 关键文件 + +| 文件 | 说明 | +|---|---| +| `src/main/mcpServers/claw.ts` | ClawServer 完整实现(4 个工具 + 辅助函数) | +| `src/main/mcpServers/__tests__/claw.test.ts` | 37 个单元测试 | +| `src/main/services/agents/services/cherryclaw/index.ts` | MCP 服务器注入逻辑 | diff --git a/docs/zh/references/cherryclaw/overview.md b/docs/zh/references/cherryclaw/overview.md new file mode 100644 index 00000000000..231b99fdcca --- /dev/null +++ b/docs/zh/references/cherryclaw/overview.md @@ -0,0 +1,126 @@ +# CherryClaw 整体设计 + +

+ CherryClaw +

+ +CherryClaw 是 Cherry Studio 中的自主代理(autonomous agent)类型,基于 Claude Agent SDK 构建。与标准的 claude-code 代理不同,CherryClaw 拥有独立的人格系统、基于任务的调度器、IM 频道集成,以及一组通过内部 MCP 服务器提供的自主管理工具。 + +## 架构概览 + +``` +CherryClawService + ├── PromptBuilder — 从工作区文件组装完整系统提示词 + ├── HeartbeatReader — 读取心跳文件内容(用于调度任务前置提示) + ├── ClawServer (MCP) — 内置 MCP 服务器,提供 cron / notify / skills / memory 工具 + ├── SchedulerService — 60s 轮询调度器,从 DB 查询到期任务并执行 + ├── TaskService — 任务 CRUD + 下次运行时间计算 + └── ChannelManager — 频道适配器生命周期管理(Telegram 等) +``` + +## 核心设计决策 + +### AgentServiceRegistry 模式 + +`SessionMessageService` 不再硬编码 `ClaudeCodeService`,而是通过 `AgentServiceRegistry` 根据 `AgentType` 查找对应的服务实现。CherryClaw 在运行时通过注册表委托给 claude-code 执行。 + +```typescript +// src/main/services/agents/services/AgentServiceRegistry.ts +agentServiceRegistry.register('claude-code', new ClaudeCodeService()) +agentServiceRegistry.register('cherry-claw', new CherryClawService()) +``` + +### 自定义系统提示词(替换 Claude Code 预设) + +CherryClaw 不使用 Claude Code 的预设系统提示词。`PromptBuilder` 从工作区文件组装完整的自定义提示词,通过 `_systemPrompt` 字段传递给 `ClaudeCodeService`。当该字段存在时,它作为完整的系统提示词使用,而非预设 + 追加模式。 + +### 禁用不适用的内置工具 + +CherryClaw 通过 `_disallowedTools` 禁用了一组不适合自主运行的 SDK 内置工具: + +| 被禁用的工具 | 原因 | +|---|---| +| `CronCreate` / `CronDelete` / `CronList` | 由内部 MCP cron 工具替代 | +| `TodoWrite` | 不适合自主代理 | +| `AskUserQuestion` | 自主代理不应向用户提问 | +| `EnterPlanMode` / `ExitPlanMode` | 不适合自主代理 | +| `EnterWorktree` / `NotebookEdit` | 不适合自主代理 | + +## 调用流程 + +``` +CherryClawService.invoke() + 1. PromptBuilder.buildSystemPrompt(workspacePath) + → 加载 system.md(可选覆盖)+ soul.md + user.md + memory/FACT.md + → 组装为完整系统提示词 + 2. 创建 ClawServer 实例(内存中的 MCP 服务器) + → 注入为 _internalMcpServers = { claw: { type: 'inmem', instance } } + 3. 设置 _disallowedTools(禁用不适用工具) + 4. 如果 agent 有 allowed_tools 白名单,追加 mcp__claw__* 通配符 + 5. 委托给 ClaudeCodeService.invoke() + → 使用 _systemPrompt 作为完整替换 + → 合并 _internalMcpServers 到 SDK options.mcpServers + → Claude SDK 自动发现 cron / notify / skills / memory 工具 +``` + +## 记忆系统 + +CherryClaw 采用受 Anna 启发的三文件记忆模型,每个文件有独立的职责范围: + +``` +{workspace}/ + system.md — 可选的系统提示词覆盖(替换默认 CherryClaw 身份) + soul.md — 你是谁:人格、语气、沟通风格 + user.md — 用户是谁:名字、偏好、个人上下文 + memory/ + FACT.md — 你知道什么:持久的项目知识、技术决策(6 个月以上) + JOURNAL.jsonl — 事件日志:一次性事件、已完成任务、会话笔记(仅追加) +``` + +关键规则: +- 每个文件有独立作用域,不跨文件重复信息 +- `soul.md` 和 `user.md` 通过 Read/Edit 工具直接编辑 +- `FACT.md` 和 `JOURNAL.jsonl` 通过 `memory` MCP 工具管理 +- 代理自主更新,不请求用户批准 +- 文件名不区分大小写 + +### PromptBuilder 缓存机制 + +`PromptBuilder` 对所有文件读取使用基于 mtime 的缓存。每次读取时仅执行一次 `fs.stat` 检查——如果文件修改时间未变,直接返回缓存内容,无需持久化文件监听器。 + +## 数据库 + +CherryClaw 使用 Drizzle ORM + LibSQL(SQLite)存储任务数据: + +| 表名 | 用途 | +|---|---| +| `scheduled_tasks` | 调度任务(名称、提示词、调度类型、下次运行时间、状态) | +| `task_run_logs` | 任务运行日志(运行时间、耗时、状态、结果/错误) | + +两个表均通过外键级联关联到 agents 表。 + +## API 端点 + +| 方法 | 路径 | 说明 | +|---|---|---| +| `GET` | `/v1/agents/:agentId/tasks` | 列出任务 | +| `POST` | `/v1/agents/:agentId/tasks` | 创建任务 | +| `GET` | `/v1/agents/:agentId/tasks/:taskId` | 获取任务详情 | +| `PATCH` | `/v1/agents/:agentId/tasks/:taskId` | 更新任务 | +| `DELETE` | `/v1/agents/:agentId/tasks/:taskId` | 删除任务 | +| `POST` | `/v1/agents/:agentId/tasks/:taskId/run` | 手动触发运行 | +| `GET` | `/v1/agents/:agentId/tasks/:taskId/logs` | 获取运行日志 | + +## 关键文件 + +| 文件 | 说明 | +|---|---| +| `src/main/services/agents/services/cherryclaw/index.ts` | CherryClawService 入口 | +| `src/main/services/agents/services/cherryclaw/prompt.ts` | PromptBuilder 系统提示词组装 | +| `src/main/services/agents/services/cherryclaw/heartbeat.ts` | HeartbeatReader 心跳文件读取 | +| `src/main/services/agents/services/AgentServiceRegistry.ts` | 代理服务注册表 | +| `src/main/services/agents/services/TaskService.ts` | 任务 CRUD + 调度计算 | +| `src/main/services/agents/services/SchedulerService.ts` | 轮询调度器 | +| `src/main/mcpServers/claw.ts` | Claw MCP 服务器 | +| `src/main/services/agents/services/channels/` | 频道抽象层 | +| `src/main/services/agents/database/schema/tasks.schema.ts` | 任务表 schema | diff --git a/docs/zh/references/cherryclaw/scheduler.md b/docs/zh/references/cherryclaw/scheduler.md new file mode 100644 index 00000000000..c9927729a7b --- /dev/null +++ b/docs/zh/references/cherryclaw/scheduler.md @@ -0,0 +1,119 @@ +# CherryClaw 调度器 + +CherryClaw 的调度器采用受 nanoclaw 启发的基于任务的轮询设计。数据库是唯一的状态源——无需在内存中维护定时器状态,应用重启后自动恢复。 + +## 架构 + +``` +SchedulerService (单例, 轮询循环) + startLoop() + → 每 60s 执行一次 tick() + → taskService.getDueTasks() + → SELECT * FROM scheduled_tasks WHERE status='active' AND next_run <= now() + → 对每个到期任务调用 runTask(task) (fire-and-forget) + + runTask(task) + 1. 加载 agent 配置 + 2. 读取心跳文件,拼接到任务提示词前面(可选) + 3. 根据 context_mode 查找或创建 session + 4. sessionMessageService.createSessionMessage({ persist: true }) + 5. 排空 stream 等待 completion + 6. 记录运行日志到 task_run_logs + 7. computeNextRun() 计算下次运行时间 + 8. 通过频道发送任务完成/失败通知(可选) + + stopLoop() + → 清除定时器,abort 所有运行中的任务 +``` + +## 调度类型 + +| 类型 | `schedule_value` 格式 | 说明 | +|---|---|---| +| `cron` | cron 表达式,如 `0 9 * * 1-5` | 标准 cron 调度(使用 cron-parser v5) | +| `interval` | 分钟数,如 `30` | 固定间隔执行 | +| `once` | ISO 8601 时间戳 | 一次性任务,执行后自动标记为 completed | + +## 防漂移间隔计算 + +`computeNextRun()` 锚定到上一次的 `next_run` 时间戳,而非当前时间。如果错过了多个间隔(例如应用关闭期间),它会跳过已过期的间隔,直接计算下一个未来时间点: + +```typescript +// 锚定到计划时间,防止累积漂移 +let next = new Date(task.next_run).getTime() + intervalMs +while (next <= now) { + next += intervalMs +} +``` + +这种方式确保了间隔调度不会因任务执行耗时或轮询延迟产生累积偏差。 + +## 上下文模式 + +每个任务可以配置 `context_mode`: + +| 模式 | 行为 | +|---|---| +| `session` | 复用已有 session,保持多轮对话上下文 | +| `isolated` | 每次执行创建新 session,无历史上下文 | + +当使用 `session` 模式时,`SessionMessageService` 会捕获 SDK 的 `session_id`(来自 `system/init` 消息)并持久化为 `agent_session_id`,下次运行时作为 `options.resume` 传入,实现跨执行的对话连续性。 + +## 心跳文件 + +如果 agent 配置了 `heartbeat_enabled: true`,调度器会在执行任务前读取心跳文件(默认路径由 `heartbeat_file` 配置指定)并作为前置上下文拼接到任务提示词中: + +``` +[Heartbeat] +{心跳文件内容} + +[Task] +{任务提示词} +``` + +`HeartbeatReader` 内置路径遍历保护,确保心跳文件路径不会逃逸出工作区目录。 + +## 连续错误处理 + +调度器跟踪每个任务的连续错误次数。连续失败 3 次后,任务自动暂停(`status: 'paused'`)。错误计数在下一次成功运行时重置。此状态在内存中跟踪,不持久化。 + +## 任务完成通知 + +每次任务运行后,`notifyTaskResult()` 向所有启用了 `is_notify_receiver` 的频道发送状态消息: + +``` +[Task completed] 任务名称 +Duration: 12s +``` + +或失败时: + +``` +[Task failed] 任务名称 +Duration: 5s +Error: 错误信息 +``` + +通知以 fire-and-forget 方式发送,不阻塞调度循环。 + +## 手动触发 + +除了自动调度,每个任务也可以通过 API 或 UI 手动触发: + +- API: `POST /v1/agents/:agentId/tasks/:taskId/run` +- UI: 任务设置列表中的「运行」按钮 + +`runTaskNow()` 会验证任务是否存在、是否正在运行(重复运行返回 409),然后在后台触发执行。 + +## 向后兼容 + +`startScheduler(agent)` 和 `stopScheduler(agentId)` 保留为空操作(no-op)以兼容现有的 agent handler 代码。所有调度逻辑由轮询循环通过数据库状态驱动。 + +## 关键文件 + +| 文件 | 说明 | +|---|---| +| `src/main/services/agents/services/SchedulerService.ts` | 轮询调度器主逻辑 | +| `src/main/services/agents/services/TaskService.ts` | 任务 CRUD、getDueTasks、computeNextRun | +| `src/main/services/agents/database/schema/tasks.schema.ts` | scheduled_tasks + task_run_logs 表定义 | +| `resources/database/drizzle/0003_wise_meltdown.sql` | 数据库迁移脚本 | diff --git a/handoff.md b/handoff.md new file mode 100644 index 00000000000..d569075ec86 --- /dev/null +++ b/handoff.md @@ -0,0 +1,320 @@ +# Handoff + +## Goal + +Implement CherryClaw — a new autonomous agent type for Cherry Studio with soul-driven personality, scheduler-based autonomous operation, heartbeat-driven task execution, and IM channel integration. Full implementation across all 4 phases from `.agents/sessions/2026-03-10-cherry-claw/plan.md`, plus a task-based scheduler redesign inspired by nanoclaw, plus an internal claw MCP server so the agent can autonomously manage its own scheduled tasks, plus a channel abstraction layer with Telegram and QQ adapters. + +## Progress + +All 4 phases are complete, plus the scheduler redesign and claw MCP tool: + +- **Phase 1**: Type system, config defaults, i18n keys — DONE +- **Phase 2**: Backend services (registry, soul, heartbeat, claw service, scheduler, lifecycle hooks) — DONE +- **Phase 3**: Frontend UI (creation modal, settings tabs, list differentiation) — DONE +- **Phase 4**: Unit tests (22 tests across 4 files) — DONE +- **Phase 5**: Scheduler redesign — tasks as first-class DB entities, poll-loop scheduler, task management UI — DONE +- **Phase 6**: Claw MCP server — internal `cron` tool auto-injected into CherryClaw sessions — DONE +- **Phase 7**: Channel abstraction layer + Telegram adapter + channel settings UI — DONE +- **Phase 7b**: QQ channel adapter — WebSocket gateway, REST API message sending, multi-message type support (c2c/group/guild/dm) — DONE +- **Phase 8**: Channel streaming — `sendMessageDraft` for real-time response streaming, multi-turn accumulation, typing indicators — DONE +- **Phase 9**: Headless message persistence — channel and scheduler messages now persist to DB — DONE +- **Phase 10**: Basic sandbox — PreToolUse hook path enforcement + OS-level sandbox + UI toggle — DONE (basic restriction only, needs hardening) +- **Phase 11**: Notify tool — `notify` MCP tool for CherryClaw to send messages to users via channels, scheduler auto-notifications on task completion/failure — DONE +- **Phase 12**: Manual task run — `POST /:taskId/run` API endpoint + "Run" button in task settings UI for manually triggering scheduled tasks — DONE +- **Phase 13**: Scheduler session resume + claw MCP tool injection — SDK session_id capture for `options.resume`, auto-add claw MCP tools to `allowed_tools` — DONE +- **Phase 14**: Claw MCP skills tool — `skills` MCP tool with search/install/remove/list actions, reuses `PluginService` for install/uninstall/list and marketplace API for search — DONE +- **Phase 15**: System prompt & memory — full custom system prompt replaces Claude Code preset; workspace files (system.md, soul.md, user.md, memory/FACT.md) assembled by `PromptBuilder`; `memory` MCP tool with update/append/search actions for FACT.md + JOURNAL.jsonl — DONE +- **Phase 16**: Heartbeat redesign — heartbeat.md as workspace file, heartbeat as auto-created scheduled task (name='heartbeat'), toggle + interval UI in Tasks settings, `HeartbeatReader` simplified (no filename param), `TaskService.listTasks` filters heartbeat by default — DONE +- **Validation**: `pnpm lint`, `pnpm test`, `pnpm format` all pass (198 test files, 3617 tests) + +## Key Decisions + +- **AgentServiceRegistry pattern** — replaced hardcoded `ClaudeCodeService` in `SessionMessageService` with a registry mapping `AgentType` → `AgentServiceInterface`. CherryClaw delegates to claude-code at runtime via registry lookup. +- **Task-based scheduler (nanoclaw-inspired)** — replaced per-agent setTimeout chains with a single 60s poll loop that queries `scheduled_tasks WHERE status='active' AND next_run <= now()`. DB is the source of truth; no timer state to restore on restart. +- **Drift-resistant interval computation** — `computeNextRun()` anchors to the previous `next_run` timestamp and skips past missed intervals, preventing cumulative drift (ported from nanoclaw). +- **Tasks as first-class entities** — new `scheduled_tasks` and `task_run_logs` Drizzle tables with FK cascades to agents. Users can create/edit/pause/delete multiple tasks per agent via the UI. +- **cron-parser v5** — uses `CronExpressionParser.parse()` API (not the older `parseExpression`). +- **mtime-based cache for workspace files** — `PromptBuilder` caches all file reads (soul.md, user.md, system.md, FACT.md) with single `fs.stat` check per read, no persistent file watchers. +- **Heartbeat as a scheduled task** — heartbeat is a special task with `name='heartbeat'` auto-created for each CherryClaw agent. Reuses the existing `TaskService` + `SchedulerService` poll loop infrastructure. Config: `heartbeat_enabled` (boolean, default true) + `heartbeat_interval` (minutes, default 30). On each tick, `SchedulerService.runTask()` detects `task.name === 'heartbeat'` and reads `{workspace}/heartbeat.md` via `HeartbeatReader`. If the file exists and heartbeat is enabled, its content is sent to the agent's main session. If the file is missing or heartbeat is disabled, the tick is skipped silently. `TaskService.listTasks()` excludes heartbeat tasks by default (pass `{ includeHeartbeat: true }` to include). `SchedulerService.ensureHeartbeatTask(agentId, intervalMinutes)` creates or updates the heartbeat task — called on agent create and update. UI shows a toggle + interval input at the top of the Tasks settings page. +- **Default emoji 🦞** — CherryClaw agents get lobster claw emoji as default avatar in the agent list. +- **Placeholder cherry-claw.png** — copied from claude.png; needs a proper distinct avatar image. +- **i18n strict nesting** — task keys use proper nested objects (e.g., `tasks.contextMode.session` not `tasks.contextMode.session` + `tasks.contextMode.session.desc`) to pass the i18n checker. +- **Internal claw MCP server (anna-inspired)** — `cron` tool with `add`/`list`/`remove` actions + `notify` tool for sending messages to users via channels + `memory` tool for persistent knowledge + `skills` tool for marketplace skill management, auto-injected into every CherryClaw session via `_internalMcpServers`. Uses the `@modelcontextprotocol/sdk` Server class, served over Streamable HTTP at `/v1/claw/:agentId/claw-mcp`. The cron tool maps anna-style inputs (`cron`, `every`, `at`, `session_mode`) to TaskService's schema (`schedule_type`, `schedule_value`, `context_mode`). The notify tool sends messages to all channels with `is_notify_receiver: true`, or to a specific channel by ID. +- **Notify channels** — `ChannelManager` tracks which adapters have `is_notify_receiver: true` via `notifyChannels` set. `getNotifyAdapters(agentId)` returns connected adapters for notification. Each adapter exposes `notifyChatIds` (set by subclass) for target chat IDs. +- **Scheduler task notifications** — After each task run, `SchedulerService.notifyTaskResult()` sends a status message (`[Task completed/failed] name, duration, error`) to notify-enabled channels. Fire-and-forget, never blocks scheduling. +- **Manual task run** — `POST /v1/agents/:agentId/tasks/:taskId/run` triggers `schedulerService.runTaskNow()` which validates the task, checks it's not already running (409 if so), then fires `runTask()` in background. UI has a "Run" button per task in the task settings list. +- **SDK session resume for scheduler** — The Claude Agent SDK's `session_id` (needed for `options.resume`) is captured in `ClaudeCodeService.processSDKQuery()` from the `system/init` message and stored on the `AgentStream.sdkSessionId` property. `SessionMessageService` reads it on stream complete and persists it as `agent_session_id` in `sessionMessagesTable` via `persistHeadlessExchange()`. On the next scheduler run with `context_mode: 'session'`, `getLastAgentSessionId()` finds the stored value and passes it as `options.resume`, enabling multi-turn conversation continuity. +- **Claw MCP tool auto-allow** — `CherryClawService.invoke()` appends `mcp__claw__cron`, `mcp__claw__notify`, `mcp__claw__skills`, and `mcp__claw__memory` to `allowed_tools` when the agent has an explicit tool whitelist. This ensures the SDK doesn't filter out the claw MCP tools. When `allowed_tools` is undefined (default), all tools are already available and no injection is needed. +- **Skills MCP tool** — `skills` tool with `search`/`install`/`remove`/`list` actions. Reuses `PluginService` (singleton) for install, uninstall, and list operations — `PluginService` internally resolves workspace path from `agent.accessible_paths[0]` via `AgentService`. Search queries the public marketplace API (`claude-plugins.dev/api/skills`) via Electron's `net.fetch`. The `buildSkillIdentifier()` helper constructs `owner/repo/name` identifiers from marketplace response metadata, matching the renderer's `buildSkillSourceKey()` logic. +- **Custom system prompt (replaces Claude Code preset)** — CherryClaw no longer uses the Claude Code preset system prompt with `append`. Instead, `PromptBuilder` assembles a complete custom system prompt from workspace files and passes it as a plain string via `_systemPrompt` on the enhanced session. `ClaudeCodeService` checks for `_systemPrompt` first; when set, it becomes the full `systemPrompt` (with language instruction appended). When not set, falls back to the existing preset+append behavior for regular claude-code agents. This allows CherryClaw to have its own identity, guidelines, and tool documentation independent of the Claude Code defaults. +- **Anna-inspired memory model** — strict 3-file model with exclusive scopes: `soul.md` (WHO you are — personality), `user.md` (WHO the user is — preferences), `memory/FACT.md` (WHAT you know — durable knowledge). Each file has XML-tagged sections (``, ``, ``) in the system prompt with scope documentation to prevent cross-file duplication. `memory/JOURNAL.jsonl` provides an append-only event log for ephemeral events. The system prompt instructs the agent to "update autonomously — never ask for approval". SOUL.md and USER.md are edited directly via Read/Write tools; FACT.md and JOURNAL are managed exclusively via the `memory` MCP tool. +- **Memory MCP tool** — `memory` tool with 3 actions: `update` (atomically overwrites `memory/FACT.md` via temp file + rename), `append` (adds timestamped JSON entry to `memory/JOURNAL.jsonl` with optional tags), `search` (case-insensitive substring search on journal, filtered by tag, reverse-chronological order, configurable limit). The tool description encodes the memory decision logic: "Before writing to FACT.md, ask: will this still matter in 6 months? If not, use append instead." +- **system.md workspace override** — `PromptBuilder` checks for `system.md` in the workspace root. If present, it replaces the default basic prompt (CherryClaw identity + guidelines). This allows per-workspace customization of the agent's base behavior without modifying code. +- **Disallowed builtin tools** — CherryClaw disables SDK builtin tools not suited for autonomous operation via `_disallowedTools`: `CronCreate`/`CronDelete`/`CronList` (replaced by claw MCP cron tool), `TodoWrite`, `AskUserQuestion`, `EnterPlanMode`, `ExitPlanMode`, `EnterWorktree`, `NotebookEdit`. Mapped to `options.disallowedTools` in the SDK. Note: `disallowedTools` only affects tools, not skills — skills are invoked via the `Skill` tool and cannot be blocked this way. +- **Basic sandbox (not a real security sandbox)** — When `sandbox_enabled` is true, two layers restrict filesystem access: (1) a `PreToolUse` hook in `ClaudeCodeService` that inspects every tool call's target paths and denies access outside `_sandboxAllowedPaths`, and (2) the SDK's OS-level `sandbox.enabled` option. The hook approach works regardless of `permissionMode` (including `bypassPermissions`) because PreToolUse hooks always fire before permission checks. Bash commands are checked via regex extraction of absolute paths from the command string — this is **best-effort, not secure**: commands like `cd / && cat etc/passwd` or variable expansion can bypass it. The OS sandbox (`sandbox.enabled: true`, `allowUnsandboxedCommands: false`) is meant to be the fallback but does not reliably restrict reads on macOS. This is a basic restriction for well-behaved agents, not a security boundary. +- **Channel abstraction layer** — `ChannelAdapter` (abstract EventEmitter), `ChannelManager` (singleton lifecycle), `ChannelMessageHandler` (stateless message routing + stream collection). Adapters are registered via `registerAdapterFactory(type, factory)` and auto-created from agent config on startup. Future channels (Discord, Slack) plug in by implementing `ChannelAdapter` and registering a factory. +- **Stream response collection** — `text-delta` events from the transform layer are cumulative within a text block. `ChannelMessageHandler` tracks per-block text (`text = value.text`) and commits on `text-end` to accumulate across multi-turn agent responses. Drafts are streamed to the chat via `sendMessageDraft` (throttled at 500ms) while `sendTypingIndicator` runs every 4s throughout the request. +- **Channel config in agent settings** — stored in `CherryClawConfiguration.channels[]`. UI is a catalog of available channel types with inline config (enable switch, bot token, allowed chat IDs). No DB migration needed. +- **grammY library** — Telegram Bot API client, long polling only (desktop app behind NAT). `sendMessageDraft` is Telegram's native streaming draft API. +- **QQ Bot API (ws package)** — QQ channel adapter uses official QQ Bot API with WebSocket gateway for receiving messages and REST API for sending. Supports c2c (private), group, guild (channel), and dm message types. Uses AppID + ClientSecret authentication with access token caching. No native draft/streaming API, so `sendMessageDraft` is a no-op. + +## Scheduler Architecture + +``` +SchedulerService (singleton, poll loop) + startLoop() → polls every 60s + tick() → taskService.getDueTasks() → for each due task: + runTask(task) + 1. Load agent config + 2. If task.name === 'heartbeat': + - Check heartbeat_enabled config + read heartbeat.md from workspace + - If disabled or file missing → skip (update next_run, return) + - Otherwise use file content as prompt + 3. Find/create session based on context_mode + 4. sessionMessageService.createSessionMessage() + 5. Log run to task_run_logs + 6. computeNextRun() → updateTaskAfterRun() + ensureHeartbeatTask(agentId, intervalMinutes) → creates/updates heartbeat task + stopLoop() → clears timer, aborts active tasks + +TaskService (CRUD + scheduling logic) + createTask / getTask / listTasks / updateTask / deleteTask + listTasks(agentId, { includeHeartbeat? }) → excludes heartbeat tasks by default + getDueTasks() → SELECT WHERE status='active' AND next_run <= now() + computeNextRun(task) → drift-resistant next run calculation + updateTaskAfterRun() → updates next_run, last_run, last_result + logTaskRun() → inserts into task_run_logs +``` + +API: `GET/POST /v1/agents/:agentId/tasks`, `GET/PATCH/DELETE /v1/agents/:agentId/tasks/:taskId`, `POST /v1/agents/:agentId/tasks/:taskId/run`, `GET /v1/agents/:agentId/tasks/:taskId/logs` + +## Claw MCP Architecture + +``` +CherryClawService.invoke() + → builds _systemPrompt via PromptBuilder (system.md + memories section) + → injects _internalMcpServers = { 'claw': { url: /v1/claw/:agentId/claw-mcp } } + → delegates to ClaudeCodeService.invoke() + → uses _systemPrompt as full replacement (not preset+append) + → merges _internalMcpServers into options.mcpServers + → Claude SDK auto-discovers cron, notify, skills, and memory tools + +PromptBuilder (src/main/services/agents/services/cherryclaw/prompt.ts) + buildSystemPrompt(workspacePath): + 1. Load basic prompt: workspace system.md > embedded default (CherryClaw identity) + 2. Load memories: soul.md, user.md, memory/FACT.md (all mtime-cached) + 3. Assemble: basic prompt + memories section (XML-tagged , , ) + +ClawServer (per-agent instance, src/main/mcpServers/claw.ts) + cron tool: + add → validates schedule (cron/every/at), maps to TaskService.createTask() + list → TaskService.listTasks() + remove → TaskService.deleteTask() + notify tool: + message → channelManager.getNotifyAdapters() → adapter.sendMessage() to all notifyChatIds + channel_id (optional) → filter to specific channel + skills tool: + search → queries marketplace API (claude-plugins.dev/api/skills?q=...) via net.fetch + install → PluginService.install({ sourcePath: 'marketplace:skill:owner/repo/name' }) + remove → PluginService.uninstall({ filename, type: 'skill' }) + list → PluginService.listInstalled() filtered to type === 'skill' + memory tool: + update → atomically overwrites memory/FACT.md (temp file + rename) + append → adds timestamped JSON entry to memory/JOURNAL.jsonl with optional tags + search → case-insensitive substring search on journal (tag filter, limit, reverse-chronological) + +Route: /v1/claw/:agentId/claw-mcp (Streamable HTTP MCP transport) + Per-session ClawServer + Transport pairs (MCP SDK Server only supports one transport) + sessions Map with cleanup on close +``` + +## Memory File Layout + +``` +{workspace}/ + system.md — optional system prompt override (replaces default CherryClaw identity) + soul.md — WHO you are: personality, tone, communication style + user.md — WHO the user is: name, preferences, personal context + heartbeat.md — standing instructions for periodic execution (e.g., "check my email") + memory/ + FACT.md — WHAT you know: durable project knowledge, technical decisions (6+ months) + JOURNAL.jsonl — event log: one-time events, completed tasks, session notes (append-only) +``` + +Rules enforced via system prompt: +- Each file has an exclusive scope — never duplicate information across files +- soul.md and user.md: edited directly via Read/Write tools +- FACT.md and JOURNAL.jsonl: managed exclusively via the `memory` MCP tool +- Updates are autonomous — agent never asks for approval + +## Channel Architecture + +``` +ChannelManager (singleton, lifecycle) + start() → loads all CherryClaw agents, creates adapters for enabled channels + stop() → disconnects all adapters + syncAgent(agentId) → disconnect old adapters, re-create from current config + +ChannelAdapter (abstract EventEmitter) + connect() / disconnect() + sendMessage(chatId, text, opts?) + sendMessageDraft(chatId, draftId, text) → stream partial response + sendTypingIndicator(chatId) + Events: 'message' → ChannelMessageEvent, 'command' → ChannelCommandEvent + +ChannelMessageHandler (singleton, stateless routing) + handleIncoming(adapter, message): + 1. resolveSession(agentId) → get/create session (tracked per agent) + 2. Start typing indicator interval (every 4s) + 3. Generate random draftId + 4. collectStreamResponse(session, text, abort, onDraft): + - Read stream, track completedText + currentBlockText + - text-delta → update currentBlockText (cumulative within block) + - text-end → commit block to completedText, reset for next turn + - Throttled onDraft(fullText) via sendMessageDraft every 500ms + 5. sendMessage(chatId, finalText) with chunking for >4096 chars + + handleCommand(adapter, command): + /new → create new session, update tracker + /compact → send '/compact' to session, collect response + /help → static help text + + Session tracking: Map + resolveSession: tracker → first existing session → create new +``` + +Adapter registration: adapters self-register via `registerAdapterFactory(type, factory)` as a side effect of importing their module. `ChannelManager` imports all adapter modules from the index. + +Wiring: `channelManager.start()` called alongside scheduler on app ready; `channelManager.stop()` on quit. `channelManager.syncAgent()` called on agent update/delete. + +## Files Changed + +### Type System & Config +- `src/renderer/src/types/agent.ts` — added `cherry-claw` to `AgentTypeSchema`, `CherryClawConfiguration`, `SchedulerType`, `CherryClawChannel` types; added `ScheduledTaskEntity`, `TaskRunLogEntity`, `CreateTaskRequest`, `UpdateTaskRequest`, `ListTasksResponse`, `ListTaskLogsResponse`, `TaskIdParamSchema` +- `src/renderer/src/config/agent.ts` — added `DEFAULT_CHERRY_CLAW_CONFIG`, `CherryClawAvatar`, updated `getAgentTypeAvatar` +- `src/main/apiServer/generated/openapi-spec.json` — added `cherry-claw` to AgentType enum +- `src/main/apiServer/routes/agents/index.ts` — updated Swagger enum, mounted task routes + +### Database Schema +- `src/main/services/agents/database/schema/tasks.schema.ts` — NEW: `scheduledTasksTable` + `taskRunLogsTable` with FK cascades, indexes +- `src/main/services/agents/database/schema/index.ts` — added tasks schema export +- `resources/database/drizzle/0003_wise_meltdown.sql` — NEW: migration for scheduled_tasks + task_run_logs tables + +### Backend Services +- `src/main/services/agents/services/AgentServiceRegistry.ts` — NEW: maps AgentType → AgentServiceInterface +- `src/main/services/agents/services/SessionMessageService.ts` — refactored to use registry; added `CreateMessageOptions.persist`, `TextStreamAccumulator.getText()`, `persistHeadlessExchange()` for headless message persistence; fixed cumulative text-delta `+=` → `=`; reads `claudeStream.sdkSessionId` on complete for resume persistence +- `src/main/services/agents/services/cherryclaw/index.ts` — CherryClawService (custom system prompt via PromptBuilder + claw MCP injection + disallowed builtin tools + sandbox path injection + claw tool auto-allow) +- `src/main/services/agents/services/cherryclaw/prompt.ts` — NEW: `PromptBuilder` assembles full system prompt from workspace files (system.md override, soul.md, user.md, memory/FACT.md) with mtime-based caching and anna-style XML-tagged memories section +- `src/main/services/agents/services/claudecode/enhanced-session.ts` — NEW: `EnhancedSessionFields` type for `_sandbox`, `_settings`, `_sandboxAllowedPaths`, `_systemPrompt`, etc. +- `src/main/services/agents/services/claudecode/index.ts` — reads enhanced session fields; when `_systemPrompt` is set, uses it as full replacement (plain string) instead of preset+append; PreToolUse hook enforces `_sandboxAllowedPaths` via path checking for all filesystem tools + Bash regex; captures SDK session_id from init message onto `AgentStream.sdkSessionId` +- `src/main/services/agents/interfaces/AgentStreamInterface.ts` — added `sdkSessionId?: string` to `AgentStream` interface for SDK session resume +- `src/main/services/agents/services/cherryclaw/soul.ts` — NEW: SoulReader with mtime cache +- `src/main/services/agents/services/cherryclaw/heartbeat.ts` — HeartbeatReader: reads `heartbeat.md` from workspace with path traversal protection. Simplified API (no filename param, always reads `heartbeat.md`, returns trimmed content or undefined for empty/missing files). +- `src/main/services/agents/services/TaskService.ts` — task CRUD, getDueTasks, computeNextRun (drift-resistant), run logging. `listTasks` now filters out heartbeat tasks by default (pass `{ includeHeartbeat: true }` to include). +- `src/main/services/agents/services/SchedulerService.ts` — poll-loop based, queries DB for due tasks, backward-compatible stopScheduler/startScheduler stubs; passes `{ persist: true }` and drains stream for completion; `runTaskNow()` for manual trigger; `notifyTaskResult()` for channel notifications; `ensureHeartbeatTask(agentId, intervalMinutes)` creates/updates the heartbeat scheduled task. `runTask()` detects heartbeat tasks (`task.name === 'heartbeat'`) and reads `heartbeat.md` from workspace instead of using stored prompt. +- `src/main/services/agents/services/index.ts` — registers claude-code + cherry-claw services, exports TaskService +- `src/main/services/agents/BaseService.ts` — added `cherry-claw` to tool/command dispatch +- `src/main/services/agents/services/SessionService.ts` — added `cherry-claw` to command dispatch +- `src/main/index.ts` — wired scheduler restore on startup, stopAll on quit +- `src/main/apiServer/routes/agents/handlers/agents.ts` — stop/restart scheduler on agent delete/update; sync heartbeat task on CherryClaw create/update/patch via `ensureHeartbeatTask()` + +### Claw MCP Server +- `src/main/mcpServers/claw.ts` — NEW: ClawServer with `cron` tool (add/list/remove actions) + `notify` tool (send messages to channels) + `memory` tool (update/append/search for FACT.md + JOURNAL.jsonl) + `skills` tool (marketplace search/install/remove/list), duration parsing, TaskService + ChannelManager + AgentService delegation +- `src/main/apiServer/routes/claw-mcp.ts` — NEW: Express route for Streamable HTTP MCP protocol, per-agent server caching, per-session transport management +- `src/main/apiServer/app.ts` — mounted claw MCP route at `/v1/claw` +- `src/main/services/agents/services/claudecode/internal-mcp.ts` — NEW: `InternalMcpServerConfig` type for injecting internal MCP servers +- `src/main/services/agents/services/claudecode/index.ts` — merges `_internalMcpServers` from session into SDK `options.mcpServers` + +### Channel Layer +- `src/main/services/agents/services/channels/ChannelAdapter.ts` — abstract interface + event types + `sendMessageDraft` + `notifyChatIds` property +- `src/main/services/agents/services/channels/ChannelMessageHandler.ts` — message routing, multi-turn stream collection, draft streaming, typing indicators; passes `{ persist: true }` for headless persistence +- `src/main/services/agents/services/channels/ChannelManager.ts` — singleton lifecycle, adapter factory registry, agent sync + `getNotifyAdapters()` + `notifyChannels` tracking +- `src/main/services/agents/services/channels/index.ts` — public exports + adapter module imports +- `src/main/services/agents/services/channels/adapters/TelegramAdapter.ts` — grammY-based adapter (long polling, auth guard, `sendMessageDraft`, message chunking, sets `notifyChatIds`) +- `src/main/services/agents/services/channels/adapters/QQAdapter.ts` — NEW: QQ Bot API adapter (WebSocket gateway, REST messaging, c2c/group/guild/dm support, access token caching) + +### Channel UI +- `src/renderer/src/pages/settings/AgentSettings/components/ChannelsSettings.tsx` — catalog-based card layout with inline config (blur-to-save), TelegramChannelCard + QQChannelCard +- `src/renderer/src/pages/settings/AgentSettings/AgentSettingsPopup.tsx` — channels tab for CherryClaw +- `src/renderer/src/types/agent.ts` — `TelegramChannelConfigSchema`, `QQChannelConfigSchema`, `CherryClawChannelSchema` with typed config + enabled flag + +### API Routes (Tasks) +- `src/main/apiServer/routes/agents/handlers/tasks.ts` — NEW: createTask, listTasks, getTask, updateTask, deleteTask, runTask, getTaskLogs +- `src/main/apiServer/routes/agents/validators/tasks.ts` — NEW: Zod validators for task routes +- `src/main/apiServer/routes/agents/handlers/index.ts` — added taskHandlers export +- `src/main/apiServer/routes/agents/validators/index.ts` — added tasks validators export + +### Frontend API Client & Hooks +- `src/renderer/src/api/agent.ts` — added task path helpers, listTasks, createTask, getTask, updateTask, deleteTask, runTask, getTaskLogs methods +- `src/renderer/src/hooks/agents/useTasks.ts` — NEW: useTasks, useCreateTask, useUpdateTask, useDeleteTask, useRunTask, useTaskLogs SWR hooks + +### Frontend UI +- `src/renderer/src/components/Popups/agent/AgentModal.tsx` — agent type selector, CherryClaw defaults, bypass warning +- `src/renderer/src/pages/settings/AgentSettings/AgentSettingsPopup.tsx` — replaced Channels tab with Tasks tab for CherryClaw agents +- `src/renderer/src/pages/settings/AgentSettings/BaseSettingsPopup.tsx` — added `'tasks'` to SettingsPopupTab union +- `src/renderer/src/pages/settings/AgentSettings/components/TasksSettings.tsx` — task list with add/edit/pause/delete/run/logs + HeartbeatSection (toggle + interval input) at top +- `src/renderer/src/pages/settings/AgentSettings/components/TaskListItem.tsx` — NEW: task row with status badge, schedule info, action buttons +- `src/renderer/src/pages/settings/AgentSettings/components/TaskFormModal.tsx` — NEW: add/edit modal (name, prompt, schedule type/value, context mode) +- `src/renderer/src/pages/settings/AgentSettings/components/TaskLogsModal.tsx` — NEW: run history table (run_at, duration, status, result/error) +- `src/renderer/src/pages/settings/AgentSettings/components/SoulSettings.tsx` — NEW +- `src/renderer/src/pages/settings/AgentSettings/components/ChannelsSettings.tsx` — placeholder (no longer in CherryClaw tab menu) +- `src/renderer/src/pages/settings/AgentSettings/shared.tsx` — CherryClaw default emoji +- `src/renderer/src/i18n/label.ts` — added CherryClaw label + +### i18n +- `src/renderer/src/i18n/locales/en-us.json` + 10 other locale files — CherryClaw + task UI strings (properly nested) + +### Tests +- `src/main/services/agents/services/__tests__/AgentServiceRegistry.test.ts` — 4 tests +- `src/main/services/agents/services/__tests__/SchedulerService.test.ts` — 7 tests (rewritten for poll-loop API) +- `src/main/services/agents/services/cherryclaw/__tests__/soul.test.ts` — 4 tests +- `src/main/services/agents/services/cherryclaw/__tests__/heartbeat.test.ts` — 4 tests (simplified: reads heartbeat.md, handles missing/empty, trims content) +- `src/main/services/agents/services/cherryclaw/__tests__/prompt.test.ts` — 7 tests (default prompt, system.md override, individual memory files, combined memories, caching) +- `src/main/mcpServers/__tests__/claw.test.ts` — 37 tests (cron tool add/list/remove, duration parsing, validation, notify tool send/filter/errors, skills tool search/install/remove/list, memory tool update/append/search) +- `src/main/services/agents/services/channels/__tests__/ChannelMessageHandler.test.ts` — 7 tests (multi-turn accumulation, chunking, commands, session tracking) +- `src/main/services/agents/services/channels/__tests__/ChannelManager.test.ts` — 6 tests (lifecycle, sync, adapter management) +- `src/main/services/agents/services/channels/adapters/__tests__/TelegramAdapter.test.ts` — 8 tests (connect, auth guard, message handling, chunking) + +### Dependencies +- `package.json` / `pnpm-lock.yaml` — added `cron-parser` ^5.5.0, `grammy` ^1.41, `ws` ^8.19.0 (QQ channel WebSocket) + +## Current State + +- Branch: `feat/claw-channel-qq` +- All lint/test/format checks pass (main process: 38 test files, 528 tests) +- Feature is code-complete including task-based scheduler, heartbeat as scheduled task, claw MCP tools (cron + notify + skills + memory), channel layer with Telegram and QQ adapters, custom system prompt with memory system, and manual task run +- Renderer tests have pre-existing environment issue (vitest web-worker module resolution) + +## Blockers / Gotchas + +- **Placeholder avatar** — `cherry-claw.png` is a copy of `claude.png`. Needs a proper distinct image. +- **Channel streaming behavior** — `text-delta` events from the transform layer are cumulative within a text block (each contains full text so far, not just the new portion). The UI relies on this. `ChannelMessageHandler` uses `text = value.text` (replace) within a block, and commits on `text-end` across turns. Do not change the transform layer's cumulative behavior. +- **Headless message persistence (FIXED)** — `SessionMessageService.createSessionMessage()` does NOT persist messages itself; persistence was entirely UI-driven via IPC (`AgentMessage_PersistExchange`). Channel and scheduler callers had no UI, so messages were lost. Fix: added `{ persist: true }` option to `createSessionMessage()` that triggers `persistHeadlessExchange()` on stream complete. Two bugs were found and fixed: + 1. **Missing persistence** — headless callers never saved user/assistant messages to `sessionMessagesTable`. Fixed by calling `agentMessageRepository.persistExchange()` when `persist: true`. + 2. **Cumulative delta corruption** — `TextStreamAccumulator` used `+=` for text-delta, but deltas are cumulative (full text so far). This caused persisted text to contain all intermediate states concatenated. Fixed by using `=` (replace). The `ChannelMessageHandler` already used `=` correctly. + 3. **topicId prefix** — `Message.topicId` must use `agent-session:` prefix, not raw session ID. Without the prefix, the UI's `DbService.getDataSource()` routes to Dexie instead of the agent SQLite data source, breaking message updates and rendering. +- **Telegram rate limits** — `sendMessageDraft` has no documented rate limit, but `sendMessage` is 30/s globally, 1/s per chat. Draft throttle is 500ms; typing indicator is 4s. +- **Telegram MarkdownV2** — agent responses sent as plain text (no `parse_mode`) to avoid escaping issues. Proper GFM→MarkdownV2 conversion is a follow-up. +- **QQ no streaming** — QQ Bot API has no native draft/streaming API like Telegram, so `sendMessageDraft` is a no-op. Full responses are sent as final messages only. +- **QQ no typing indicator** — QQ Bot API does not support typing indicators for most message types. `sendTypingIndicator` is a no-op. +- ~~**Memory system**~~ — DONE: anna-inspired 3-file model (soul.md, user.md, memory/FACT.md) + JOURNAL.jsonl, with `memory` MCP tool and `PromptBuilder` for system prompt assembly. +- **Non-Anthropic models** — CherryClaw only supports Anthropic provider models (inherits from Claude Agent SDK). +- **Session settings** — `SessionSettingsPopup.tsx` was NOT updated with CherryClaw tabs (only `AgentSettingsPopup` was). May want to add soul/task tabs there too if sessions need per-session overrides. +- **Scheduler backward compat** — `startScheduler(agent)` and `stopScheduler(agentId)` are now no-ops (the poll loop handles everything via DB state). Agent handler code in `agents.ts` still calls them but they just ensure the loop is running. +- **Task consecutive errors** — after 3 consecutive errors, a task is auto-paused. The error count resets on the next successful run. This is tracked per-task in the running task state (not persisted). +- **Claw MCP server lifecycle (FIXED)** — per-session ClawServer + Transport pairs. The MCP SDK `Server` class only supports one transport at a time (`connect()` throws "Already connected" if called twice). Previous per-agent caching caused sessions to break on reconnect. Now each MCP session gets its own `ClawServer` + `StreamableHTTPServerTransport` pair, stored in a `sessions` Map keyed by MCP session ID. `cleanupClawServer(agentId)` removes all sessions for that agent. Should be wired into agent delete handler. +- **Claw MCP tool allowlist (FIXED)** — the claw MCP server is registered as `claw`, so tools appear as `mcp__claw__cron`, `mcp__claw__notify`, `mcp__claw__skills`, and `mcp__claw__memory`. `CherryClawService.invoke()` now auto-appends these to `allowed_tools` when the agent has an explicit whitelist. When `allowed_tools` is undefined (no restriction), all tools are already available. +- **Sandbox is basic restriction only (NOT a security boundary)** — The PreToolUse hook path check has known bypasses: (1) Bash regex misses relative path tricks (`cd / && cat etc/passwd`), variable expansion (`$HOME`), subshells, heredocs, etc. (2) The SDK OS-level sandbox (`sandbox.enabled`) does not reliably restrict reads on macOS. (3) MCP tools and agent sub-tools are not checked. This is sufficient for well-behaved autonomous agents but should not be relied upon as a security sandbox. Future work: integrate proper OS sandbox enforcement, or restrict Bash to a vetted allowlist of commands. + +## Next Steps + +1. **Create PR** — use `gh-create-pr` skill to create a pull request from `feat/cherry-claw-agent` → `main` +2. **Replace avatar** — design/source a proper CherryClaw avatar image to replace the placeholder +3. **E2E testing** — manually test the full flow: create CherryClaw agent → verify cron tool is available → agent creates a scheduled task → verify task execution and run logging +4. **Wire cleanup** — call `cleanupClawServer(agentId)` in the agent delete handler to free per-agent MCP server instances +5. ~~**Tool allowlist**~~ — DONE: `mcp__claw__cron`, `mcp__claw__notify`, `mcp__claw__skills`, and `mcp__claw__memory` auto-added to `allowed_tools` in `CherryClawService.invoke()` +6. **TaskService tests** — add unit tests for TaskService CRUD and computeNextRun +7. **SessionSettingsPopup** — consider adding CherryClaw tabs to session-level settings if per-session overrides are needed +8. **GFM→MarkdownV2 conversion** — proper markdown formatting for Telegram responses +9. **Additional channel adapters** — Discord, Slack using the same `ChannelAdapter` + `registerAdapterFactory` pattern +10. **Harden sandbox** — current sandbox is basic path checking only. Needs: (a) proper OS sandbox enforcement for Bash reads, (b) Bash command allowlist or AST-based path extraction, (c) MCP tool path checking, (d) block relative path traversal tricks in Bash commands diff --git a/package.json b/package.json index 233f8045469..dbaf91e1454 100644 --- a/package.json +++ b/package.json @@ -77,12 +77,15 @@ "dependencies": { "@anthropic-ai/claude-agent-sdk": "0.2.71", "@expo/sudo-prompt": "^9.3.2", + "@larksuiteoapi/node-sdk": "^1.59.0", "@libsql/client": "0.14.0", "@napi-rs/system-ocr": "1.0.2", "@paymoapp/electron-shutdown-handler": "1.1.2", + "cron-parser": "^5.5.0", "express": "5.1.0", "font-list": "2.0.0", "graceful-fs": "4.2.11", + "grammy": "^1.41.1", "gray-matter": "4.0.3", "jsdom": "26.1.0", "node-stream-zip": "1.15.0", @@ -92,7 +95,8 @@ "sharp": "0.34.3", "swagger-ui-express": "5.0.1", "tesseract.js": "6.0.1", - "turndown": "7.2.0" + "turndown": "7.2.0", + "ws": "^8.19.0" }, "devDependencies": { "@agentic/exa": "^7.3.3", @@ -243,6 +247,7 @@ "@types/unist": "3.0.3", "@types/uuid": "^10.0.0", "@types/word-extractor": "^1", + "@types/ws": "^8.18.1", "@typescript/native-preview": "7.0.0-dev.20260204.1", "@uiw/codemirror-extensions-langs": "4.25.7", "@uiw/codemirror-themes-all": "4.25.7", diff --git a/packages/aiCore/src/core/options/factory.ts b/packages/aiCore/src/core/options/factory.ts index 1e493b2337e..d04e7edf89f 100644 --- a/packages/aiCore/src/core/options/factory.ts +++ b/packages/aiCore/src/core/options/factory.ts @@ -114,10 +114,3 @@ export function createGoogleOptions(options: ExtractProviderOptions<'google'>) { export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'> | Record) { return createProviderOptions('openrouter', options) } - -/** - * 创建XAI供应商选项的便捷函数 - */ -export function createXaiOptions(options: ExtractProviderOptions<'xai'>) { - return createProviderOptions('xai', options) -} diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts index 6e313bdd27d..79ee9e52c9a 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts @@ -1,11 +1,11 @@ import { anthropic } from '@ai-sdk/anthropic' import { google } from '@ai-sdk/google' import { openai } from '@ai-sdk/openai' +import { xai } from '@ai-sdk/xai' import type { InferToolInput, InferToolOutput } from 'ai' import { type Tool } from 'ai' -import { createOpenRouterOptions, createXaiOptions, mergeProviderOptions } from '../../../options' -import type { ProviderOptionsMap } from '../../../options/types' +import { createOpenRouterOptions, mergeProviderOptions } from '../../../options' import type { AiRequestContext } from '../../' import type { OpenRouterSearchConfig } from './openrouter' @@ -16,7 +16,8 @@ export type OpenAISearchConfig = NonNullable[0]> export type AnthropicSearchConfig = NonNullable[0]> export type GoogleSearchConfig = NonNullable[0]> -export type XAISearchConfig = NonNullable +export type XAIWebSearchConfig = NonNullable[0]> +export type XAIXSearchConfig = NonNullable[0]> type NormalizeTool = T extends Tool ? Tool : Tool @@ -24,6 +25,8 @@ type AnthropicWebSearchTool = NormalizeTool> type OpenAIChatWebSearchTool = NormalizeTool> type GoogleWebSearchTool = NormalizeTool> +type XAIWebSearchTool = NormalizeTool> +type XAIXSearchTool = NormalizeTool> /** * 插件初始化时接收的完整配置对象 @@ -34,7 +37,8 @@ export interface WebSearchPluginConfig { openai?: OpenAISearchConfig 'openai-chat'?: OpenAISearchPreviewConfig anthropic?: AnthropicSearchConfig - xai?: ProviderOptionsMap['xai']['searchParameters'] + xai?: XAIWebSearchConfig + 'xai-xsearch'?: XAIXSearchConfig google?: GoogleSearchConfig openrouter?: OpenRouterSearchConfig } @@ -47,10 +51,10 @@ export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = { openai: {}, 'openai-chat': {}, xai: { - mode: 'on', - returnCitations: true, - maxSearchResults: 5, - sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }] + enableImageUnderstanding: true + }, + 'xai-xsearch': { + enableImageUnderstanding: true }, anthropic: { maxUses: 5 @@ -87,6 +91,9 @@ export type WebSearchToolOutputSchema = { web?: { uri: string; title: string } }> } + // xAI 工具 + xai: InferToolOutput + 'xai-xsearch': InferToolOutput } export type WebSearchToolInputSchema = { @@ -94,6 +101,8 @@ export type WebSearchToolInputSchema = { openai: InferToolInput google: InferToolInput 'openai-chat': InferToolInput + xai: InferToolInput + 'xai-xsearch': InferToolInput } /** @@ -141,8 +150,9 @@ export const switchWebSearchTool = (config: WebSearchPluginConfig, params: any, }, xai: () => { const cfg = config.xai ?? DEFAULT_WEB_SEARCH_CONFIG.xai - const searchOptions = createXaiOptions({ searchParameters: { ...cfg, mode: 'on' } }) - applyProviderOptionsSearch(params, searchOptions) + applyToolBasedSearch(params, 'web_search', xai.tools.webSearch(cfg)) + const xSearchCfg = config['xai-xsearch'] ?? DEFAULT_WEB_SEARCH_CONFIG['xai-xsearch'] + applyToolBasedSearch(params, 'x_search', xai.tools.xSearch(xSearchCfg)) }, openrouter: () => { const cfg = (config.openrouter ?? DEFAULT_WEB_SEARCH_CONFIG.openrouter) as OpenRouterSearchConfig diff --git a/packages/aiCore/src/core/providers/schemas.ts b/packages/aiCore/src/core/providers/schemas.ts index 8c038bebbfd..2765a120074 100644 --- a/packages/aiCore/src/core/providers/schemas.ts +++ b/packages/aiCore/src/core/providers/schemas.ts @@ -101,7 +101,15 @@ export const baseProviders = [ { id: 'xai', name: 'xAI (Grok)', - creator: createXai, + creator: (options: Parameters[0]) => { + const provider = createXai(options) + return customProvider({ + fallbackProvider: { + ...provider, + languageModel: (modelId: string) => provider.responses(modelId) + } + }) + }, supportsImageGeneration: true }, { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 188da65865c..2f1ebd9e270 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -94,6 +94,9 @@ importers: '@expo/sudo-prompt': specifier: ^9.3.2 version: 9.3.2 + '@larksuiteoapi/node-sdk': + specifier: ^1.59.0 + version: 1.59.0 '@libsql/client': specifier: 0.14.0 version: 0.14.0 @@ -103,6 +106,9 @@ importers: '@paymoapp/electron-shutdown-handler': specifier: 1.1.2 version: 1.1.2 + cron-parser: + specifier: ^5.5.0 + version: 5.5.0 express: specifier: 5.1.0 version: 5.1.0 @@ -112,6 +118,9 @@ importers: graceful-fs: specifier: 4.2.11 version: 4.2.11 + grammy: + specifier: ^1.41.1 + version: 1.41.1(encoding@0.1.13) gray-matter: specifier: 4.0.3 version: 4.0.3 @@ -142,6 +151,9 @@ importers: turndown: specifier: 7.2.0 version: 7.2.0 + ws: + specifier: ^8.19.0 + version: 8.19.0 devDependencies: '@agentic/exa': specifier: ^7.3.3 @@ -205,7 +217,7 @@ importers: version: 5.6.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3) '@ant-design/v5-patch-for-react-19': specifier: ^1.0.3 - version: 1.0.3(antd@5.27.0(patch_hash=cdc383bd0d9b9fe0df2ce7b1f1d4ead200012b7f9517d9257b4ea0a5b324e243)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3) + version: 1.0.3(antd@5.27.0(patch_hash=cdc383bd0d9b9fe0df2ce7b1f1d4ead200012b7f9517d9257b4ea0a5b324e243)(luxon@3.7.2)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3) '@anthropic-ai/sdk': specifier: ^0.41.0 version: 0.41.0(encoding@0.1.13) @@ -587,6 +599,9 @@ importers: '@types/word-extractor': specifier: ^1 version: 1.0.6 + '@types/ws': + specifier: ^8.18.1 + version: 8.18.1 '@typescript/native-preview': specifier: 7.0.0-dev.20260204.1 version: 7.0.0-dev.20260204.1 @@ -631,7 +646,7 @@ importers: version: 6.0.103(zod@4.3.4) antd: specifier: 5.27.0 - version: 5.27.0(patch_hash=cdc383bd0d9b9fe0df2ce7b1f1d4ead200012b7f9517d9257b4ea0a5b324e243)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3) + version: 5.27.0(patch_hash=cdc383bd0d9b9fe0df2ce7b1f1d4ead200012b7f9517d9257b4ea0a5b324e243)(luxon@3.7.2)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3) archiver: specifier: ^7.0.1 version: 7.0.1 @@ -2505,6 +2520,9 @@ packages: peerDependencies: '@modelcontextprotocol/sdk': ^1.11.0 + '@grammyjs/types@3.25.0': + resolution: {integrity: sha512-iN9i5p+8ZOu9OMxWNcguojQfz4K/PDyMPOnL7PPCON+SoA/F8OKMH3uR7CVUkYfdNe0GCz8QOzAWrnqusQYFOg==} + '@hello-pangea/dnd@18.0.1': resolution: {integrity: sha512-xojVWG8s/TGrKT1fC8K2tIWeejJYTAeJuj36zM//yEm/ZrnZUSFGS15BpO+jGZT1ybWvyXmeDJwPYb4dhWlbZQ==} peerDependencies: @@ -3163,6 +3181,9 @@ packages: peerDependencies: '@langchain/core': 1.0.2 + '@larksuiteoapi/node-sdk@1.59.0': + resolution: {integrity: sha512-sBpkruTvZDOxnVtoTbepWKRX0j1Y1ZElQYu0x7+v088sI9pcpbVp6ZzCGn62dhrKPatzNyCJyzYCPXPYQWccrA==} + '@leichtgewicht/ip-codec@2.0.5': resolution: {integrity: sha512-Vo+PSpZG2/fmgmiNzYK9qWRh8h/CHrwD0mo1h1DzL4yzHNSfWYujGTYsWGreD000gcgmZ7K4Ys6Tx9TxtsKdDw==} @@ -6366,6 +6387,10 @@ packages: crelt@1.0.6: resolution: {integrity: sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==} + cron-parser@5.5.0: + resolution: {integrity: sha512-oML4lKUXxizYswqmxuOCpgFS8BNUJpIu6k/2HVHyaL8Ynnf3wdf9tkns0yRdJLSIjkJ+b0DXHMZEHGpMwjnPww==} + engines: {node: '>=18'} + cross-dirname@0.1.0: resolution: {integrity: sha512-+R08/oI0nl3vfPcqftZRpytksBXDzOUveBq/NBVx0sUp1axwzPQrKinNx5yd5sxPu8j1wIy8AfnVQ+5eFdha6Q==} @@ -7705,11 +7730,11 @@ packages: glob@7.1.6: resolution: {integrity: sha512-LwaxwyZ72Lk7vZINtNNrywX0ZuLyStrdDtabefZKAY5ZGJhVtgdznluResxNmPitE0SAO+O26sWTHeKSI2wMBA==} - deprecated: Glob versions prior to v9 are no longer supported + deprecated: Old versions of glob are not supported, and contain widely publicized security vulnerabilities, which have been fixed in the current version. Please update. Support for old versions may be purchased (at exorbitant rates) by contacting i@izs.me glob@7.2.3: resolution: {integrity: sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==} - deprecated: Glob versions prior to v9 are no longer supported + deprecated: Old versions of glob are not supported, and contain widely publicized security vulnerabilities, which have been fixed in the current version. Please update. Support for old versions may be purchased (at exorbitant rates) by contacting i@izs.me global-agent@3.0.0: resolution: {integrity: sha512-PT6XReJ+D07JvGoxQMkT6qji/jVNfX/h364XHZOWeRzy64sSFr+xJ5OX7LI3b4MPQzdL4H8Y8M0xzPpsVMwA8Q==} @@ -7758,6 +7783,10 @@ packages: graceful-fs@4.2.11: resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==} + grammy@1.41.1: + resolution: {integrity: sha512-wcHAQ1e7svL3fJMpDchcQVcWUmywhuepOOjHUHmMmWAwUJEIyK5ea5sbSjZd+Gy1aMpZeP8VYJa+4tP+j1YptQ==} + engines: {node: ^12.20.0 || >=14.13.1} + graphql@16.12.0: resolution: {integrity: sha512-DKKrynuQRne0PNpEbzuEdHlYOMksHSUI8Zc9Unei5gTsMNA2/vMpoMz/yKba50pejK56qj98qM0SjYxAKi13gQ==} engines: {node: ^12.22.0 || ^14.16.0 || ^16.0.0 || >=17.0.0} @@ -8611,6 +8640,9 @@ packages: resolution: {integrity: sha512-z+Uw/vLuy6gQe8cfaFWD7p0wVv8fJl3mbzXh33RS+0oW2wvUqiRXiQ69gLWSLpgB5/6sU+r6BlQR0MBILadqTQ==} deprecated: This package is deprecated. Use the optional chaining (?.) operator instead. + lodash.identity@3.0.0: + resolution: {integrity: sha512-AupTIzdLQxJS5wIYUQlgGyk2XRTfGXA+MCghDHqZk0pzUNYvd3EESS6dkChNauNYVIutcb0dfHw1ri9Q1yPV8Q==} + lodash.includes@4.3.0: resolution: {integrity: sha512-W3Bx6mdkRTGtlJISOvVD/lbqjTlPPUDTMnlXZFnVwi9NKJ6tiAk6LVdlhZMm17VZisqhKcgzpO5Wz91PCt5b0w==} @@ -8642,6 +8674,9 @@ packages: lodash.once@4.1.1: resolution: {integrity: sha512-Sb487aTOCr9drQVL8pIxOzVhafOjZN9UU54hiN8PU3uAiSV7lx1yYNpbNmex2PK6dSJoNTSJUUswT651yww3Mg==} + lodash.pickby@4.6.0: + resolution: {integrity: sha512-AZV+GsS/6ckvPOVQPXSiFFacKvKB4kOQu6ynt9wz0F3LO4R9Ij4K1ddYsIytDpSgLz88JHd9P+oaLeej5/Sl7Q==} + lodash@4.17.21: resolution: {integrity: sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==} @@ -8706,6 +8741,10 @@ packages: peerDependencies: react: ^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0 + luxon@3.7.2: + resolution: {integrity: sha512-vtEhXh/gNjI9Yg1u4jX/0YVPMvxzHuGgCm6tC5kZyb08yjGWGnqAjGJvcXbqQR2P3MyMEFnRbpcdFS6PBcLqew==} + engines: {node: '>=12'} + lz-string@1.5.0: resolution: {integrity: sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ==} hasBin: true @@ -9885,6 +9924,10 @@ packages: resolution: {integrity: sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==} engines: {node: '>=0.6'} + qs@6.15.0: + resolution: {integrity: sha512-mAZTtNCeetKMH+pSjrb76NAM8V9a05I9aBZOHztWy/UqcJdQYNsf59vrRKWnojAT9Y+GbIvoTBC++CPHqpDBhQ==} + engines: {node: '>=0.6'} + quansync@1.0.0: resolution: {integrity: sha512-5xZacEEufv3HSTPQuchrvV6soaiACMFnq1H8wkVioctoH3TRha9Sz66lOxRwPK/qZj7HPiSveih9yAyh98gvqA==} @@ -11783,18 +11826,6 @@ packages: wrappy@1.0.2: resolution: {integrity: sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==} - ws@8.18.3: - resolution: {integrity: sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==} - engines: {node: '>=10.0.0'} - peerDependencies: - bufferutil: ^4.0.1 - utf-8-validate: '>=5.0.2' - peerDependenciesMeta: - bufferutil: - optional: true - utf-8-validate: - optional: true - ws@8.19.0: resolution: {integrity: sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==} engines: {node: '>=10.0.0'} @@ -12211,9 +12242,9 @@ snapshots: resize-observer-polyfill: 1.5.1 throttle-debounce: 5.0.2 - '@ant-design/v5-patch-for-react-19@1.0.3(antd@5.27.0(patch_hash=cdc383bd0d9b9fe0df2ce7b1f1d4ead200012b7f9517d9257b4ea0a5b324e243)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)': + '@ant-design/v5-patch-for-react-19@1.0.3(antd@5.27.0(patch_hash=cdc383bd0d9b9fe0df2ce7b1f1d4ead200012b7f9517d9257b4ea0a5b324e243)(luxon@3.7.2)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)': dependencies: - antd: 5.27.0(patch_hash=cdc383bd0d9b9fe0df2ce7b1f1d4ead200012b7f9517d9257b4ea0a5b324e243)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3) + antd: 5.27.0(patch_hash=cdc383bd0d9b9fe0df2ce7b1f1d4ead200012b7f9517d9257b4ea0a5b324e243)(luxon@3.7.2)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3) react: 19.2.3 react-dom: 19.2.3(react@19.2.3) @@ -14081,7 +14112,7 @@ snapshots: dependencies: '@modelcontextprotocol/sdk': 1.27.1(@cfworker/json-schema@4.1.1)(zod@4.3.4) google-auth-library: 9.15.1(encoding@0.1.13) - ws: 8.18.3 + ws: 8.19.0 zod: 3.25.76 zod-to-json-schema: 3.25.1(zod@3.25.76) transitivePeerDependencies: @@ -14090,6 +14121,8 @@ snapshots: - supports-color - utf-8-validate + '@grammyjs/types@3.25.0': {} + '@hello-pangea/dnd@18.0.1(@types/react@19.2.7)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)': dependencies: '@babel/runtime': 7.28.4 @@ -14436,6 +14469,20 @@ snapshots: '@langchain/core': 1.0.2(patch_hash=8dc787a82cebafe8b23c8826f25f29aca64fc8b43a0a1878e0010782e4da96ed)(@opentelemetry/api@1.9.0)(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(openai@6.15.0(ws@8.19.0)(zod@4.3.4)) js-tiktoken: 1.0.21 + '@larksuiteoapi/node-sdk@1.59.0': + dependencies: + axios: 1.13.6(debug@4.4.3) + lodash.identity: 3.0.0 + lodash.merge: 4.6.2 + lodash.pickby: 4.6.0 + protobufjs: 7.5.4 + qs: 6.15.0 + ws: 8.19.0 + transitivePeerDependencies: + - bufferutil + - debug + - utf-8-validate + '@leichtgewicht/ip-codec@2.0.5': {} '@lezer/common@1.5.1': {} @@ -14571,7 +14618,7 @@ snapshots: '@libsql/isomorphic-ws@0.1.5': dependencies: '@types/ws': 8.18.1 - ws: 8.18.3 + ws: 8.19.0 transitivePeerDependencies: - bufferutil - utf-8-validate @@ -17246,7 +17293,7 @@ snapshots: sirv: 3.0.2 tinyrainbow: 2.0.0 vitest: 3.2.4(@types/debug@4.1.12)(@types/node@24.10.4)(@vitest/browser@3.2.4)(@vitest/ui@3.2.4)(esbuild@0.25.12)(jiti@2.6.1)(jsdom@26.1.0)(msw@2.12.7(@types/node@24.10.4)(typescript@5.8.3))(tsx@4.21.0)(yaml@2.8.2) - ws: 8.18.3 + ws: 8.19.0 optionalDependencies: playwright: 1.57.0 transitivePeerDependencies: @@ -17464,7 +17511,7 @@ snapshots: ansis@4.2.0: {} - antd@5.27.0(patch_hash=cdc383bd0d9b9fe0df2ce7b1f1d4ead200012b7f9517d9257b4ea0a5b324e243)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3): + antd@5.27.0(patch_hash=cdc383bd0d9b9fe0df2ce7b1f1d4ead200012b7f9517d9257b4ea0a5b324e243)(luxon@3.7.2)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3): dependencies: '@ant-design/colors': 7.2.1 '@ant-design/cssinjs': 1.23.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3) @@ -17496,7 +17543,7 @@ snapshots: rc-motion: 2.9.5(react-dom@19.2.3(react@19.2.3))(react@19.2.3) rc-notification: 5.6.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3) rc-pagination: 5.1.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3) - rc-picker: 4.11.3(dayjs@1.11.19)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3) + rc-picker: 4.11.3(dayjs@1.11.19)(luxon@3.7.2)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3) rc-progress: 4.0.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3) rc-rate: 2.13.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3) rc-resize-observer: 1.4.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3) @@ -18348,6 +18395,10 @@ snapshots: crelt@1.0.6: {} + cron-parser@5.5.0: + dependencies: + luxon: 3.7.2 + cross-dirname@0.1.0: optional: true @@ -19998,6 +20049,16 @@ snapshots: graceful-fs@4.2.11: {} + grammy@1.41.1(encoding@0.1.13): + dependencies: + '@grammyjs/types': 3.25.0 + abort-controller: 3.0.0 + debug: 4.4.3 + node-fetch: 2.7.0(encoding@0.1.13) + transitivePeerDependencies: + - encoding + - supports-color + graphql@16.12.0: {} gray-matter@4.0.3: @@ -20569,7 +20630,7 @@ snapshots: whatwg-encoding: 3.1.1 whatwg-mimetype: 4.0.0 whatwg-url: 14.2.0 - ws: 8.18.3 + ws: 8.19.0 xml-name-validator: 5.0.0 transitivePeerDependencies: - bufferutil @@ -20869,6 +20930,8 @@ snapshots: lodash.get@4.4.2: {} + lodash.identity@3.0.0: {} + lodash.includes@4.3.0: {} lodash.isboolean@3.0.3: {} @@ -20889,6 +20952,8 @@ snapshots: lodash.once@4.1.1: {} + lodash.pickby@4.6.0: {} + lodash@4.17.21: {} lodash@4.17.23: {} @@ -20947,6 +21012,8 @@ snapshots: dependencies: react: 19.2.3 + luxon@3.7.2: {} + lz-string@1.5.0: {} mac-system-proxy@1.0.4: {} @@ -22534,6 +22601,10 @@ snapshots: dependencies: side-channel: 1.1.0 + qs@6.15.0: + dependencies: + side-channel: 1.1.0 + quansync@1.0.0: {} querystringify@2.2.0: {} @@ -22701,7 +22772,7 @@ snapshots: react: 19.2.3 react-dom: 19.2.3(react@19.2.3) - rc-picker@4.11.3(dayjs@1.11.19)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3): + rc-picker@4.11.3(dayjs@1.11.19)(luxon@3.7.2)(moment@2.30.1)(react-dom@19.2.3(react@19.2.3))(react@19.2.3): dependencies: '@babel/runtime': 7.28.4 '@rc-component/trigger': 2.3.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3) @@ -22713,6 +22784,7 @@ snapshots: react-dom: 19.2.3(react@19.2.3) optionalDependencies: dayjs: 1.11.19 + luxon: 3.7.2 moment: 2.30.1 rc-progress@4.0.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3): @@ -24792,8 +24864,6 @@ snapshots: wrappy@1.0.2: {} - ws@8.18.3: {} - ws@8.19.0: {} xlsx@https://cdn.sheetjs.com/xlsx-0.20.2/xlsx-0.20.2.tgz: {} diff --git a/resources/database/drizzle/0003_wise_meltdown.sql b/resources/database/drizzle/0003_wise_meltdown.sql new file mode 100644 index 00000000000..20407cdc613 --- /dev/null +++ b/resources/database/drizzle/0003_wise_meltdown.sql @@ -0,0 +1,25 @@ +CREATE TABLE `scheduled_tasks` ( + `id` text PRIMARY KEY NOT NULL, + `agent_id` text NOT NULL, + `name` text NOT NULL, + `prompt` text NOT NULL, + `schedule_type` text NOT NULL, + `schedule_value` text NOT NULL, + `context_mode` text DEFAULT 'session' NOT NULL, + `next_run` text, + `last_run` text, + `last_result` text, + `status` text DEFAULT 'active' NOT NULL, + `created_at` text NOT NULL, + `updated_at` text NOT NULL +); +--> statement-breakpoint +CREATE TABLE `task_run_logs` ( + `id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, + `task_id` text NOT NULL, + `run_at` text NOT NULL, + `duration_ms` integer NOT NULL, + `status` text NOT NULL, + `result` text, + `error` text +); diff --git a/resources/database/drizzle/meta/0003_snapshot.json b/resources/database/drizzle/meta/0003_snapshot.json new file mode 100644 index 00000000000..8165f51e5fd --- /dev/null +++ b/resources/database/drizzle/meta/0003_snapshot.json @@ -0,0 +1,508 @@ +{ + "version": "6", + "dialect": "sqlite", + "id": "45248bfb-356b-400b-a48e-858fb2928c3a", + "prevId": "0cf3d79e-69bf-4dba-8df4-996b9b67d2e8", + "tables": { + "agents": { + "name": "agents", + "columns": { + "id": { + "name": "id", + "type": "text", + "primaryKey": true, + "notNull": true, + "autoincrement": false + }, + "type": { + "name": "type", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "name": { + "name": "name", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "description": { + "name": "description", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "accessible_paths": { + "name": "accessible_paths", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "instructions": { + "name": "instructions", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "model": { + "name": "model", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "plan_model": { + "name": "plan_model", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "small_model": { + "name": "small_model", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "mcps": { + "name": "mcps", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "allowed_tools": { + "name": "allowed_tools", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "configuration": { + "name": "configuration", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "updated_at": { + "name": "updated_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "session_messages": { + "name": "session_messages", + "columns": { + "id": { + "name": "id", + "type": "integer", + "primaryKey": true, + "notNull": true, + "autoincrement": true + }, + "session_id": { + "name": "session_id", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "role": { + "name": "role", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "agent_session_id": { + "name": "agent_session_id", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false, + "default": "''" + }, + "metadata": { + "name": "metadata", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "updated_at": { + "name": "updated_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "migrations": { + "name": "migrations", + "columns": { + "version": { + "name": "version", + "type": "integer", + "primaryKey": true, + "notNull": true, + "autoincrement": false + }, + "tag": { + "name": "tag", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "executed_at": { + "name": "executed_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "sessions": { + "name": "sessions", + "columns": { + "id": { + "name": "id", + "type": "text", + "primaryKey": true, + "notNull": true, + "autoincrement": false + }, + "agent_type": { + "name": "agent_type", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "agent_id": { + "name": "agent_id", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "name": { + "name": "name", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "description": { + "name": "description", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "accessible_paths": { + "name": "accessible_paths", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "instructions": { + "name": "instructions", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "model": { + "name": "model", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "plan_model": { + "name": "plan_model", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "small_model": { + "name": "small_model", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "mcps": { + "name": "mcps", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "allowed_tools": { + "name": "allowed_tools", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "slash_commands": { + "name": "slash_commands", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "configuration": { + "name": "configuration", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "updated_at": { + "name": "updated_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "scheduled_tasks": { + "name": "scheduled_tasks", + "columns": { + "id": { + "name": "id", + "type": "text", + "primaryKey": true, + "notNull": true, + "autoincrement": false + }, + "agent_id": { + "name": "agent_id", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "name": { + "name": "name", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "prompt": { + "name": "prompt", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "schedule_type": { + "name": "schedule_type", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "schedule_value": { + "name": "schedule_value", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "context_mode": { + "name": "context_mode", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false, + "default": "'session'" + }, + "next_run": { + "name": "next_run", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "last_run": { + "name": "last_run", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "last_result": { + "name": "last_result", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "status": { + "name": "status", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false, + "default": "'active'" + }, + "created_at": { + "name": "created_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "updated_at": { + "name": "updated_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "task_run_logs": { + "name": "task_run_logs", + "columns": { + "id": { + "name": "id", + "type": "integer", + "primaryKey": true, + "notNull": true, + "autoincrement": true + }, + "task_id": { + "name": "task_id", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "run_at": { + "name": "run_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "duration_ms": { + "name": "duration_ms", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "status": { + "name": "status", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "result": { + "name": "result", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "error": { + "name": "error", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + } + }, + "views": {}, + "enums": {}, + "_meta": { + "schemas": {}, + "tables": {}, + "columns": {} + }, + "internal": { + "indexes": {} + } +} diff --git a/resources/database/drizzle/meta/_journal.json b/resources/database/drizzle/meta/_journal.json index ac026637aa0..d8a70fad26e 100644 --- a/resources/database/drizzle/meta/_journal.json +++ b/resources/database/drizzle/meta/_journal.json @@ -22,6 +22,13 @@ "when": 1762526423527, "tag": "0002_wealthy_naoko", "breakpoints": true + }, + { + "idx": 3, + "version": "6", + "when": 1773111833225, + "tag": "0003_wise_meltdown", + "breakpoints": true } ] } diff --git a/src/main/apiServer/app.ts b/src/main/apiServer/app.ts index a3f8d3dd04e..0a0708361b3 100644 --- a/src/main/apiServer/app.ts +++ b/src/main/apiServer/app.ts @@ -9,6 +9,7 @@ import { errorHandler } from './middleware/error' import { setupOpenAPIDocumentation } from './middleware/openapi' import { agentsRoutes } from './routes/agents' import { chatRoutes } from './routes/chat' +import { clawMcpRoutes } from './routes/claw-mcp' import { mcpRoutes } from './routes/mcp' import { messagesProviderRoutes, messagesRoutes } from './routes/messages' import { modelsRoutes } from './routes/models' @@ -152,6 +153,7 @@ apiRouter.use('/mcps', mcpRoutes) apiRouter.use('/messages', extendMessagesTimeout, messagesRoutes) apiRouter.use('/models', modelsRoutes) apiRouter.use('/agents', agentsRoutes) +apiRouter.use('/claw', clawMcpRoutes) app.use('/v1', apiRouter) // Error handling (must be last) diff --git a/src/main/apiServer/generated/openapi-spec.json b/src/main/apiServer/generated/openapi-spec.json index 353e156511c..3d198a76101 100644 --- a/src/main/apiServer/generated/openapi-spec.json +++ b/src/main/apiServer/generated/openapi-spec.json @@ -216,7 +216,7 @@ }, "AgentType": { "type": "string", - "enum": ["claude-code"], + "enum": ["claude-code", "cherry-claw"], "description": "Type of agent" }, "AgentConfiguration": { diff --git a/src/main/apiServer/routes/agents/handlers/agents.ts b/src/main/apiServer/routes/agents/handlers/agents.ts index 53e5f9433e7..695523a8943 100644 --- a/src/main/apiServer/routes/agents/handlers/agents.ts +++ b/src/main/apiServer/routes/agents/handlers/agents.ts @@ -1,6 +1,8 @@ import { loggerService } from '@logger' import { AgentModelValidationError, agentService, sessionService } from '@main/services/agents' -import type { ListAgentsResponse } from '@types' +import { channelManager } from '@main/services/agents/services/channels' +import { schedulerService } from '@main/services/agents/services/SchedulerService' +import type { CherryClawConfiguration, ListAgentsResponse } from '@types' import { type ReplaceAgentRequest, type UpdateAgentRequest } from '@types' import type { Request, Response } from 'express' @@ -62,8 +64,14 @@ export const createAgent = async (req: Request, res: Response): Promise { + logger.warn('Failed to sync heartbeat task', { + agentId, + error: err instanceof Error ? err.message : String(err) + }) + }) + } + logger.info('Agent updated', { agentId }) return res.json(agent) } catch (error: any) { @@ -496,6 +518,20 @@ export const patchAgent = async (req: Request, res: Response): Promise }) } + // Restart scheduler if this is a cherry-claw agent with scheduler config changes + if (agent.type === 'cherry-claw') { + schedulerService.stopScheduler(agentId) + schedulerService.startScheduler(agent) + channelManager.syncAgent(agentId) + const config = (agent.configuration ?? {}) as CherryClawConfiguration + schedulerService.ensureHeartbeatTask(agentId, config.heartbeat_interval ?? 30).catch((err) => { + logger.warn('Failed to sync heartbeat task', { + agentId, + error: err instanceof Error ? err.message : String(err) + }) + }) + } + logger.info('Agent patched', { agentId }) return res.json(agent) } catch (error: any) { @@ -556,6 +592,9 @@ export const deleteAgent = async (req: Request, res: Response): Promise => { + const { agentId } = req.params + try { + logger.debug('Creating task', { agentId }) + const task = await taskService.createTask(agentId, req.body) + logger.info('Task created', { agentId, taskId: task.id }) + return res.status(201).json(task) + } catch (error: any) { + logger.error('Error creating task', { error, agentId }) + return res.status(500).json({ + error: { + message: `Failed to create task: ${error.message}`, + type: 'internal_error', + code: 'task_creation_failed' + } + }) + } +} + +export const listTasks = async (req: Request, res: Response): Promise => { + const { agentId } = req.params + try { + const limit = req.query.limit ? parseInt(req.query.limit as string) : 20 + const offset = req.query.offset ? parseInt(req.query.offset as string) : 0 + + logger.debug('Listing tasks', { agentId, limit, offset }) + const result = await taskService.listTasks(agentId, { limit, offset }) + + return res.json({ + data: result.tasks, + total: result.total, + limit, + offset + } satisfies ListTasksResponse) + } catch (error: any) { + logger.error('Error listing tasks', { error, agentId }) + return res.status(500).json({ + error: { + message: 'Failed to list tasks', + type: 'internal_error', + code: 'task_list_failed' + } + }) + } +} + +export const getTask = async (req: Request, res: Response): Promise => { + const { agentId, taskId } = req.params + try { + logger.debug('Getting task', { agentId, taskId }) + const task = await taskService.getTask(agentId, taskId) + + if (!task) { + return res.status(404).json({ + error: { + message: 'Task not found', + type: 'not_found', + code: 'task_not_found' + } + }) + } + + return res.json(task) + } catch (error: any) { + logger.error('Error getting task', { error, agentId, taskId }) + return res.status(500).json({ + error: { + message: 'Failed to get task', + type: 'internal_error', + code: 'task_get_failed' + } + }) + } +} + +export const updateTask = async (req: Request, res: Response): Promise => { + const { agentId, taskId } = req.params + try { + logger.debug('Updating task', { agentId, taskId }) + const task = await taskService.updateTask(agentId, taskId, req.body) + + if (!task) { + return res.status(404).json({ + error: { + message: 'Task not found', + type: 'not_found', + code: 'task_not_found' + } + }) + } + + logger.info('Task updated', { agentId, taskId }) + return res.json(task) + } catch (error: any) { + logger.error('Error updating task', { error, agentId, taskId }) + return res.status(500).json({ + error: { + message: `Failed to update task: ${error.message}`, + type: 'internal_error', + code: 'task_update_failed' + } + }) + } +} + +export const deleteTask = async (req: Request, res: Response): Promise => { + const { agentId, taskId } = req.params + try { + logger.debug('Deleting task', { agentId, taskId }) + const deleted = await taskService.deleteTask(agentId, taskId) + + if (!deleted) { + return res.status(404).json({ + error: { + message: 'Task not found', + type: 'not_found', + code: 'task_not_found' + } + }) + } + + logger.info('Task deleted', { agentId, taskId }) + return res.status(204).send() + } catch (error: any) { + logger.error('Error deleting task', { error, agentId, taskId }) + return res.status(500).json({ + error: { + message: 'Failed to delete task', + type: 'internal_error', + code: 'task_delete_failed' + } + }) + } +} + +export const runTask = async (req: Request, res: Response): Promise => { + const { agentId, taskId } = req.params + try { + logger.debug('Manually running task', { agentId, taskId }) + await schedulerService.runTaskNow(agentId, taskId) + logger.info('Task triggered manually', { agentId, taskId }) + return res.json({ status: 'triggered' }) + } catch (error: any) { + const status = error.message?.includes('not found') ? 404 : error.message?.includes('already running') ? 409 : 500 + logger.error('Error running task', { error, agentId, taskId }) + return res.status(status).json({ + error: { + message: `Failed to run task: ${error.message}`, + type: status === 409 ? 'conflict' : status === 404 ? 'not_found' : 'internal_error', + code: 'task_run_failed' + } + }) + } +} + +export const getTaskLogs = async (req: Request, res: Response): Promise => { + const { agentId, taskId } = req.params + try { + const limit = req.query.limit ? parseInt(req.query.limit as string) : 20 + const offset = req.query.offset ? parseInt(req.query.offset as string) : 0 + + // Verify the task belongs to this agent + const task = await taskService.getTask(agentId, taskId) + if (!task) { + return res.status(404).json({ + error: { + message: 'Task not found', + type: 'not_found', + code: 'task_not_found' + } + }) + } + + logger.debug('Getting task logs', { taskId, limit, offset }) + const result = await taskService.getTaskLogs(taskId, { limit, offset }) + + return res.json({ + data: result.logs, + total: result.total, + limit, + offset + } satisfies ListTaskLogsResponse) + } catch (error: any) { + logger.error('Error getting task logs', { error, taskId }) + return res.status(500).json({ + error: { + message: 'Failed to get task logs', + type: 'internal_error', + code: 'task_logs_failed' + } + }) + } +} diff --git a/src/main/apiServer/routes/agents/index.ts b/src/main/apiServer/routes/agents/index.ts index 42843b72018..22aabd2f55a 100644 --- a/src/main/apiServer/routes/agents/index.ts +++ b/src/main/apiServer/routes/agents/index.ts @@ -1,6 +1,6 @@ import express from 'express' -import { agentHandlers, messageHandlers, sessionHandlers } from './handlers' +import { agentHandlers, messageHandlers, sessionHandlers, taskHandlers } from './handlers' import { checkAgentExists, handleValidationErrors } from './middleware' import { validateAgent, @@ -13,7 +13,11 @@ import { validateSessionMessage, validateSessionMessageId, validateSessionReplace, - validateSessionUpdate + validateSessionUpdate, + validateTask, + validateTaskId, + validateTaskPagination, + validateTaskUpdate } from './validators' // Create main agents router @@ -30,7 +34,7 @@ const agentsRouter = express.Router() * * AgentType: * type: string - * enum: [claude-code] + * enum: [claude-code, cherry-claw] * description: Type of agent * * AgentConfiguration: @@ -945,9 +949,31 @@ const createMessagesRouter = (): express.Router => { return messagesRouter } +// Create tasks router with agent context +const createTasksRouter = (): express.Router => { + const tasksRouter = express.Router({ mergeParams: true }) + + tasksRouter.post('/', validateTask, handleValidationErrors, taskHandlers.createTask) + tasksRouter.get('/', validateTaskPagination, handleValidationErrors, taskHandlers.listTasks) + tasksRouter.get('/:taskId', validateTaskId, handleValidationErrors, taskHandlers.getTask) + tasksRouter.patch('/:taskId', validateTaskId, validateTaskUpdate, handleValidationErrors, taskHandlers.updateTask) + tasksRouter.delete('/:taskId', validateTaskId, handleValidationErrors, taskHandlers.deleteTask) + tasksRouter.post('/:taskId/run', validateTaskId, handleValidationErrors, taskHandlers.runTask) + tasksRouter.get( + '/:taskId/logs', + validateTaskId, + validateTaskPagination, + handleValidationErrors, + taskHandlers.getTaskLogs + ) + + return tasksRouter +} + // Mount nested resources with clear hierarchy const sessionsRouter = createSessionsRouter() const messagesRouter = createMessagesRouter() +const tasksRouter = createTasksRouter() // Mount sessions under specific agent agentsRouter.use('/:agentId/sessions', validateAgentId, checkAgentExists, handleValidationErrors, sessionsRouter) @@ -961,5 +987,8 @@ agentsRouter.use( messagesRouter ) +// Mount tasks under specific agent +agentsRouter.use('/:agentId/tasks', validateAgentId, checkAgentExists, handleValidationErrors, tasksRouter) + // Export main router and convenience router export const agentsRoutes = agentsRouter diff --git a/src/main/apiServer/routes/agents/validators/index.ts b/src/main/apiServer/routes/agents/validators/index.ts index 7bba43e3b7c..e16d0182e00 100644 --- a/src/main/apiServer/routes/agents/validators/index.ts +++ b/src/main/apiServer/routes/agents/validators/index.ts @@ -2,3 +2,4 @@ export * from './agents' export * from './common' export * from './messages' export * from './sessions' +export * from './tasks' diff --git a/src/main/apiServer/routes/agents/validators/tasks.ts b/src/main/apiServer/routes/agents/validators/tasks.ts new file mode 100644 index 00000000000..3643ce90de8 --- /dev/null +++ b/src/main/apiServer/routes/agents/validators/tasks.ts @@ -0,0 +1,19 @@ +import { CreateTaskRequestSchema, PaginationQuerySchema, TaskIdParamSchema, UpdateTaskRequestSchema } from '@types' + +import { createZodValidator } from './zodValidator' + +export const validateTask = createZodValidator({ + body: CreateTaskRequestSchema +}) + +export const validateTaskUpdate = createZodValidator({ + body: UpdateTaskRequestSchema +}) + +export const validateTaskId = createZodValidator({ + params: TaskIdParamSchema +}) + +export const validateTaskPagination = createZodValidator({ + query: PaginationQuerySchema +}) diff --git a/src/main/apiServer/routes/claw-mcp.ts b/src/main/apiServer/routes/claw-mcp.ts new file mode 100644 index 00000000000..6bd2028e615 --- /dev/null +++ b/src/main/apiServer/routes/claw-mcp.ts @@ -0,0 +1,121 @@ +import ClawServer from '@main/mcpServers/claw' +import { loggerService } from '@main/services/LoggerService' +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp' +import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types' +import { isJSONRPCRequest, JSONRPCMessageSchema } from '@modelcontextprotocol/sdk/types' +import { randomUUID } from 'crypto' +import type { Request, Response } from 'express' +import express from 'express' +import type { IncomingMessage, ServerResponse } from 'http' + +const logger = loggerService.withContext('ClawMCPRoute') + +// Per-session state: each MCP session gets its own Server + Transport pair. +// The MCP SDK Server class only supports one transport at a time, so sharing +// a Server across sessions causes "Already connected" errors on reconnect. +type SessionEntry = { + server: ClawServer + transport: StreamableHTTPServerTransport + agentId: string +} + +const sessions = new Map() + +function createSessionEntry(agentId: string): SessionEntry { + const server = new ClawServer(agentId) + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: (newSessionId) => { + sessions.set(newSessionId, entry) + logger.debug('Claw MCP session initialized', { sessionId: newSessionId, agentId }) + } + }) + + const entry: SessionEntry = { server, transport, agentId } + + transport.onclose = () => { + if (transport.sessionId) { + sessions.delete(transport.sessionId) + logger.debug('Claw MCP session closed', { sessionId: transport.sessionId, agentId }) + } + } + + return entry +} + +const router = express.Router({ mergeParams: true }) + +router.all('/:agentId/claw-mcp', async (req: Request, res: Response): Promise => { + const { agentId } = req.params + if (!agentId) { + res.status(400).json({ error: 'agentId is required' }) + return + } + + const sessionId = req.headers['mcp-session-id'] as string | undefined + + let entry: SessionEntry + if (sessionId && sessions.has(sessionId)) { + entry = sessions.get(sessionId)! + } else { + entry = createSessionEntry(agentId) + await entry.server.mcpServer.connect(entry.transport) + } + + // Only parse JSON-RPC body for POST requests. + // GET (SSE streaming) and DELETE (session close) have no body. + if (req.method === 'POST') { + const jsonPayload = req.body + const messages: JSONRPCMessage[] = [] + + if (Array.isArray(jsonPayload)) { + for (const payload of jsonPayload) { + messages.push(JSONRPCMessageSchema.parse(payload)) + } + } else { + messages.push(JSONRPCMessageSchema.parse(jsonPayload)) + } + + for (const message of messages) { + if (isJSONRPCRequest(message)) { + if (!message.params) { + message.params = {} + } + if (!message.params._meta) { + message.params._meta = {} + } + message.params._meta.agentId = agentId + } + } + + logger.debug('Dispatching claw MCP POST request', { + agentId, + sessionId: entry.transport.sessionId ?? sessionId, + messageCount: messages.length + }) + + await entry.transport.handleRequest(req as IncomingMessage, res as ServerResponse, messages) + } else { + // GET / DELETE — let the transport handle directly without body parsing + logger.debug('Dispatching claw MCP request', { + method: req.method, + agentId, + sessionId: entry.transport.sessionId ?? sessionId + }) + + await entry.transport.handleRequest(req as IncomingMessage, res as ServerResponse) + } +}) + +/** + * Clean up all claw sessions for a specific agent (e.g. on agent deletion). + */ +export function cleanupClawServer(agentId: string): void { + for (const [sessionId, entry] of sessions) { + if (entry.agentId === agentId) { + sessions.delete(sessionId) + } + } +} + +export { router as clawMcpRoutes } diff --git a/src/main/index.ts b/src/main/index.ts index 24a8c81fad5..3f128db740e 100644 --- a/src/main/index.ts +++ b/src/main/index.ts @@ -16,6 +16,8 @@ import process from 'node:process' import { registerIpc } from './ipc' import { agentService } from './services/agents' +import { schedulerService } from './services/agents/services/SchedulerService' +import { channelManager } from './services/agents/services/channels' import { analyticsService } from './services/AnalyticsService' import { apiServerService } from './services/ApiServerService' import { appMenuService } from './services/AppMenuService' @@ -210,6 +212,12 @@ if (!app.requestSingleInstanceLock()) { if (shouldStart) { await apiServerService.start() } + + // Restore CherryClaw schedulers after services are ready + await schedulerService.restoreSchedulers() + + // Start CherryClaw channel adapters (Telegram, etc.) + await channelManager.start() } catch (error: any) { logger.error('Failed to check/start API server:', error) } @@ -270,6 +278,8 @@ if (!app.requestSingleInstanceLock()) { } try { + schedulerService.stopAll() + await channelManager.stop() await analyticsService.destroy() await openClawService.stopGateway() await mcpService.cleanup() diff --git a/src/main/mcpServers/__tests__/claw.test.ts b/src/main/mcpServers/__tests__/claw.test.ts new file mode 100644 index 00000000000..e55206457fd --- /dev/null +++ b/src/main/mcpServers/__tests__/claw.test.ts @@ -0,0 +1,609 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock TaskService before importing ClawServer +const mockCreateTask = vi.fn() +const mockListTasks = vi.fn() +const mockDeleteTask = vi.fn() +const mockGetNotifyAdapters = vi.fn() +const mockSendMessage = vi.fn() +const mockPluginInstall = vi.fn() +const mockPluginUninstall = vi.fn() +const mockPluginListInstalled = vi.fn() +const mockNetFetch = vi.fn() +const mockGetAgent = vi.fn() +const mockMkdir = vi.fn() +const mockWriteFile = vi.fn() +const mockRename = vi.fn() +const mockAppendFile = vi.fn() +const mockReadFile = vi.fn() +const mockReaddir = vi.fn() +const mockStat = vi.fn() + +vi.mock('node:fs/promises', () => ({ + mkdir: (...args: unknown[]) => mockMkdir(...args), + writeFile: (...args: unknown[]) => mockWriteFile(...args), + rename: (...args: unknown[]) => mockRename(...args), + appendFile: (...args: unknown[]) => mockAppendFile(...args), + readFile: (...args: unknown[]) => mockReadFile(...args), + readdir: (...args: unknown[]) => mockReaddir(...args), + stat: (...args: unknown[]) => mockStat(...args) +})) + +vi.mock('@main/services/agents/services/TaskService', () => ({ + taskService: { + createTask: mockCreateTask, + listTasks: mockListTasks, + deleteTask: mockDeleteTask + } +})) + +vi.mock('@main/services/agents/services/AgentService', () => ({ + agentService: { + getAgent: mockGetAgent + } +})) + +vi.mock('@main/services/agents/services/channels/ChannelManager', () => ({ + channelManager: { + getNotifyAdapters: mockGetNotifyAdapters + } +})) + +vi.mock('@main/services/agents/plugins/PluginService', () => ({ + PluginService: { + getInstance: () => ({ + install: mockPluginInstall, + uninstall: mockPluginUninstall, + listInstalled: mockPluginListInstalled + }) + } +})) + +vi.mock('electron', () => ({ + net: { + fetch: mockNetFetch + } +})) + +// Import after mocks +const { default: ClawServer } = await import('../claw') +type ClawServerInstance = InstanceType + +function createServer(agentId = 'agent_test') { + return new ClawServer(agentId) +} + +// Helper to call tools via the Server's request handlers +async function callTool(server: ClawServerInstance, args: Record, toolName = 'cron') { + // Use the server's internal handler by simulating a CallTool request + const handlers = (server.mcpServer.server as any)._requestHandlers + const callToolHandler = handlers?.get('tools/call') + if (!callToolHandler) { + throw new Error('No tools/call handler registered') + } + + return callToolHandler( + { method: 'tools/call', params: { name: toolName, arguments: args } }, + {} // extra + ) +} + +async function listTools(server: ClawServerInstance) { + const handlers = (server.mcpServer.server as any)._requestHandlers + const listHandler = handlers?.get('tools/list') + if (!listHandler) { + throw new Error('No tools/list handler registered') + } + return listHandler({ method: 'tools/list', params: {} }, {}) +} + +describe('ClawServer', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should list the cron, notify, skills, and memory tools', async () => { + const server = createServer() + const result = await listTools(server) + expect(result.tools).toHaveLength(4) + expect(result.tools.map((t: any) => t.name)).toEqual(['cron', 'notify', 'skills', 'memory']) + }) + + describe('add action', () => { + it('should create a task with cron schedule', async () => { + const task = { id: 'task_1', name: 'test', schedule_type: 'cron', schedule_value: '0 9 * * 1-5' } + mockCreateTask.mockResolvedValue(task) + + const server = createServer('agent_1') + const result = await callTool(server, { + action: 'add', + name: 'Daily standup', + message: 'Run standup check', + cron: '0 9 * * 1-5' + }) + + expect(mockCreateTask).toHaveBeenCalledWith('agent_1', { + name: 'Daily standup', + prompt: 'Run standup check', + schedule_type: 'cron', + schedule_value: '0 9 * * 1-5', + context_mode: 'session' + }) + expect(result.content[0].text).toContain('Job created') + }) + + it('should create a task with interval schedule', async () => { + const task = { id: 'task_2', name: 'check', schedule_type: 'interval', schedule_value: '30' } + mockCreateTask.mockResolvedValue(task) + + const server = createServer('agent_2') + await callTool(server, { + action: 'add', + name: 'Health check', + message: 'Check system health', + every: '30m' + }) + + expect(mockCreateTask).toHaveBeenCalledWith('agent_2', { + name: 'Health check', + prompt: 'Check system health', + schedule_type: 'interval', + schedule_value: '30', + context_mode: 'session' + }) + }) + + it('should parse hour+minute durations', async () => { + mockCreateTask.mockResolvedValue({ id: 'task_3' }) + + const server = createServer() + await callTool(server, { + action: 'add', + name: 'test', + message: 'test', + every: '1h30m' + }) + + expect(mockCreateTask).toHaveBeenCalledWith( + 'agent_test', + expect.objectContaining({ + schedule_type: 'interval', + schedule_value: '90' + }) + ) + }) + + it('should create a one-time task with at', async () => { + mockCreateTask.mockResolvedValue({ id: 'task_4' }) + + const server = createServer() + await callTool(server, { + action: 'add', + name: 'Deploy', + message: 'Deploy to prod', + at: '2024-01-15T14:30:00+08:00' + }) + + expect(mockCreateTask).toHaveBeenCalledWith( + 'agent_test', + expect.objectContaining({ + schedule_type: 'once' + }) + ) + }) + + it('should use isolated context mode when session_mode is new', async () => { + mockCreateTask.mockResolvedValue({ id: 'task_5' }) + + const server = createServer() + await callTool(server, { + action: 'add', + name: 'test', + message: 'test', + cron: '* * * * *', + session_mode: 'new' + }) + + expect(mockCreateTask).toHaveBeenCalledWith( + 'agent_test', + expect.objectContaining({ + context_mode: 'isolated' + }) + ) + }) + + it('should reject when no schedule is provided', async () => { + const server = createServer() + const result = await callTool(server, { + action: 'add', + name: 'test', + message: 'test' + }) + + expect(result.isError).toBe(true) + expect(mockCreateTask).not.toHaveBeenCalled() + }) + + it('should reject when multiple schedules are provided', async () => { + const server = createServer() + const result = await callTool(server, { + action: 'add', + name: 'test', + message: 'test', + cron: '* * * * *', + every: '30m' + }) + + expect(result.isError).toBe(true) + expect(mockCreateTask).not.toHaveBeenCalled() + }) + }) + + describe('list action', () => { + it('should list tasks', async () => { + const tasks = [{ id: 'task_1', name: 'Job 1' }] + mockListTasks.mockResolvedValue({ tasks, total: 1 }) + + const server = createServer('agent_1') + const result = await callTool(server, { action: 'list' }) + + expect(mockListTasks).toHaveBeenCalledWith('agent_1', { limit: 100 }) + expect(result.content[0].text).toContain('Job 1') + }) + + it('should handle empty task list', async () => { + mockListTasks.mockResolvedValue({ tasks: [], total: 0 }) + + const server = createServer() + const result = await callTool(server, { action: 'list' }) + + expect(result.content[0].text).toBe('No scheduled jobs.') + }) + }) + + describe('remove action', () => { + it('should remove a task', async () => { + mockDeleteTask.mockResolvedValue(true) + + const server = createServer('agent_1') + const result = await callTool(server, { action: 'remove', id: 'task_1' }) + + expect(mockDeleteTask).toHaveBeenCalledWith('agent_1', 'task_1') + expect(result.content[0].text).toContain('removed') + }) + + it('should error when task not found', async () => { + mockDeleteTask.mockResolvedValue(false) + + const server = createServer() + const result = await callTool(server, { action: 'remove', id: 'nonexistent' }) + + expect(result.isError).toBe(true) + }) + }) + + describe('notify tool', () => { + function makeAdapter(channelId: string, chatIds: string[]) { + return { + channelId, + notifyChatIds: chatIds, + sendMessage: mockSendMessage + } + } + + it('should send notification to all notify adapters', async () => { + mockSendMessage.mockResolvedValue(undefined) + mockGetNotifyAdapters.mockReturnValue([makeAdapter('ch1', ['100', '200'])]) + + const server = createServer('agent_1') + const result = await callTool(server, { message: 'Hello user!' }, 'notify') + + expect(mockGetNotifyAdapters).toHaveBeenCalledWith('agent_1') + expect(mockSendMessage).toHaveBeenCalledTimes(2) + expect(mockSendMessage).toHaveBeenCalledWith('100', 'Hello user!') + expect(mockSendMessage).toHaveBeenCalledWith('200', 'Hello user!') + expect(result.content[0].text).toContain('2 chat(s)') + }) + + it('should filter by channel_id when provided', async () => { + mockSendMessage.mockResolvedValue(undefined) + mockGetNotifyAdapters.mockReturnValue([makeAdapter('ch1', ['100']), makeAdapter('ch2', ['200'])]) + + const server = createServer('agent_1') + const result = await callTool(server, { message: 'Targeted', channel_id: 'ch2' }, 'notify') + + expect(mockSendMessage).toHaveBeenCalledTimes(1) + expect(mockSendMessage).toHaveBeenCalledWith('200', 'Targeted') + expect(result.content[0].text).toContain('1 chat(s)') + }) + + it('should return message when no notify channels found', async () => { + mockGetNotifyAdapters.mockReturnValue([]) + + const server = createServer('agent_1') + const result = await callTool(server, { message: 'Hello' }, 'notify') + + expect(result.content[0].text).toContain('No notify-enabled channels') + expect(mockSendMessage).not.toHaveBeenCalled() + }) + + it('should error when message is missing', async () => { + const server = createServer() + const result = await callTool(server, {}, 'notify') + + expect(result.isError).toBe(true) + expect(result.content[0].text).toContain("'message' is required") + }) + + it('should report partial failures', async () => { + mockSendMessage.mockResolvedValueOnce(undefined).mockRejectedValueOnce(new Error('rate limited')) + mockGetNotifyAdapters.mockReturnValue([makeAdapter('ch1', ['100', '200'])]) + + const server = createServer('agent_1') + const result = await callTool(server, { message: 'Test' }, 'notify') + + expect(result.content[0].text).toContain('1 chat(s)') + expect(result.content[0].text).toContain('rate limited') + }) + }) + + describe('skills tool', () => { + it('should search marketplace skills', async () => { + const mockResponse = { + ok: true, + status: 200, + json: vi.fn().mockResolvedValue({ + skills: [ + { + name: 'gh-create-pr', + description: 'Create GitHub PRs', + author: 'test-author', + namespace: '@test-owner/test-repo', + installs: 42, + metadata: { repoOwner: 'test-owner', repoName: 'test-repo' } + } + ], + total: 1 + }) + } + mockNetFetch.mockResolvedValue(mockResponse) + + const server = createServer('agent_1') + const result = await callTool(server, { action: 'search', query: 'github pr' }, 'skills') + + expect(mockNetFetch).toHaveBeenCalledWith(expect.stringContaining('/api/skills'), { method: 'GET' }) + expect(result.content[0].text).toContain('gh-create-pr') + expect(result.content[0].text).toContain('test-owner/test-repo/gh-create-pr') + }) + + it('should handle empty search results', async () => { + mockNetFetch.mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue({ skills: [], total: 0 }) + }) + + const server = createServer() + const result = await callTool(server, { action: 'search', query: 'nonexistent' }, 'skills') + + expect(result.content[0].text).toContain('No skills found') + }) + + it('should error when query is missing for search', async () => { + const server = createServer() + const result = await callTool(server, { action: 'search' }, 'skills') + + expect(result.isError).toBe(true) + expect(result.content[0].text).toContain("'query' is required") + }) + + it('should install a marketplace skill', async () => { + mockPluginInstall.mockResolvedValue({ + name: 'gh-create-pr', + description: 'Create PRs', + filename: 'gh-create-pr' + }) + + const server = createServer('agent_1') + const result = await callTool(server, { action: 'install', identifier: 'owner/repo/gh-create-pr' }, 'skills') + + expect(mockPluginInstall).toHaveBeenCalledWith({ + agentId: 'agent_1', + sourcePath: 'marketplace:skill:owner/repo/gh-create-pr', + type: 'skill' + }) + expect(result.content[0].text).toContain('Skill installed') + expect(result.content[0].text).toContain('gh-create-pr') + }) + + it('should error when identifier is missing for install', async () => { + const server = createServer() + const result = await callTool(server, { action: 'install' }, 'skills') + + expect(result.isError).toBe(true) + expect(result.content[0].text).toContain("'identifier' is required") + }) + + it('should remove an installed skill', async () => { + mockPluginUninstall.mockResolvedValue(undefined) + + const server = createServer('agent_1') + const result = await callTool(server, { action: 'remove', name: 'gh-create-pr' }, 'skills') + + expect(mockPluginUninstall).toHaveBeenCalledWith({ + agentId: 'agent_1', + filename: 'gh-create-pr', + type: 'skill' + }) + expect(result.content[0].text).toContain('removed') + }) + + it('should error when name is missing for remove', async () => { + const server = createServer() + const result = await callTool(server, { action: 'remove' }, 'skills') + + expect(result.isError).toBe(true) + expect(result.content[0].text).toContain("'name' is required") + }) + + it('should list installed skills', async () => { + mockPluginListInstalled.mockResolvedValue([ + { type: 'skill', filename: 'gh-create-pr', metadata: { name: 'gh-create-pr', description: 'Create PRs' } }, + { type: 'agent', filename: 'some-agent.md', metadata: { name: 'some-agent', description: 'An agent' } }, + { type: 'skill', filename: 'code-review', metadata: { name: 'code-review', description: 'Review code' } } + ]) + + const server = createServer('agent_1') + const result = await callTool(server, { action: 'list' }, 'skills') + + expect(mockPluginListInstalled).toHaveBeenCalledWith('agent_1') + expect(result.content[0].text).toContain('gh-create-pr') + expect(result.content[0].text).toContain('code-review') + // Should not include the agent + expect(result.content[0].text).not.toContain('some-agent') + }) + + it('should handle empty skills list', async () => { + mockPluginListInstalled.mockResolvedValue([]) + + const server = createServer() + const result = await callTool(server, { action: 'list' }, 'skills') + + expect(result.content[0].text).toBe('No skills installed.') + }) + + it('should handle unknown skills action', async () => { + const server = createServer() + const result = await callTool(server, { action: 'unknown' }, 'skills') + + expect(result.isError).toBe(true) + expect(result.content[0].text).toContain('Unknown action') + }) + }) + + describe('memory tool', () => { + const agentWithWorkspace = { accessible_paths: ['/workspace/test'] } + + beforeEach(() => { + mockGetAgent.mockResolvedValue(agentWithWorkspace) + mockMkdir.mockResolvedValue(undefined) + mockWriteFile.mockResolvedValue(undefined) + mockRename.mockResolvedValue(undefined) + mockAppendFile.mockResolvedValue(undefined) + // resolveFileCI: exact path always found (case-sensitive match) + mockStat.mockResolvedValue({ mtimeMs: 1000 }) + }) + + it('should update FACT.md atomically', async () => { + const server = createServer('agent_1') + const result = await callTool(server, { action: 'update', content: '# Facts\n\nNew knowledge' }, 'memory') + + expect(mockMkdir).toHaveBeenCalledWith('/workspace/test/memory', { recursive: true }) + expect(mockWriteFile).toHaveBeenCalledWith( + expect.stringContaining('FACT.md.'), + '# Facts\n\nNew knowledge', + 'utf-8' + ) + expect(mockRename).toHaveBeenCalled() + expect(result.content[0].text).toBe('Memory updated.') + }) + + it('should error when content is missing for update', async () => { + const server = createServer('agent_1') + const result = await callTool(server, { action: 'update' }, 'memory') + + expect(result.isError).toBe(true) + expect(result.content[0].text).toContain("'content' is required") + }) + + it('should append journal entry with tags', async () => { + const server = createServer('agent_1') + const result = await callTool( + server, + { action: 'append', text: 'Deployed v2.0', tags: ['deploy', 'release'] }, + 'memory' + ) + + expect(mockAppendFile).toHaveBeenCalledWith( + '/workspace/test/memory/JOURNAL.jsonl', + expect.stringContaining('"text":"Deployed v2.0"'), + 'utf-8' + ) + expect(result.content[0].text).toContain('Journal entry added') + }) + + it('should error when text is missing for append', async () => { + const server = createServer('agent_1') + const result = await callTool(server, { action: 'append' }, 'memory') + + expect(result.isError).toBe(true) + expect(result.content[0].text).toContain("'text' is required") + }) + + it('should search journal entries', async () => { + const entries = [ + '{"ts":"2024-01-01T00:00:00Z","tags":["deploy"],"text":"Deployed v1.0"}', + '{"ts":"2024-01-02T00:00:00Z","tags":["bugfix"],"text":"Fixed login bug"}', + '{"ts":"2024-01-03T00:00:00Z","tags":["deploy"],"text":"Deployed v2.0"}' + ].join('\n') + mockReadFile.mockResolvedValue(entries) + + const server = createServer('agent_1') + const result = await callTool(server, { action: 'search', tag: 'deploy' }, 'memory') + + const parsed = JSON.parse(result.content[0].text) + expect(parsed).toHaveLength(2) + expect(parsed[0].text).toBe('Deployed v2.0') // reverse chronological + }) + + it('should search journal with text query', async () => { + const entries = [ + '{"ts":"2024-01-01T00:00:00Z","tags":[],"text":"Setup project"}', + '{"ts":"2024-01-02T00:00:00Z","tags":[],"text":"Fixed login bug"}' + ].join('\n') + mockReadFile.mockResolvedValue(entries) + + const server = createServer('agent_1') + const result = await callTool(server, { action: 'search', query: 'login' }, 'memory') + + const parsed = JSON.parse(result.content[0].text) + expect(parsed).toHaveLength(1) + expect(parsed[0].text).toBe('Fixed login bug') + }) + + it('should return message when journal has no matches', async () => { + mockReadFile.mockResolvedValue('{"ts":"2024-01-01T00:00:00Z","tags":[],"text":"hello"}\n') + + const server = createServer('agent_1') + const result = await callTool(server, { action: 'search', query: 'nonexistent' }, 'memory') + + expect(result.content[0].text).toBe('No matching journal entries found.') + }) + + it('should return message when journal does not exist', async () => { + mockReadFile.mockRejectedValue(new Error('ENOENT')) + + const server = createServer('agent_1') + const result = await callTool(server, { action: 'search' }, 'memory') + + expect(result.content[0].text).toBe('No journal entries found.') + }) + + it('should error when agent has no workspace', async () => { + mockGetAgent.mockResolvedValue({ accessible_paths: [] }) + + const server = createServer('agent_1') + const result = await callTool(server, { action: 'update', content: 'test' }, 'memory') + + expect(result.isError).toBe(true) + expect(result.content[0].text).toContain('no workspace path') + }) + + it('should handle unknown memory action', async () => { + const server = createServer() + const result = await callTool(server, { action: 'unknown' }, 'memory') + + expect(result.isError).toBe(true) + expect(result.content[0].text).toContain('Unknown action') + }) + }) +}) diff --git a/src/main/mcpServers/claw.ts b/src/main/mcpServers/claw.ts new file mode 100644 index 00000000000..fee399527d7 --- /dev/null +++ b/src/main/mcpServers/claw.ts @@ -0,0 +1,678 @@ +import { appendFile, mkdir, readdir, readFile, rename, stat, writeFile } from 'node:fs/promises' +import path from 'node:path' + +import { loggerService } from '@logger' +import { PluginService } from '@main/services/agents/plugins/PluginService' +import { agentService } from '@main/services/agents/services/AgentService' +import { channelManager } from '@main/services/agents/services/channels/ChannelManager' +import { taskService } from '@main/services/agents/services/TaskService' +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import type { Tool } from '@modelcontextprotocol/sdk/types.js' +import { CallToolRequestSchema, ErrorCode, ListToolsRequestSchema, McpError } from '@modelcontextprotocol/sdk/types.js' +import type { TaskContextMode, TaskScheduleType } from '@types' +import { net } from 'electron' + +const logger = loggerService.withContext('MCPServer:Claw') + +/** + * Parse a human-friendly duration string (e.g. '30m', '2h', '1h30m') into minutes. + */ +function parseDurationToMinutes(duration: string): number { + let totalMinutes = 0 + const hourMatch = duration.match(/(\d+)\s*h/i) + const minMatch = duration.match(/(\d+)\s*m/i) + + if (hourMatch) totalMinutes += parseInt(hourMatch[1], 10) * 60 + if (minMatch) totalMinutes += parseInt(minMatch[1], 10) + + if (totalMinutes === 0) { + const raw = parseInt(duration, 10) + if (!isNaN(raw) && raw > 0) return raw + throw new Error(`Invalid duration: "${duration}". Use formats like '30m', '2h', '1h30m'.`) + } + + return totalMinutes +} + +type SkillSearchResult = { + name: string + namespace?: string + description?: string | null + author?: string | null + installs?: number + metadata?: { + repoOwner?: string + repoName?: string + } +} + +function buildSkillIdentifier(skill: SkillSearchResult): string { + const { name, namespace, metadata } = skill + const repoOwner = metadata?.repoOwner + const repoName = metadata?.repoName + + if (repoOwner && repoName) { + return `${repoOwner}/${repoName}/${name}` + } + + if (namespace) { + const cleanNamespace = namespace.replace(/^@/, '') + const parts = cleanNamespace.split('/').filter(Boolean) + if (parts.length >= 2) { + return `${parts[0]}/${parts[1]}/${name}` + } + return `${cleanNamespace}/${name}` + } + + return name +} + +const CRON_TOOL: Tool = { + name: 'cron', + description: + "Manage scheduled tasks. Use action 'add' to create a recurring or one-time job, 'list' to see all jobs, or 'remove' to delete a job. For one-time jobs, use the 'at' field with an RFC3339 timestamp.", + inputSchema: { + type: 'object', + properties: { + action: { + type: 'string', + enum: ['add', 'list', 'remove'], + description: 'The action to perform' + }, + name: { + type: 'string', + description: 'Name of the job (required for add)' + }, + message: { + type: 'string', + description: 'The prompt/instruction to execute on schedule (required for add)' + }, + cron: { + type: 'string', + description: "Cron expression, e.g. '0 9 * * 1-5' for weekdays at 9am (use cron OR every, not both)" + }, + every: { + type: 'string', + description: "Duration, e.g. '30m', '2h', '24h' (use every OR cron, not both)" + }, + at: { + type: 'string', + description: + "RFC3339 timestamp for a one-time job, e.g. '2024-01-15T14:30:00+08:00' (use at OR cron OR every, not combined)" + }, + session_mode: { + type: 'string', + enum: ['reuse', 'new'], + description: + "Session behavior: 'reuse' (default) keeps conversation history across executions, 'new' starts a fresh session each time" + }, + id: { + type: 'string', + description: 'Job ID (required for remove)' + } + }, + required: ['action'] + } +} + +const NOTIFY_TOOL: Tool = { + name: 'notify', + description: + 'Send a notification message to the user through connected channels (e.g. Telegram). Use this to proactively inform the user about task results, status updates, or any important information.', + inputSchema: { + type: 'object', + properties: { + message: { + type: 'string', + description: 'The notification message to send to the user' + }, + channel_id: { + type: 'string', + description: 'Optional: send to a specific channel only (omit to send to all notify-enabled channels)' + } + }, + required: ['message'] + } +} + +const MARKETPLACE_BASE_URL = 'https://claude-plugins.dev' + +const SKILLS_TOOL: Tool = { + name: 'skills', + description: + "Manage Claude skills in the agent's workspace. Use action 'search' to find skills from the marketplace, 'install' to install a skill, 'remove' to uninstall a skill, or 'list' to see installed skills.", + inputSchema: { + type: 'object', + properties: { + action: { + type: 'string', + enum: ['search', 'install', 'remove', 'list'], + description: 'The action to perform' + }, + query: { + type: 'string', + description: "Search query for finding skills in the marketplace (required for 'search')" + }, + identifier: { + type: 'string', + description: + "Marketplace skill identifier in 'owner/repo/skill-name' format (required for 'install'). Get this from the search results." + }, + name: { + type: 'string', + description: "Skill folder name to remove (required for 'remove'). Get this from the list results." + } + }, + required: ['action'] + } +} + +/** + * Resolve a filename within a directory using case-insensitive matching. + * Returns the full path if found (preferring exact match), or the canonical path as fallback. + */ +async function resolveFileCI(dir: string, name: string): Promise { + const exact = path.join(dir, name) + try { + await stat(exact) + return exact + } catch { + // exact match not found, try case-insensitive + } + + try { + const entries = await readdir(dir) + const target = name.toLowerCase() + const match = entries.find((e) => e.toLowerCase() === target) + return match ? path.join(dir, match) : exact + } catch { + return exact + } +} + +type JournalEntry = { + ts: string + tags: string[] + text: string +} + +const MEMORY_TOOL: Tool = { + name: 'memory', + description: + "Manage persistent memory across sessions. Actions: 'update' overwrites memory/FACT.md (only durable project knowledge and decisions — not user preferences or personality, those belong in user.md and soul.md). 'append' logs to memory/JOURNAL.jsonl (one-time events, completed tasks, session notes). 'search' queries the journal. Before writing to FACT.md, ask: will this still matter in 6 months? If not, use append instead.", + inputSchema: { + type: 'object', + properties: { + action: { + type: 'string', + enum: ['update', 'append', 'search'], + description: + "Action to perform: 'update' overwrites FACT.md (durable project knowledge only), 'append' adds a JOURNAL entry, 'search' queries the journal" + }, + content: { + type: 'string', + description: 'Full markdown content for FACT.md (required for update)' + }, + text: { + type: 'string', + description: 'Journal entry text (required for append)' + }, + tags: { + type: 'array', + items: { type: 'string' }, + description: 'Tags for the journal entry (optional, for append)' + }, + query: { + type: 'string', + description: 'Search query — case-insensitive substring match (for search)' + }, + tag: { + type: 'string', + description: 'Filter by tag (optional, for search)' + }, + limit: { + type: 'integer', + description: 'Max results to return (default 20, for search)' + } + }, + required: ['action'] + } +} + +class ClawServer { + public mcpServer: McpServer + private agentId: string + + constructor(agentId: string) { + this.agentId = agentId + this.mcpServer = new McpServer( + { + name: 'claw', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ) + this.setupHandlers() + } + + private setupHandlers() { + this.mcpServer.server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [CRON_TOOL, NOTIFY_TOOL, SKILLS_TOOL, MEMORY_TOOL] + })) + + this.mcpServer.server.setRequestHandler(CallToolRequestSchema, async (request) => { + const toolName = request.params.name + const args = (request.params.arguments ?? {}) as Record + + try { + switch (toolName) { + case 'cron': { + const action = args.action + switch (action) { + case 'add': + return await this.addJob(args) + case 'list': + return await this.listJobs() + case 'remove': + return await this.removeJob(args) + default: + throw new McpError(ErrorCode.InvalidParams, `Unknown action "${action}", expected add/list/remove`) + } + } + case 'notify': + return await this.sendNotification(args) + case 'skills': { + const action = args.action + switch (action) { + case 'search': + return await this.searchSkills(args) + case 'install': + return await this.installSkill(args) + case 'remove': + return await this.removeSkill(args) + case 'list': + return await this.listSkills() + default: + throw new McpError( + ErrorCode.InvalidParams, + `Unknown action "${action}", expected search/install/remove/list` + ) + } + } + case 'memory': { + const action = args.action + switch (action) { + case 'update': + return await this.memoryUpdate(args) + case 'append': + return await this.memoryAppend(args) + case 'search': + return await this.memorySearch(args) + default: + throw new McpError(ErrorCode.InvalidParams, `Unknown action "${action}", expected update/append/search`) + } + } + default: + throw new McpError(ErrorCode.MethodNotFound, `Unknown tool: ${toolName}`) + } + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + logger.error(`Tool error: ${toolName}`, { agentId: this.agentId, error: message }) + return { + content: [{ type: 'text' as const, text: `Error: ${message}` }], + isError: true + } + } + }) + } + + private async addJob(args: Record) { + const name = args.name + const message = args.message + const cronExpr = args.cron + const every = args.every + const at = args.at + const sessionMode = args.session_mode + + if (!name) throw new McpError(ErrorCode.InvalidParams, "'name' is required for add") + if (!message) throw new McpError(ErrorCode.InvalidParams, "'message' is required for add") + + // Determine schedule type and value + const scheduleCount = [cronExpr, every, at].filter(Boolean).length + if (scheduleCount === 0) throw new McpError(ErrorCode.InvalidParams, "One of 'cron', 'every', or 'at' is required") + if (scheduleCount > 1) throw new McpError(ErrorCode.InvalidParams, "Use only one of 'cron', 'every', or 'at'") + + let scheduleType: TaskScheduleType + let scheduleValue: string + + if (cronExpr) { + scheduleType = 'cron' + scheduleValue = cronExpr + } else if (every) { + scheduleType = 'interval' + scheduleValue = String(parseDurationToMinutes(every)) + } else { + scheduleType = 'once' + // Validate and normalize to ISO string + const date = new Date(at!) + if (isNaN(date.getTime())) throw new McpError(ErrorCode.InvalidParams, `Invalid timestamp: "${at}"`) + scheduleValue = date.toISOString() + } + + const contextMode: TaskContextMode = sessionMode === 'new' ? 'isolated' : 'session' + + const task = await taskService.createTask(this.agentId, { + name, + prompt: message, + schedule_type: scheduleType, + schedule_value: scheduleValue, + context_mode: contextMode + }) + + logger.info('Cron job created via tool', { agentId: this.agentId, taskId: task.id }) + return { + content: [{ type: 'text' as const, text: `Job created:\n${JSON.stringify(task, null, 2)}` }] + } + } + + private async listJobs() { + const { tasks } = await taskService.listTasks(this.agentId, { limit: 100 }) + + if (tasks.length === 0) { + return { content: [{ type: 'text' as const, text: 'No scheduled jobs.' }] } + } + + return { + content: [{ type: 'text' as const, text: JSON.stringify(tasks, null, 2) }] + } + } + + private async sendNotification(args: Record) { + const message = args.message + if (!message) throw new McpError(ErrorCode.InvalidParams, "'message' is required for notify") + + const targetChannelId = args.channel_id + let adapters = channelManager.getNotifyAdapters(this.agentId) + + if (targetChannelId) { + adapters = adapters.filter((a) => a.channelId === targetChannelId) + } + + if (adapters.length === 0) { + return { + content: [ + { + type: 'text' as const, + text: 'No notify-enabled channels found. Enable `is_notify_receiver` on at least one channel in agent settings.' + } + ] + } + } + + let sent = 0 + const errors: string[] = [] + + for (const adapter of adapters) { + for (const chatId of adapter.notifyChatIds) { + try { + await adapter.sendMessage(chatId, message) + sent++ + } catch (err) { + const errMsg = err instanceof Error ? err.message : String(err) + errors.push(`${adapter.channelId}/${chatId}: ${errMsg}`) + logger.warn('Failed to send notification', { + agentId: this.agentId, + channelId: adapter.channelId, + chatId, + error: errMsg + }) + } + } + } + + const parts = [`Notification sent to ${sent} chat(s).`] + if (errors.length > 0) { + parts.push(`Errors: ${errors.join('; ')}`) + } + + logger.info('Notification sent via notify tool', { agentId: this.agentId, sent, errors: errors.length }) + return { + content: [{ type: 'text' as const, text: parts.join(' ') }] + } + } + + private async searchSkills(args: Record) { + const query = args.query + if (!query) throw new McpError(ErrorCode.InvalidParams, "'query' is required for search") + + const url = new URL(`${MARKETPLACE_BASE_URL}/api/skills`) + url.searchParams.set('q', query.replace(/[-_]+/g, ' ').trim()) + url.searchParams.set('limit', '20') + url.searchParams.set('offset', '0') + + const response = await net.fetch(url.toString(), { method: 'GET' }) + if (!response.ok) { + throw new Error(`Marketplace API returned ${response.status}: ${response.statusText}`) + } + + const json = (await response.json()) as { skills?: SkillSearchResult[]; total?: number } + const skills = json.skills ?? [] + + if (skills.length === 0) { + return { content: [{ type: 'text' as const, text: `No skills found for "${query}".` }] } + } + + const results = skills.map((s) => ({ + name: s.name, + description: s.description ?? null, + author: s.author ?? null, + identifier: buildSkillIdentifier(s), + installs: s.installs ?? 0 + })) + + logger.info('Skills search via tool', { agentId: this.agentId, query, resultCount: results.length }) + return { + content: [ + { + type: 'text' as const, + text: `Found ${results.length} skill(s) for "${query}":\n${JSON.stringify(results, null, 2)}\n\nUse the 'identifier' field with action 'install' to install a skill.` + } + ] + } + } + + private async installSkill(args: Record) { + const identifier = args.identifier + if (!identifier) { + throw new McpError( + ErrorCode.InvalidParams, + "'identifier' is required for install (format: 'owner/repo/skill-name')" + ) + } + + const pluginService = PluginService.getInstance() + const sourcePath = `marketplace:skill:${identifier}` + + const metadata = await pluginService.install({ + agentId: this.agentId, + sourcePath, + type: 'skill' + }) + + logger.info('Skill installed via tool', { agentId: this.agentId, identifier, name: metadata.name }) + return { + content: [ + { + type: 'text' as const, + text: `Skill installed:\n Name: ${metadata.name}\n Description: ${metadata.description ?? 'N/A'}\n Folder: ${metadata.filename}` + } + ] + } + } + + private async removeSkill(args: Record) { + const name = args.name + if (!name) throw new McpError(ErrorCode.InvalidParams, "'name' is required for remove (skill folder name)") + + const pluginService = PluginService.getInstance() + + await pluginService.uninstall({ + agentId: this.agentId, + filename: name, + type: 'skill' + }) + + logger.info('Skill removed via tool', { agentId: this.agentId, name }) + return { + content: [{ type: 'text' as const, text: `Skill "${name}" removed.` }] + } + } + + private async listSkills() { + const pluginService = PluginService.getInstance() + const allPlugins = await pluginService.listInstalled(this.agentId) + const skills = allPlugins.filter((p) => p.type === 'skill') + + if (skills.length === 0) { + return { content: [{ type: 'text' as const, text: 'No skills installed.' }] } + } + + const results = skills.map((s) => ({ + name: s.metadata.name, + folder: s.filename, + description: s.metadata.description ?? null + })) + + logger.info('Skills list via tool', { agentId: this.agentId, count: results.length }) + return { + content: [{ type: 'text' as const, text: JSON.stringify(results, null, 2) }] + } + } + + private async getWorkspacePath(): Promise { + const agent = await agentService.getAgent(this.agentId) + if (!agent) throw new McpError(ErrorCode.InternalError, `Agent not found: ${this.agentId}`) + const workspace = agent.accessible_paths?.[0] + if (!workspace) throw new McpError(ErrorCode.InternalError, 'Agent has no workspace path configured') + return workspace + } + + private async memoryUpdate(args: Record) { + const content = args.content + if (!content) throw new McpError(ErrorCode.InvalidParams, "'content' is required for update action") + + const workspace = await this.getWorkspacePath() + const memoryDir = path.join(workspace, 'memory') + const factPath = await resolveFileCI(memoryDir, 'FACT.md') + + await mkdir(memoryDir, { recursive: true }) + + // Atomic write via temp file + rename + const tmpPath = `${factPath}.${Date.now()}.tmp` + await writeFile(tmpPath, content, 'utf-8') + await rename(tmpPath, factPath) + + logger.info('Memory FACT.md updated via tool', { agentId: this.agentId, length: content.length }) + return { + content: [{ type: 'text' as const, text: 'Memory updated.' }] + } + } + + private async memoryAppend(args: Record) { + const text = args.text + if (!text) throw new McpError(ErrorCode.InvalidParams, "'text' is required for append action") + + const tags: string[] = [] + const rawTags = (args as Record).tags + if (Array.isArray(rawTags)) { + for (const item of rawTags) { + if (typeof item === 'string') tags.push(item) + } + } + + const workspace = await this.getWorkspacePath() + const memoryDir = path.join(workspace, 'memory') + + await mkdir(memoryDir, { recursive: true }) + + const journalPath = await resolveFileCI(memoryDir, 'JOURNAL.jsonl') + + const entry: JournalEntry = { + ts: new Date().toISOString(), + tags, + text + } + + await appendFile(journalPath, JSON.stringify(entry) + '\n', 'utf-8') + + logger.info('Journal entry appended via tool', { agentId: this.agentId, tags }) + return { + content: [{ type: 'text' as const, text: `Journal entry added at ${entry.ts}.` }] + } + } + + private async memorySearch(args: Record) { + const query = args.query ?? '' + const tagFilter = args.tag ?? '' + const limit = Math.max(1, parseInt(args.limit ?? '20', 10) || 20) + + const workspace = await this.getWorkspacePath() + const memoryDir = path.join(workspace, 'memory') + const journalPath = await resolveFileCI(memoryDir, 'JOURNAL.jsonl') + + let fileContent: string + try { + fileContent = await readFile(journalPath, 'utf-8') + } catch { + return { content: [{ type: 'text' as const, text: 'No journal entries found.' }] } + } + + const queryLower = query.toLowerCase() + const tagLower = tagFilter.toLowerCase() + const matches: JournalEntry[] = [] + + for (const line of fileContent.split('\n')) { + if (!line.trim()) continue + let entry: JournalEntry + try { + entry = JSON.parse(line) + } catch { + continue + } + if (tagFilter && !entry.tags?.some((t) => t.toLowerCase() === tagLower)) continue + if (query && !entry.text.toLowerCase().includes(queryLower)) continue + matches.push(entry) + } + + // Return last N entries in reverse-chronological order + const result = matches.slice(-limit).reverse() + + if (result.length === 0) { + return { content: [{ type: 'text' as const, text: 'No matching journal entries found.' }] } + } + + logger.info('Journal search via tool', { agentId: this.agentId, query, tag: tagFilter, resultCount: result.length }) + return { + content: [{ type: 'text' as const, text: JSON.stringify(result, null, 2) }] + } + } + + private async removeJob(args: Record) { + const id = args.id + if (!id) throw new McpError(ErrorCode.InvalidParams, "'id' is required for remove") + + const deleted = await taskService.deleteTask(this.agentId, id) + if (!deleted) throw new McpError(ErrorCode.InvalidParams, `Job "${id}" not found`) + + logger.info('Cron job removed via tool', { agentId: this.agentId, taskId: id }) + return { + content: [{ type: 'text' as const, text: `Job "${id}" removed.` }] + } + } +} + +export default ClawServer diff --git a/src/main/services/CodeToolsService.ts b/src/main/services/CodeToolsService.ts index 65bddf6d3d0..8721606de5a 100644 --- a/src/main/services/CodeToolsService.ts +++ b/src/main/services/CodeToolsService.ts @@ -1201,7 +1201,8 @@ class CodeToolsService { detached: true, stdio: 'ignore', cwd: directory, - env: processEnv + env: processEnv, + shell: isWin }) const successMessage = `Launched ${cliTool} in new terminal window` diff --git a/src/main/services/OpenClawService.ts b/src/main/services/OpenClawService.ts index 43209ab826a..4348da2b8b7 100644 --- a/src/main/services/OpenClawService.ts +++ b/src/main/services/OpenClawService.ts @@ -244,8 +244,14 @@ class OpenClawService { const npmArgs = ['install', '-g', packageName] if (registryArg) npmArgs.push(registryArg) - // Keep the command string for logging and sudo retry - const npmCommand = `"${npmPath}" install -g ${packageName} ${registryArg}`.trim() + // Keep the command string for logging and sudo retry. + // On macOS/Linux, prepend npm's parent dir to PATH so that sudo (which runs in a + // clean environment without user PATH) can resolve `node` via npm's shebang + // (#!/usr/bin/env node). + const nodeDir = path.dirname(npmPath) + const npmCommand = isWin + ? `"${npmPath}" install -g ${packageName} ${registryArg}`.trim() + : `PATH="${nodeDir}:$PATH" "${npmPath}" install -g ${packageName} ${registryArg}`.trim() // On Windows, wrap npm path in quotes if it contains spaces and is not already quoted const needsQuotes = isWin && npmPath.includes(' ') && !npmPath.startsWith('"') @@ -347,8 +353,12 @@ class OpenClawService { const npmArgs = ['uninstall', '-g', 'openclaw', '@qingchencloud/openclaw-zh'] - // Keep the command string for logging and sudo retry - const npmCommand = `"${npmPath}" uninstall -g openclaw @qingchencloud/openclaw-zh` + // Keep the command string for logging and sudo retry. + // On macOS/Linux, prepend npm's parent dir to PATH so that sudo can resolve `node`. + const nodeDir = path.dirname(npmPath) + const npmCommand = isWin + ? `"${npmPath}" uninstall -g openclaw @qingchencloud/openclaw-zh` + : `PATH="${nodeDir}:$PATH" "${npmPath}" uninstall -g openclaw @qingchencloud/openclaw-zh` // On Windows, wrap npm path in quotes if it contains spaces and is not already quoted const needsQuotes = isWin && npmPath.includes(' ') && !npmPath.startsWith('"') diff --git a/src/main/services/agents/BaseService.ts b/src/main/services/agents/BaseService.ts index f88edcc9a4d..4bb24629f9f 100644 --- a/src/main/services/agents/BaseService.ts +++ b/src/main/services/agents/BaseService.ts @@ -54,7 +54,7 @@ export abstract class BaseService { ): Promise<{ tools: Tool[]; legacyIdMap: Map }> { const tools: Tool[] = [] const legacyIdMap = new Map() - if (agentType === 'claude-code') { + if (agentType === 'claude-code' || agentType === 'cherry-claw') { tools.push(...builtinTools) } if (ids && ids.length > 0) { @@ -139,7 +139,7 @@ export abstract class BaseService { } public async listSlashCommands(agentType: AgentType): Promise { - if (agentType === 'claude-code') { + if (agentType === 'claude-code' || agentType === 'cherry-claw') { return builtinSlashCommands } return [] diff --git a/src/main/services/agents/database/schema/index.ts b/src/main/services/agents/database/schema/index.ts index 553f94a038c..7e2c6598480 100644 --- a/src/main/services/agents/database/schema/index.ts +++ b/src/main/services/agents/database/schema/index.ts @@ -6,3 +6,4 @@ export * from './agents.schema' export * from './messages.schema' export * from './migrations.schema' export * from './sessions.schema' +export * from './tasks.schema' diff --git a/src/main/services/agents/database/schema/tasks.schema.ts b/src/main/services/agents/database/schema/tasks.schema.ts new file mode 100644 index 00000000000..096697ee714 --- /dev/null +++ b/src/main/services/agents/database/schema/tasks.schema.ts @@ -0,0 +1,60 @@ +/** + * Drizzle ORM schema for scheduled tasks tables + */ + +import { foreignKey, index, integer, sqliteTable, text } from 'drizzle-orm/sqlite-core' + +import { agentsTable } from './agents.schema' + +export const scheduledTasksTable = sqliteTable('scheduled_tasks', { + id: text('id').primaryKey(), + agent_id: text('agent_id').notNull(), + name: text('name').notNull(), + prompt: text('prompt').notNull(), + schedule_type: text('schedule_type').notNull(), // 'cron' | 'interval' | 'once' + schedule_value: text('schedule_value').notNull(), // cron expression, milliseconds as string, or ISO timestamp + context_mode: text('context_mode').notNull().default('session'), // 'session' | 'isolated' + next_run: text('next_run'), + last_run: text('last_run'), + last_result: text('last_result'), + status: text('status').notNull().default('active'), // 'active' | 'paused' | 'completed' + created_at: text('created_at').notNull(), + updated_at: text('updated_at').notNull() +}) + +export const taskRunLogsTable = sqliteTable('task_run_logs', { + id: integer('id').primaryKey({ autoIncrement: true }), + task_id: text('task_id').notNull(), + run_at: text('run_at').notNull(), + duration_ms: integer('duration_ms').notNull(), + status: text('status').notNull(), // 'success' | 'error' + result: text('result'), + error: text('error') +}) + +// Foreign keys +export const scheduledTasksFkAgent = foreignKey({ + columns: [scheduledTasksTable.agent_id], + foreignColumns: [agentsTable.id], + name: 'fk_scheduled_tasks_agent_id' +}).onDelete('cascade') + +export const taskRunLogsFkTask = foreignKey({ + columns: [taskRunLogsTable.task_id], + foreignColumns: [scheduledTasksTable.id], + name: 'fk_task_run_logs_task_id' +}).onDelete('cascade') + +// Indexes for scheduled_tasks table +export const tasksAgentIdIdx = index('idx_tasks_agent_id').on(scheduledTasksTable.agent_id) +export const tasksNextRunIdx = index('idx_tasks_next_run').on(scheduledTasksTable.next_run) +export const tasksStatusIdx = index('idx_tasks_status').on(scheduledTasksTable.status) + +// Indexes for task_run_logs table +export const taskRunLogsTaskIdIdx = index('idx_task_run_logs_task_id').on(taskRunLogsTable.task_id) + +// Type exports +export type TaskRow = typeof scheduledTasksTable.$inferSelect +export type InsertTaskRow = typeof scheduledTasksTable.$inferInsert +export type TaskRunLogRow = typeof taskRunLogsTable.$inferSelect +export type InsertTaskRunLogRow = typeof taskRunLogsTable.$inferInsert diff --git a/src/main/services/agents/drizzle.config.ts b/src/main/services/agents/drizzle.config.ts index 8eb8ce3a73c..dbda29b72f8 100644 --- a/src/main/services/agents/drizzle.config.ts +++ b/src/main/services/agents/drizzle.config.ts @@ -18,16 +18,12 @@ * Drizzle Kit configuration for agents database */ -import os from 'node:os' import path from 'node:path' import { defineConfig } from 'drizzle-kit' import { app } from 'electron' function getDbPath() { - if (process.env.NODE_ENV === 'development') { - return path.join(os.homedir(), '.cherrystudio', 'data', 'agents.db') - } return path.join(app.getPath('userData'), 'Data', 'agents.db') } diff --git a/src/main/services/agents/interfaces/AgentStreamInterface.ts b/src/main/services/agents/interfaces/AgentStreamInterface.ts index 485b3584fc8..7a584497c4c 100644 --- a/src/main/services/agents/interfaces/AgentStreamInterface.ts +++ b/src/main/services/agents/interfaces/AgentStreamInterface.ts @@ -19,6 +19,8 @@ export interface AgentStream extends EventEmitter { emit(event: 'data', data: AgentStreamEvent): boolean on(event: 'data', listener: (data: AgentStreamEvent) => void): this once(event: 'data', listener: (data: AgentStreamEvent) => void): this + /** SDK session_id captured from the init message, used for resume. */ + sdkSessionId?: string } export interface AgentThinkingOptions { diff --git a/src/main/services/agents/services/AgentService.ts b/src/main/services/agents/services/AgentService.ts index 661cd4200f4..a2dc6a43387 100644 --- a/src/main/services/agents/services/AgentService.ts +++ b/src/main/services/agents/services/AgentService.ts @@ -1,4 +1,5 @@ import { loggerService } from '@logger' +import { modelsService } from '@main/apiServer/services/models' import { pluginService } from '@main/services/agents/plugins/PluginService' import type { AgentEntity, @@ -104,6 +105,9 @@ export class AgentService extends BaseService { } async listAgents(options: ListOptions = {}): Promise<{ agents: AgentEntity[]; total: number }> { + // Ensure a default CherryClaw agent exists + await this.ensureDefaultCherryClaw() + // Build query with pagination const database = await this.getDatabase() const totalResult = await database.select({ count: count() }).from(agentsTable) @@ -131,9 +135,70 @@ export class AgentService extends BaseService { agent.allowed_tools = this.normalizeAllowedTools(agent.allowed_tools, agent.tools, legacyIdMap) } + // Sort cherry-claw agents to the top + agents.sort((a, b) => { + if (a.type === 'cherry-claw' && b.type !== 'cherry-claw') return -1 + if (a.type !== 'cherry-claw' && b.type === 'cherry-claw') return 1 + return 0 + }) + return { agents, total: totalResult[0].count } } + private async ensureDefaultCherryClaw(): Promise { + try { + const database = await this.getDatabase() + const existing = await database + .select({ id: agentsTable.id }) + .from(agentsTable) + .where(eq(agentsTable.type, 'cherry-claw')) + .limit(1) + + if (existing.length > 0) return + + // Find the first available Anthropic model + const modelsRes = await modelsService.getModels({ providerType: 'anthropic', limit: 1 }) + const firstModel = modelsRes.data?.[0] + if (!firstModel) { + logger.warn('No Anthropic models available — skipping default CherryClaw creation') + return + } + + const req: CreateAgentRequest = { + type: 'cherry-claw', + name: 'CherryClaw', + description: 'Default autonomous CherryClaw agent', + model: firstModel.id, + accessible_paths: [], + configuration: { + permission_mode: 'bypassPermissions', + max_turns: 100, + soul_enabled: true, + scheduler_enabled: false, + scheduler_type: 'interval', + heartbeat_enabled: true, + heartbeat_interval: 30 + } + } + + const agent = await this.createAgent(req) + logger.info('Auto-created default CherryClaw agent', { id: agent.id }) + + // Create a default session for the auto-created agent + const { SessionService } = await import('./SessionService') + const sessionSvc = SessionService.getInstance() + await sessionSvc.createSession(agent.id, {}) + logger.info('Default session created for CherryClaw agent', { agentId: agent.id }) + + // Create the heartbeat scheduled task + const { schedulerService } = await import('./SchedulerService') + await schedulerService.ensureHeartbeatTask(agent.id, 30) + logger.info('Heartbeat task created for CherryClaw agent', { agentId: agent.id }) + } catch (error) { + logger.error('Failed to ensure default CherryClaw agent', error as Error) + } + } + async updateAgent( id: string, updates: UpdateAgentRequest, diff --git a/src/main/services/agents/services/AgentServiceRegistry.ts b/src/main/services/agents/services/AgentServiceRegistry.ts new file mode 100644 index 00000000000..4dad3640c59 --- /dev/null +++ b/src/main/services/agents/services/AgentServiceRegistry.ts @@ -0,0 +1,41 @@ +import { loggerService } from '@logger' +import type { AgentType } from '@types' + +import type { AgentServiceInterface } from '../interfaces/AgentStreamInterface' + +const logger = loggerService.withContext('AgentServiceRegistry') + +/** + * Registry mapping AgentType to the service that handles invocations for that type. + * Used by SessionMessageService to dispatch to the correct agent service. + */ +class AgentServiceRegistry { + private static instance: AgentServiceRegistry | null = null + private readonly services = new Map() + + static getInstance(): AgentServiceRegistry { + if (!AgentServiceRegistry.instance) { + AgentServiceRegistry.instance = new AgentServiceRegistry() + } + return AgentServiceRegistry.instance + } + + register(agentType: AgentType, service: AgentServiceInterface): void { + logger.info('Registering agent service', { agentType }) + this.services.set(agentType, service) + } + + getService(agentType: AgentType): AgentServiceInterface { + const service = this.services.get(agentType) + if (!service) { + throw new Error(`No agent service registered for type: ${agentType}`) + } + return service + } + + hasService(agentType: AgentType): boolean { + return this.services.has(agentType) + } +} + +export const agentServiceRegistry = AgentServiceRegistry.getInstance() diff --git a/src/main/services/agents/services/SchedulerService.ts b/src/main/services/agents/services/SchedulerService.ts new file mode 100644 index 00000000000..ff175fc6380 --- /dev/null +++ b/src/main/services/agents/services/SchedulerService.ts @@ -0,0 +1,335 @@ +import { loggerService } from '@logger' +import type { CherryClawConfiguration, ScheduledTaskEntity } from '@types' + +import { agentService } from './AgentService' +import { channelManager } from './channels/ChannelManager' +import { CherryClawService } from './cherryclaw' +import { sessionMessageService } from './SessionMessageService' +import { sessionService } from './SessionService' +import { taskService } from './TaskService' + +const logger = loggerService.withContext('SchedulerService') + +const POLL_INTERVAL_MS = 60_000 +const MAX_CONSECUTIVE_ERRORS = 3 + +type RunningTask = { + taskId: string + agentId: string + abortController: AbortController + consecutiveErrors: number +} + +class SchedulerService { + private static instance: SchedulerService | null = null + private pollTimer: ReturnType | null = null + private running = false + private readonly activeTasks = new Map() + private cherryClawService: CherryClawService | null = null + + static getInstance(): SchedulerService { + if (!SchedulerService.instance) { + SchedulerService.instance = new SchedulerService() + } + return SchedulerService.instance + } + + private getCherryClawService(): CherryClawService { + if (!this.cherryClawService) { + this.cherryClawService = new CherryClawService() + } + return this.cherryClawService + } + + startLoop(): void { + if (this.running) { + logger.debug('Scheduler loop already running') + return + } + this.running = true + logger.info('Scheduler poll loop started') + this.poll() + } + + stopLoop(): void { + this.running = false + if (this.pollTimer) { + clearTimeout(this.pollTimer) + this.pollTimer = null + } + // Abort all running tasks + for (const [taskId, rt] of this.activeTasks) { + rt.abortController.abort() + logger.info('Aborted running task on shutdown', { taskId }) + } + this.activeTasks.clear() + logger.info('Scheduler poll loop stopped') + } + + // Keep backward-compatible aliases used by agent handlers and main/index.ts + stopScheduler(_agentId: string): void { + // No-op — the poll loop handles everything via DB state. + // Individual task abort is handled by stopLoop or task deletion. + } + + startScheduler(_agent: any): void { + // No-op — the poll loop picks up tasks from DB automatically. + // Just ensure the loop is running. + this.startLoop() + } + + stopAll(): void { + this.stopLoop() + } + + async restoreSchedulers(): Promise { + this.startLoop() + } + + /** + * Ensure a heartbeat task exists for the given agent. + * Creates one if missing, or updates the interval if it changed. + */ + async ensureHeartbeatTask(agentId: string, intervalMinutes: number = 30): Promise { + const { tasks } = await taskService.listTasks(agentId, { includeHeartbeat: true }) + const existing = tasks.find((t) => t.name === 'heartbeat') + + if (existing) { + const currentInterval = existing.schedule_value + const newInterval = String(intervalMinutes) + if (currentInterval !== newInterval) { + await taskService.updateTask(agentId, existing.id, { schedule_value: newInterval }) + logger.info('Updated heartbeat task interval', { agentId, interval: intervalMinutes }) + } + } else { + await taskService.createTask(agentId, { + name: 'heartbeat', + prompt: '__heartbeat__', + schedule_type: 'interval', + schedule_value: String(intervalMinutes), + context_mode: 'session' + }) + logger.info('Created heartbeat task', { agentId, interval: intervalMinutes }) + } + } + + /** Manually trigger a task run (from UI). Returns immediately; task runs in background. */ + async runTaskNow(agentId: string, taskId: string): Promise { + const task = await taskService.getTask(agentId, taskId) + if (!task) throw new Error(`Task not found: ${taskId}`) + if (this.activeTasks.has(task.id)) throw new Error('Task is already running') + + // Fire and forget + this.runTask(task).catch((error) => { + logger.error('Unhandled error in manual runTask', { + taskId: task.id, + error: error instanceof Error ? error.message : String(error) + }) + }) + } + + private poll(): void { + if (!this.running) return + + this.tick() + .catch((error) => { + logger.error('Error in scheduler tick', { + error: error instanceof Error ? error.message : String(error) + }) + }) + .finally(() => { + if (this.running) { + this.pollTimer = setTimeout(() => this.poll(), POLL_INTERVAL_MS) + } + }) + } + + private async tick(): Promise { + const dueTasks = await taskService.getDueTasks() + if (dueTasks.length > 0) { + logger.info('Found due tasks', { count: dueTasks.length }) + } + + for (const task of dueTasks) { + // Skip if already running + if (this.activeTasks.has(task.id)) { + logger.debug('Task already running, skipping', { taskId: task.id }) + continue + } + + // Fire and forget — don't block the poll loop + this.runTask(task).catch((error) => { + logger.error('Unhandled error in runTask', { + taskId: task.id, + error: error instanceof Error ? error.message : String(error) + }) + }) + } + } + + private async runTask(task: ScheduledTaskEntity): Promise { + const startTime = Date.now() + const abortController = new AbortController() + const runningTask: RunningTask = { + taskId: task.id, + agentId: task.agent_id, + abortController, + consecutiveErrors: 0 + } + this.activeTasks.set(task.id, runningTask) + + let result: string | null = null + let error: string | null = null + + try { + logger.info('Running scheduled task', { taskId: task.id, agentId: task.agent_id }) + + const agent = await agentService.getAgent(task.agent_id) + if (!agent) { + throw new Error(`Agent not found: ${task.agent_id}`) + } + + const config = (agent.configuration ?? {}) as CherryClawConfiguration + const workspacePath = agent.accessible_paths?.[0] + + // For heartbeat tasks, read prompt from workspace heartbeat.md file + let fullPrompt = task.prompt + if (task.name === 'heartbeat') { + if (config.heartbeat_enabled === false || !workspacePath) { + logger.debug('Heartbeat task skipped (disabled or no workspace)', { taskId: task.id }) + // Still update next_run so it doesn't fire again immediately + const nextRun = taskService.computeNextRun(task) + await taskService.updateTaskAfterRun(task.id, nextRun, 'Skipped (disabled)') + this.activeTasks.delete(task.id) + return + } + const clawService = this.getCherryClawService() + const heartbeatContent = await clawService.heartbeatReader.readHeartbeat(workspacePath) + if (!heartbeatContent) { + logger.debug('Heartbeat task skipped (no heartbeat.md)', { taskId: task.id }) + const nextRun = taskService.computeNextRun(task) + await taskService.updateTaskAfterRun(task.id, nextRun, 'Skipped (no file)') + this.activeTasks.delete(task.id) + return + } + fullPrompt = [ + '[Heartbeat]', + 'This is a periodic heartbeat. The instructions below are from your heartbeat.md file.', + 'Process each item, take action where possible, and use the notify tool to alert the user of important results.', + '', + '---', + heartbeatContent + ].join('\n') + } + + // Find or create session based on context mode + let sessionId: string + if (task.context_mode === 'session') { + const { sessions } = await sessionService.listSessions(task.agent_id, { limit: 1 }) + if (sessions.length === 0) { + const newSession = await sessionService.createSession(task.agent_id, {}) + sessionId = newSession!.id + } else { + sessionId = sessions[0].id + } + } else { + const newSession = await sessionService.createSession(task.agent_id, {}) + sessionId = newSession!.id + } + + const session = await sessionService.getSession(task.agent_id, sessionId) + if (!session) { + throw new Error(`Session not found: ${sessionId}`) + } + + // Send as user message (triggers agent response) + const { stream, completion } = await sessionMessageService.createSessionMessage( + session, + { content: fullPrompt }, + abortController, + { persist: true } + ) + + // Drain the stream so completion resolves + const reader = stream.getReader() + while (!(await reader.read()).done) { + // discard chunks + } + await completion + + result = 'Completed' + logger.info('Task completed', { taskId: task.id, durationMs: Date.now() - startTime }) + } catch (err) { + error = err instanceof Error ? err.message : String(err) + logger.error('Task failed', { taskId: task.id, error }) + + // Track consecutive errors + runningTask.consecutiveErrors++ + if (runningTask.consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) { + logger.warn('Pausing task after consecutive errors', { + taskId: task.id, + errors: runningTask.consecutiveErrors + }) + await taskService.updateTask(task.agent_id, task.id, { status: 'paused' }) + } + } finally { + this.activeTasks.delete(task.id) + } + + const durationMs = Date.now() - startTime + + // Log the run + await taskService.logTaskRun({ + task_id: task.id, + run_at: new Date().toISOString(), + duration_ms: durationMs, + status: error ? 'error' : 'success', + result, + error + }) + + // Compute next run and update task + const nextRun = taskService.computeNextRun(task) + const resultSummary = error ? `Error: ${error}` : result ? result.slice(0, 200) : 'Completed' + await taskService.updateTaskAfterRun(task.id, nextRun, resultSummary) + + // Send notification to notify-enabled channels + await this.notifyTaskResult(task, durationMs, error) + } + + private async notifyTaskResult(task: ScheduledTaskEntity, durationMs: number, error: string | null): Promise { + try { + const adapters = channelManager.getNotifyAdapters(task.agent_id) + if (adapters.length === 0) return + + const status = error ? 'failed' : 'completed' + const durationSec = Math.round(durationMs / 1000) + const lines = [ + `[Task ${status}] ${task.name}`, + `Duration: ${durationSec}s`, + ...(error ? [`Error: ${error}`] : []) + ] + const text = lines.join('\n') + + for (const adapter of adapters) { + for (const chatId of adapter.notifyChatIds) { + adapter.sendMessage(chatId, text).catch((err) => { + logger.warn('Failed to send task notification', { + taskId: task.id, + channelId: adapter.channelId, + chatId, + error: err instanceof Error ? err.message : String(err) + }) + }) + } + } + } catch (err) { + logger.warn('Error sending task notifications', { + taskId: task.id, + error: err instanceof Error ? err.message : String(err) + }) + } + } +} + +export const schedulerService = SchedulerService.getInstance() diff --git a/src/main/services/agents/services/SessionMessageService.ts b/src/main/services/agents/services/SessionMessageService.ts index e7d5a7bfe58..69da077032c 100644 --- a/src/main/services/agents/services/SessionMessageService.ts +++ b/src/main/services/agents/services/SessionMessageService.ts @@ -1,5 +1,8 @@ +import { randomUUID } from 'node:crypto' + import { loggerService } from '@logger' import type { + AgentPersistedMessage, AgentSessionMessageEntity, CreateSessionMessageRequest, GetAgentSessionResponse, @@ -10,8 +13,9 @@ import { and, desc, eq, not } from 'drizzle-orm' import { BaseService } from '../BaseService' import { sessionMessagesTable } from '../database/schema' +import { agentMessageRepository } from '../database/sessionMessageRepository' import type { AgentStreamEvent } from '../interfaces/AgentStreamInterface' -import ClaudeCodeService from './claudecode' +import { agentServiceRegistry } from './AgentServiceRegistry' const logger = loggerService.withContext('SessionMessageService') @@ -23,6 +27,11 @@ type SessionStreamResult = { }> } +export type CreateMessageOptions = { + /** When true, persist user+assistant messages to DB on stream complete. Use for headless callers (channels, scheduler) where no UI handles persistence. */ + persist?: boolean +} + // Ensure errors emitted through SSE are serializable function serializeError(error: unknown): { message: string; name?: string; stack?: string } { if (error instanceof Error) { @@ -55,7 +64,7 @@ class TextStreamAccumulator { break case 'text-delta': if (part.text) { - this.textBuffer += part.text + this.textBuffer = part.text } break case 'text-end': { @@ -91,11 +100,14 @@ class TextStreamAccumulator { break } } + + getText(): string { + return (this.totalText + this.textBuffer).replace(/\n+$/, '') + } } export class SessionMessageService extends BaseService { private static instance: SessionMessageService | null = null - private cc: ClaudeCodeService = new ClaudeCodeService() static getInstance(): SessionMessageService { if (!SessionMessageService.instance) { @@ -151,26 +163,28 @@ export class SessionMessageService extends BaseService { async createSessionMessage( session: GetAgentSessionResponse, messageData: CreateSessionMessageRequest, - abortController: AbortController + abortController: AbortController, + options?: CreateMessageOptions ): Promise { - return await this.startSessionMessageStream(session, messageData, abortController) + return await this.startSessionMessageStream(session, messageData, abortController, options) } private async startSessionMessageStream( session: GetAgentSessionResponse, req: CreateSessionMessageRequest, - abortController: AbortController + abortController: AbortController, + options?: CreateMessageOptions ): Promise { const agentSessionId = await this.getLastAgentSessionId(session.id) logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId }) - if (session.agent_type !== 'claude-code') { - // TODO: Implement support for other agent types + if (!agentServiceRegistry.hasService(session.agent_type)) { logger.error('Unsupported agent type for streaming:', { agent_type: session.agent_type }) - throw new Error('Unsupported agent type for streaming') + throw new Error(`Unsupported agent type for streaming: ${session.agent_type}`) } - const claudeStream = await this.cc.invoke(req.content, session, abortController, agentSessionId, { + const service = agentServiceRegistry.getService(session.agent_type) + const claudeStream = await service.invoke(req.content, session, abortController, agentSessionId, { effort: req.effort, thinking: req.thinking }) @@ -229,7 +243,23 @@ export class SessionMessageService extends BaseService { case 'complete': { cleanup() controller.close() - resolveCompletion({}) + if (options?.persist) { + // Read SDK session_id from the stream object (set by ClaudeCodeService on init) + const resolvedSessionId = claudeStream.sdkSessionId || agentSessionId + logger.debug('Persisting headless exchange with agent session ID', { + sdkSessionId: claudeStream.sdkSessionId, + fallback: agentSessionId, + resolved: resolvedSessionId + }) + this.persistHeadlessExchange(session, req.content, accumulator.getText(), resolvedSessionId) + .then(resolveCompletion) + .catch((err) => { + logger.error('Failed to persist headless exchange', err as Error) + resolveCompletion({}) + }) + } else { + resolveCompletion({}) + } break } @@ -263,6 +293,83 @@ export class SessionMessageService extends BaseService { return { stream, completion } } + /** + * Persist user + assistant messages for headless callers (channels, scheduler) + * that have no UI to handle persistence via IPC. + */ + private async persistHeadlessExchange( + session: GetAgentSessionResponse, + userContent: string, + assistantContent: string, + agentSessionId: string + ): Promise<{ userMessage?: AgentSessionMessageEntity; assistantMessage?: AgentSessionMessageEntity }> { + const now = new Date().toISOString() + const userMsgId = randomUUID() + const assistantMsgId = randomUUID() + const userBlockId = randomUUID() + const assistantBlockId = randomUUID() + const topicId = `agent-session:${session.id}` + + const userPayload = { + message: { + id: userMsgId, + role: 'user' as const, + assistantId: session.agent_id, + topicId, + createdAt: now, + status: 'success', + blocks: [userBlockId] + }, + blocks: [ + { + id: userBlockId, + messageId: userMsgId, + type: 'main_text', + createdAt: now, + status: 'success', + content: userContent + } + ] + } as AgentPersistedMessage + + const assistantPayload = { + message: { + id: assistantMsgId, + role: 'assistant' as const, + assistantId: session.agent_id, + topicId, + createdAt: now, + status: 'success', + blocks: [assistantBlockId] + }, + blocks: [ + { + id: assistantBlockId, + messageId: assistantMsgId, + type: 'main_text', + createdAt: now, + status: 'success', + content: assistantContent + } + ] + } as AgentPersistedMessage + + const result = await agentMessageRepository.persistExchange({ + sessionId: session.id, + agentSessionId, + user: { payload: userPayload, createdAt: now }, + assistant: { payload: assistantPayload, createdAt: now } + }) + + logger.info('Persisted headless exchange', { + sessionId: session.id, + userMessageId: userMsgId, + assistantMessageId: assistantMsgId + }) + + return result + } + private async getLastAgentSessionId(sessionId: string): Promise { try { const database = await this.getDatabase() diff --git a/src/main/services/agents/services/SessionService.ts b/src/main/services/agents/services/SessionService.ts index 846e9bae2f3..9c3c5417632 100644 --- a/src/main/services/agents/services/SessionService.ts +++ b/src/main/services/agents/services/SessionService.ts @@ -37,7 +37,7 @@ export class SessionService extends BaseService { const commands: SlashCommand[] = [] // Add builtin slash commands - if (agentType === 'claude-code') { + if (agentType === 'claude-code' || agentType === 'cherry-claw') { commands.push(...builtinSlashCommands) } diff --git a/src/main/services/agents/services/TaskService.ts b/src/main/services/agents/services/TaskService.ts new file mode 100644 index 00000000000..bcb66b28255 --- /dev/null +++ b/src/main/services/agents/services/TaskService.ts @@ -0,0 +1,281 @@ +import { loggerService } from '@logger' +import type { CreateTaskRequest, ListOptions, ScheduledTaskEntity, TaskRunLogEntity, UpdateTaskRequest } from '@types' +import { and, asc, count, desc, eq, lte, ne } from 'drizzle-orm' + +import { BaseService } from '../BaseService' +import { + type InsertTaskRow, + type InsertTaskRunLogRow, + scheduledTasksTable, + type TaskRow, + taskRunLogsTable +} from '../database/schema' + +const logger = loggerService.withContext('TaskService') + +export class TaskService extends BaseService { + private static instance: TaskService | null = null + + static getInstance(): TaskService { + if (!TaskService.instance) { + TaskService.instance = new TaskService() + } + return TaskService.instance + } + + async createTask(agentId: string, req: CreateTaskRequest): Promise { + const id = `task_${Date.now()}_${Math.random().toString(36).substring(2, 11)}` + const now = new Date().toISOString() + + const nextRun = this.computeInitialNextRun(req.schedule_type, req.schedule_value) + + const insertData: InsertTaskRow = { + id, + agent_id: agentId, + name: req.name, + prompt: req.prompt, + schedule_type: req.schedule_type, + schedule_value: req.schedule_value, + context_mode: req.context_mode ?? 'session', + next_run: nextRun, + status: 'active', + created_at: now, + updated_at: now + } + + const database = await this.getDatabase() + await database.insert(scheduledTasksTable).values(insertData) + const result = await database.select().from(scheduledTasksTable).where(eq(scheduledTasksTable.id, id)).limit(1) + + if (!result[0]) { + throw new Error('Failed to create task') + } + + logger.info('Task created', { taskId: id, agentId }) + return result[0] as ScheduledTaskEntity + } + + async getTask(agentId: string, taskId: string): Promise { + const database = await this.getDatabase() + const result = await database + .select() + .from(scheduledTasksTable) + .where(and(eq(scheduledTasksTable.id, taskId), eq(scheduledTasksTable.agent_id, agentId))) + .limit(1) + + return (result[0] as ScheduledTaskEntity) ?? null + } + + async listTasks( + agentId: string, + options: ListOptions & { includeHeartbeat?: boolean } = {} + ): Promise<{ tasks: ScheduledTaskEntity[]; total: number }> { + const database = await this.getDatabase() + const { includeHeartbeat = false, ...paginationOptions } = options + + // By default, exclude heartbeat tasks from the listing + const whereCondition = includeHeartbeat + ? eq(scheduledTasksTable.agent_id, agentId) + : and(eq(scheduledTasksTable.agent_id, agentId), ne(scheduledTasksTable.name, 'heartbeat')) + + const totalResult = await database.select({ count: count() }).from(scheduledTasksTable).where(whereCondition) + + const baseQuery = database + .select() + .from(scheduledTasksTable) + .where(whereCondition) + .orderBy(desc(scheduledTasksTable.created_at)) + + const result = + paginationOptions.limit !== undefined + ? paginationOptions.offset !== undefined + ? await baseQuery.limit(paginationOptions.limit).offset(paginationOptions.offset) + : await baseQuery.limit(paginationOptions.limit) + : await baseQuery + + return { + tasks: result as ScheduledTaskEntity[], + total: totalResult[0].count + } + } + + async updateTask(agentId: string, taskId: string, updates: UpdateTaskRequest): Promise { + const existing = await this.getTask(agentId, taskId) + if (!existing) return null + + const now = new Date().toISOString() + const updateData: Partial = { updated_at: now } + + if (updates.name !== undefined) updateData.name = updates.name + if (updates.prompt !== undefined) updateData.prompt = updates.prompt + if (updates.context_mode !== undefined) updateData.context_mode = updates.context_mode + if (updates.status !== undefined) updateData.status = updates.status + + // If schedule type or value changed, recompute next_run + if (updates.schedule_type !== undefined || updates.schedule_value !== undefined) { + const schedType = updates.schedule_type ?? existing.schedule_type + const schedValue = updates.schedule_value ?? existing.schedule_value + updateData.schedule_type = schedType + updateData.schedule_value = schedValue + updateData.next_run = this.computeInitialNextRun(schedType, schedValue) + } + + // If resuming from paused, recompute next_run + if (updates.status === 'active' && existing.status === 'paused') { + const schedType = updates.schedule_type ?? existing.schedule_type + const schedValue = updates.schedule_value ?? existing.schedule_value + updateData.next_run = this.computeInitialNextRun(schedType, schedValue) + } + + const database = await this.getDatabase() + await database + .update(scheduledTasksTable) + .set(updateData) + .where(and(eq(scheduledTasksTable.id, taskId), eq(scheduledTasksTable.agent_id, agentId))) + + logger.info('Task updated', { taskId, agentId }) + return this.getTask(agentId, taskId) + } + + async deleteTask(agentId: string, taskId: string): Promise { + const database = await this.getDatabase() + const result = await database + .delete(scheduledTasksTable) + .where(and(eq(scheduledTasksTable.id, taskId), eq(scheduledTasksTable.agent_id, agentId))) + + logger.info('Task deleted', { taskId, agentId }) + return result.rowsAffected > 0 + } + + // --- Due tasks (used by SchedulerService poll loop) --- + + async getDueTasks(): Promise { + const now = new Date().toISOString() + const database = await this.getDatabase() + const result = await database + .select() + .from(scheduledTasksTable) + .where(and(eq(scheduledTasksTable.status, 'active'), lte(scheduledTasksTable.next_run, now))) + .orderBy(asc(scheduledTasksTable.next_run)) + + return result as ScheduledTaskEntity[] + } + + async updateTaskAfterRun(taskId: string, nextRun: string | null, lastResult: string): Promise { + const now = new Date().toISOString() + const updateData: Partial = { + last_run: now, + last_result: lastResult, + next_run: nextRun, + updated_at: now + } + + // Mark one-time tasks as completed + if (nextRun === null) { + updateData.status = 'completed' + } + + const database = await this.getDatabase() + await database.update(scheduledTasksTable).set(updateData).where(eq(scheduledTasksTable.id, taskId)) + } + + // --- Task run logs --- + + async logTaskRun(log: Omit): Promise { + const database = await this.getDatabase() + await database.insert(taskRunLogsTable).values(log) + } + + async getTaskLogs(taskId: string, options: ListOptions = {}): Promise<{ logs: TaskRunLogEntity[]; total: number }> { + const database = await this.getDatabase() + + const totalResult = await database + .select({ count: count() }) + .from(taskRunLogsTable) + .where(eq(taskRunLogsTable.task_id, taskId)) + + const baseQuery = database + .select() + .from(taskRunLogsTable) + .where(eq(taskRunLogsTable.task_id, taskId)) + .orderBy(desc(taskRunLogsTable.run_at)) + + const result = + options.limit !== undefined + ? options.offset !== undefined + ? await baseQuery.limit(options.limit).offset(options.offset) + : await baseQuery.limit(options.limit) + : await baseQuery + + return { + logs: result as unknown as TaskRunLogEntity[], + total: totalResult[0].count + } + } + + // --- Next run computation (nanoclaw-inspired, drift-resistant) --- + + computeNextRun(task: ScheduledTaskEntity): string | null { + if (task.schedule_type === 'once') return null + + const now = Date.now() + + if (task.schedule_type === 'cron') { + try { + const { CronExpressionParser } = require('cron-parser') + const interval = CronExpressionParser.parse(task.schedule_value) + return interval.next().toISOString() + } catch { + logger.warn('Invalid cron expression', { taskId: task.id, cron: task.schedule_value }) + return null + } + } + + if (task.schedule_type === 'interval') { + const minutes = parseInt(task.schedule_value, 10) + const ms = minutes * 60_000 + if (!ms || ms <= 0) { + logger.warn('Invalid interval value', { taskId: task.id, value: task.schedule_value }) + return new Date(now + 60_000).toISOString() + } + + // Anchor to scheduled time to prevent drift + let next = new Date(task.next_run!).getTime() + ms + while (next <= now) { + next += ms + } + return new Date(next).toISOString() + } + + return null + } + + private computeInitialNextRun(scheduleType: string, scheduleValue: string): string | null { + const now = Date.now() + + switch (scheduleType) { + case 'cron': { + try { + const { CronExpressionParser } = require('cron-parser') + const interval = CronExpressionParser.parse(scheduleValue) + return interval.next().toISOString() + } catch { + return null + } + } + case 'interval': { + const minutes = parseInt(scheduleValue, 10) + if (!minutes || minutes <= 0) return null + return new Date(now + minutes * 60_000).toISOString() + } + case 'once': { + // schedule_value is an ISO timestamp for once + return scheduleValue + } + default: + return null + } + } +} + +export const taskService = TaskService.getInstance() diff --git a/src/main/services/agents/services/__tests__/AgentServiceRegistry.test.ts b/src/main/services/agents/services/__tests__/AgentServiceRegistry.test.ts new file mode 100644 index 00000000000..cd1a8b0148b --- /dev/null +++ b/src/main/services/agents/services/__tests__/AgentServiceRegistry.test.ts @@ -0,0 +1,61 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import type { AgentServiceInterface } from '../../interfaces/AgentStreamInterface' + +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn() }) + } +})) + +describe('AgentServiceRegistry', () => { + let agentServiceRegistry: any + + beforeEach(async () => { + vi.resetModules() + const mod = await import('../AgentServiceRegistry') + agentServiceRegistry = mod.agentServiceRegistry + }) + + const createMockService = (): AgentServiceInterface => ({ + invoke: vi.fn() + }) + + it('should register and retrieve a service', () => { + const mockService = createMockService() + const agentType = 'cherry-claw' as any + + agentServiceRegistry.register(agentType, mockService) + + expect(agentServiceRegistry.getService(agentType)).toBe(mockService) + }) + + it('should throw when getting an unregistered service', () => { + expect(() => agentServiceRegistry.getService('unknown-type' as any)).toThrow( + 'No agent service registered for type: unknown-type' + ) + }) + + it('should return true for hasService when registered, false otherwise', () => { + const mockService = createMockService() + const agentType = 'cherry-claw' as any + + expect(agentServiceRegistry.hasService(agentType)).toBe(false) + + agentServiceRegistry.register(agentType, mockService) + + expect(agentServiceRegistry.hasService(agentType)).toBe(true) + }) + + it('should overwrite a previously registered service', () => { + const firstService = createMockService() + const secondService = createMockService() + const agentType = 'cherry-claw' as any + + agentServiceRegistry.register(agentType, firstService) + agentServiceRegistry.register(agentType, secondService) + + expect(agentServiceRegistry.getService(agentType)).toBe(secondService) + expect(agentServiceRegistry.getService(agentType)).not.toBe(firstService) + }) +}) diff --git a/src/main/services/agents/services/__tests__/SchedulerService.test.ts b/src/main/services/agents/services/__tests__/SchedulerService.test.ts new file mode 100644 index 00000000000..c1ed3f959b4 --- /dev/null +++ b/src/main/services/agents/services/__tests__/SchedulerService.test.ts @@ -0,0 +1,160 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn(), silly: vi.fn() }) + } +})) + +vi.mock('../AgentService', () => ({ + agentService: { + listAgents: vi.fn().mockResolvedValue({ agents: [], total: 0 }), + getAgent: vi.fn() + } +})) + +vi.mock('../SessionService', () => ({ + sessionService: { + listSessions: vi.fn().mockResolvedValue({ sessions: [], total: 0 }), + getSession: vi.fn(), + createSession: vi.fn().mockResolvedValue({ id: 'session-1' }) + } +})) + +vi.mock('../SessionMessageService', () => ({ + sessionMessageService: { + createSessionMessage: vi.fn() + } +})) + +vi.mock('../TaskService', () => ({ + taskService: { + getDueTasks: vi.fn().mockResolvedValue([]), + updateTaskAfterRun: vi.fn(), + logTaskRun: vi.fn(), + computeNextRun: vi.fn().mockReturnValue(null), + updateTask: vi.fn() + } +})) + +vi.mock('../cherryclaw', () => ({ + CherryClawService: vi.fn().mockImplementation(() => ({ + heartbeatReader: { readHeartbeat: vi.fn().mockResolvedValue(undefined) } + })) +})) + +describe('SchedulerService', () => { + let SchedulerServiceModule: any + + beforeEach(async () => { + vi.useFakeTimers() + vi.resetModules() + SchedulerServiceModule = await import('../SchedulerService') + }) + + afterEach(() => { + const service = SchedulerServiceModule.schedulerService + service.stopAll() + vi.useRealTimers() + }) + + it('startLoop starts the poll loop', () => { + const service = SchedulerServiceModule.schedulerService + service.startLoop() + // Running is tracked internally; stopAll should not throw + service.stopAll() + }) + + it('startLoop is idempotent', () => { + const service = SchedulerServiceModule.schedulerService + service.startLoop() + service.startLoop() // second call should be a no-op + service.stopAll() + }) + + it('stopAll stops the loop and aborts active tasks', () => { + const service = SchedulerServiceModule.schedulerService + service.startLoop() + service.stopAll() + // Should not throw, loop should be stopped + }) + + it('restoreSchedulers starts the poll loop', async () => { + const service = SchedulerServiceModule.schedulerService + await service.restoreSchedulers() + // The poll loop should be running + service.stopAll() + }) + + it('stopScheduler is a no-op (poll loop handles everything)', () => { + const service = SchedulerServiceModule.schedulerService + // Should not throw for any agent ID + service.stopScheduler('nonexistent') + }) + + it('startScheduler starts the poll loop', () => { + const service = SchedulerServiceModule.schedulerService + service.startScheduler({ id: 'agent-1' }) + service.stopAll() + }) + + it('tick processes due tasks', async () => { + const { taskService } = await import('../TaskService') + const { agentService } = await import('../AgentService') + const { sessionService } = await import('../SessionService') + const { sessionMessageService } = await import('../SessionMessageService') + + const mockTask = { + id: 'task-1', + agent_id: 'agent-1', + name: 'Test task', + prompt: 'Do something', + schedule_type: 'once' as const, + schedule_value: new Date().toISOString(), + context_mode: 'session' as const, + next_run: new Date(Date.now() - 1000).toISOString(), + last_run: null, + last_result: null, + status: 'active' as const, + created_at: new Date().toISOString(), + updated_at: new Date().toISOString() + } + + vi.mocked(taskService.getDueTasks).mockResolvedValueOnce([mockTask]) + vi.mocked(agentService.getAgent).mockResolvedValueOnce({ + id: 'agent-1', + type: 'cherry-claw', + name: 'Test', + model: 'claude-3', + accessible_paths: ['/tmp/test'], + configuration: { heartbeat_enabled: true }, + created_at: new Date().toISOString(), + updated_at: new Date().toISOString() + } as any) + vi.mocked(sessionService.listSessions).mockResolvedValueOnce({ + sessions: [{ id: 'session-1' }] as any, + total: 1 + }) + vi.mocked(sessionService.getSession).mockResolvedValueOnce({ + id: 'session-1', + agent_id: 'agent-1' + } as any) + vi.mocked(sessionMessageService.createSessionMessage).mockResolvedValueOnce({ + stream: new ReadableStream({ start: (c) => c.close() }), + completion: Promise.resolve({}) + } as any) + + const service = SchedulerServiceModule.schedulerService + service.startLoop() + + // Advance past the first tick + await vi.advanceTimersByTimeAsync(100) + + expect(taskService.getDueTasks).toHaveBeenCalled() + // Give the async task time to complete + await vi.advanceTimersByTimeAsync(1000) + + expect(taskService.logTaskRun).toHaveBeenCalled() + expect(taskService.updateTaskAfterRun).toHaveBeenCalledWith('task-1', null, 'Completed') + }) +}) diff --git a/src/main/services/agents/services/channels/ChannelAdapter.ts b/src/main/services/agents/services/channels/ChannelAdapter.ts new file mode 100644 index 00000000000..f5173d3e17a --- /dev/null +++ b/src/main/services/agents/services/channels/ChannelAdapter.ts @@ -0,0 +1,65 @@ +import { EventEmitter } from 'events' + +export type ChannelMessageEvent = { + chatId: string + userId: string + userName: string + text: string +} + +export type ChannelCommandEvent = { + chatId: string + userId: string + userName: string + command: 'new' | 'compact' | 'help' | 'whoami' + args?: string +} + +export type SendMessageOptions = { + parseMode?: 'MarkdownV2' | 'HTML' + replyToMessageId?: number +} + +export type ChannelAdapterConfig = { + channelId: string + agentId: string + channelConfig: Record +} + +export abstract class ChannelAdapter extends EventEmitter { + readonly channelId: string + readonly agentId: string + /** Chat IDs that this adapter can send notifications to (set by subclass in constructor). */ + notifyChatIds: string[] = [] + + constructor(protected readonly config: ChannelAdapterConfig) { + super() + this.channelId = config.channelId + this.agentId = config.agentId + } + + abstract connect(): Promise + abstract disconnect(): Promise + abstract sendMessage(chatId: string, text: string, opts?: SendMessageOptions): Promise + /** Stream a partial/draft message to the chat. Same draftId updates the existing draft in-place. */ + abstract sendMessageDraft(chatId: string, draftId: number, text: string): Promise + abstract sendTypingIndicator(chatId: string): Promise + async finalizeStream(_draftId: number, _finalText: string): Promise { + void _draftId + void _finalText + return false + } + + // Typed event emitter overrides + override emit(event: 'message', data: ChannelMessageEvent): boolean + override emit(event: 'command', data: ChannelCommandEvent): boolean + override emit(event: string, ...args: unknown[]): boolean { + return super.emit(event, ...args) + } + + override on(event: 'message', listener: (data: ChannelMessageEvent) => void): this + override on(event: 'command', listener: (data: ChannelCommandEvent) => void): this + override on(event: string, listener: (...args: any[]) => void): this { + return super.on(event, listener) + } +} diff --git a/src/main/services/agents/services/channels/ChannelManager.ts b/src/main/services/agents/services/channels/ChannelManager.ts new file mode 100644 index 00000000000..c31a3380f7d --- /dev/null +++ b/src/main/services/agents/services/channels/ChannelManager.ts @@ -0,0 +1,158 @@ +import { loggerService } from '@logger' +import type { CherryClawChannel, CherryClawConfiguration } from '@types' + +import { agentService } from '../AgentService' +import type { ChannelAdapter } from './ChannelAdapter' +import { channelMessageHandler } from './ChannelMessageHandler' + +const logger = loggerService.withContext('ChannelManager') + +// Adapter factory registry -- adapters register themselves here +type AdapterFactory = (channelConfig: CherryClawChannel, agentId: string) => ChannelAdapter +const adapterFactories = new Map() + +export function registerAdapterFactory(type: string, factory: AdapterFactory): void { + adapterFactories.set(type, factory) +} + +class ChannelManager { + private static instance: ChannelManager | null = null + private readonly adapters = new Map() // key: `${agentId}:${channelId}` + private readonly notifyChannels = new Set() // key: `${agentId}:${channelId}` + + static getInstance(): ChannelManager { + if (!ChannelManager.instance) { + ChannelManager.instance = new ChannelManager() + } + return ChannelManager.instance + } + + async start(): Promise { + logger.info('Starting channel manager') + try { + const { agents } = await agentService.listAgents() + const clawAgents = agents.filter((a) => a.type === 'cherry-claw') + + for (const agent of clawAgents) { + await this.startAgentChannels(agent.id, (agent.configuration as CherryClawConfiguration)?.channels) + } + + logger.info('Channel manager started', { adapterCount: this.adapters.size }) + } catch (error) { + logger.error('Failed to start channel manager', { + error: error instanceof Error ? error.message : String(error) + }) + } + } + + async stop(): Promise { + logger.info('Stopping channel manager') + const disconnects = Array.from(this.adapters.values()).map((adapter) => + adapter.disconnect().catch((err) => { + logger.warn('Error disconnecting adapter', { + agentId: adapter.agentId, + channelId: adapter.channelId, + error: err instanceof Error ? err.message : String(err) + }) + }) + ) + await Promise.all(disconnects) + this.adapters.clear() + this.notifyChannels.clear() + logger.info('Channel manager stopped') + } + + /** Return connected adapters for an agent whose channel has `is_notify_receiver: true`. */ + getNotifyAdapters(agentId: string): ChannelAdapter[] { + const result: ChannelAdapter[] = [] + for (const [key, adapter] of this.adapters) { + if (adapter.agentId !== agentId) continue + // Look up original channel config to check is_notify_receiver + const channelId = key.split(':')[1] + if (this.notifyChannels.has(`${agentId}:${channelId}`)) { + result.push(adapter) + } + } + return result + } + + async syncAgent(agentId: string): Promise { + // Disconnect existing adapters for this agent + for (const [key, adapter] of this.adapters) { + if (adapter.agentId === agentId) { + await adapter.disconnect().catch((err) => { + logger.warn('Error disconnecting adapter during sync', { + key, + error: err instanceof Error ? err.message : String(err) + }) + }) + this.adapters.delete(key) + this.notifyChannels.delete(key) + } + } + + channelMessageHandler.clearSessionTracker(agentId) + + // Re-create from current config (agent may have been deleted) + const agent = await agentService.getAgent(agentId) + if (!agent || agent.type !== 'cherry-claw') return + + const config = agent.configuration as CherryClawConfiguration + await this.startAgentChannels(agentId, config?.channels) + } + + private async startAgentChannels(agentId: string, channels?: CherryClawChannel[]): Promise { + if (!channels || channels.length === 0) return + + for (const channel of channels) { + if (channel.enabled === false) continue + + const factory = adapterFactories.get(channel.type) + if (!factory) { + logger.warn('No adapter factory for channel type', { type: channel.type, agentId }) + continue + } + + const key = `${agentId}:${channel.id}` + try { + const adapter = factory(channel, agentId) + + adapter.on('message', (msg) => { + channelMessageHandler.handleIncoming(adapter, msg).catch((err) => { + logger.error('Unhandled error in message handler', { + agentId, + channelId: channel.id, + error: err instanceof Error ? err.message : String(err) + }) + }) + }) + + adapter.on('command', (cmd) => { + channelMessageHandler.handleCommand(adapter, cmd).catch((err) => { + logger.error('Unhandled error in command handler', { + agentId, + channelId: channel.id, + error: err instanceof Error ? err.message : String(err) + }) + }) + }) + + await adapter.connect() + this.adapters.set(key, adapter) + if (channel.is_notify_receiver) { + this.notifyChannels.add(key) + } + logger.info('Channel adapter connected', { agentId, channelId: channel.id, type: channel.type }) + } catch (error) { + logger.error('Failed to connect channel adapter', { + agentId, + channelId: channel.id, + type: channel.type, + error: error instanceof Error ? error.message : String(error) + }) + } + } + } +} + +export const channelManager = ChannelManager.getInstance() diff --git a/src/main/services/agents/services/channels/ChannelMessageHandler.ts b/src/main/services/agents/services/channels/ChannelMessageHandler.ts new file mode 100644 index 00000000000..b88d6bfd68c --- /dev/null +++ b/src/main/services/agents/services/channels/ChannelMessageHandler.ts @@ -0,0 +1,289 @@ +import { loggerService } from '@logger' +import type { GetAgentSessionResponse } from '@types' + +import { agentService } from '../AgentService' +import { sessionMessageService } from '../SessionMessageService' +import { sessionService } from '../SessionService' +import type { ChannelAdapter, ChannelCommandEvent, ChannelMessageEvent } from './ChannelAdapter' + +const logger = loggerService.withContext('ChannelMessageHandler') + +const MAX_MESSAGE_LENGTH = 4096 +const DRAFT_THROTTLE_MS = 500 +const TYPING_INTERVAL_MS = 4000 + +export class ChannelMessageHandler { + private static instance: ChannelMessageHandler | null = null + private readonly sessionTracker = new Map() // agentId -> sessionId + + static getInstance(): ChannelMessageHandler { + if (!ChannelMessageHandler.instance) { + ChannelMessageHandler.instance = new ChannelMessageHandler() + } + return ChannelMessageHandler.instance + } + + async handleIncoming(adapter: ChannelAdapter, message: ChannelMessageEvent): Promise { + const { agentId } = adapter + try { + const session = await this.resolveSession(agentId) + if (!session) { + logger.error('Failed to resolve session', { agentId }) + return + } + + const abortController = new AbortController() + const draftId = Math.floor(Math.random() * 2_147_483_647) + 1 + + // Show typing indicator immediately and keep refreshing every 4s + adapter.sendTypingIndicator(message.chatId).catch(() => {}) + const typingInterval = setInterval( + () => adapter.sendTypingIndicator(message.chatId).catch(() => {}), + TYPING_INTERVAL_MS + ) + + try { + const responseText = await this.collectStreamResponse(session, message.text, abortController, (text) => + adapter.sendMessageDraft(message.chatId, draftId, text).catch(() => {}) + ) + + if (responseText) { + const finalized = await adapter.finalizeStream(draftId, responseText).catch(() => false) + if (!finalized) { + await this.sendChunked(adapter, message.chatId, responseText) + } + } + } finally { + clearInterval(typingInterval) + } + } catch (error) { + logger.error('Error handling incoming message', { + agentId, + chatId: message.chatId, + error: error instanceof Error ? error.message : String(error) + }) + } + } + + async handleCommand(adapter: ChannelAdapter, command: ChannelCommandEvent): Promise { + const { agentId } = adapter + try { + switch (command.command) { + case 'new': { + const newSession = await sessionService.createSession(agentId, {}) + if (newSession) { + this.sessionTracker.set(agentId, newSession.id) + await adapter.sendMessage(command.chatId, 'New session created.') + } + break + } + case 'compact': { + const session = await this.resolveSession(agentId) + if (!session) { + await adapter.sendMessage(command.chatId, 'No active session.') + return + } + const abortController = new AbortController() + adapter.sendTypingIndicator(command.chatId).catch(() => {}) + const typingInterval = setInterval( + () => adapter.sendTypingIndicator(command.chatId).catch(() => {}), + TYPING_INTERVAL_MS + ) + try { + const response = await this.collectStreamResponse(session, '/compact', abortController) + await adapter.sendMessage(command.chatId, response || 'Session compacted.') + } finally { + clearInterval(typingInterval) + } + break + } + case 'help': { + const agent = await agentService.getAgent(agentId) + const name = agent?.name ?? 'CherryClaw' + const description = agent?.description ?? '' + const helpText = [ + `*${name}*`, + description ? `_${description}_` : '', + '', + 'Available commands:', + '/new - Start a new conversation session', + '/compact - Compact current session context', + '/help - Show this help message', + '/whoami - Show the current chat ID for allow_ids' + ] + .filter(Boolean) + .join('\n') + await adapter.sendMessage(command.chatId, helpText) + break + } + case 'whoami': { + await adapter.sendMessage( + command.chatId, + [ + `Current chat ID: \`${command.chatId}\``, + '', + 'Add this value to `allow_ids` in settings to receive notifications.' + ].join('\n') + ) + break + } + } + } catch (error) { + logger.error('Error handling command', { + agentId, + command: command.command, + error: error instanceof Error ? error.message : String(error) + }) + } + } + + /** Clear session tracking for an agent (used when agent is deleted/updated) */ + clearSessionTracker(agentId: string): void { + this.sessionTracker.delete(agentId) + } + + private async resolveSession(agentId: string): Promise { + // Check tracker first + const trackedId = this.sessionTracker.get(agentId) + if (trackedId) { + const session = await sessionService.getSession(agentId, trackedId) + if (session) return session + // Tracked session gone, clear it + this.sessionTracker.delete(agentId) + } + + // Fall back to first existing session + const { sessions } = await sessionService.listSessions(agentId, { limit: 1 }) + if (sessions.length > 0) { + this.sessionTracker.set(agentId, sessions[0].id) + return sessionService.getSession(agentId, sessions[0].id) + } + + // Create new session + const newSession = await sessionService.createSession(agentId, {}) + if (newSession) { + this.sessionTracker.set(agentId, newSession.id) + return newSession + } + + return null + } + + private async collectStreamResponse( + session: GetAgentSessionResponse, + content: string, + abortController: AbortController, + onDraft?: (text: string) => void + ): Promise { + const { stream, completion } = await sessionMessageService.createSessionMessage( + session, + { content }, + abortController, + { persist: true } + ) + + const reader = stream.getReader() + let completedText = '' // text from finished blocks/turns + let currentBlockText = '' // cumulative text within the current block + let lastDraftTime = 0 + let draftTimer: ReturnType | undefined + + const emitDraft = () => { + if (!onDraft) return + const fullText = completedText + currentBlockText + if (fullText) onDraft(fullText) + } + + const throttledDraft = () => { + if (!onDraft) return + const now = Date.now() + if (now - lastDraftTime >= DRAFT_THROTTLE_MS) { + lastDraftTime = now + if (draftTimer) clearTimeout(draftTimer) + emitDraft() + } else if (!draftTimer) { + draftTimer = setTimeout( + () => { + draftTimer = undefined + lastDraftTime = Date.now() + emitDraft() + }, + DRAFT_THROTTLE_MS - (now - lastDraftTime) + ) + } + } + + try { + while (true) { + const { done, value } = await reader.read() + if (done) break + + switch (value.type) { + case 'text-delta': + // text-delta values are cumulative within a block + if (value.text) { + currentBlockText = value.text + throttledDraft() + } + break + case 'text-end': + // Block finished — commit current block text and reset for next turn + if (currentBlockText) { + completedText += currentBlockText + '\n\n' + currentBlockText = '' + } + break + } + } + + await completion + } finally { + if (draftTimer) clearTimeout(draftTimer) + } + + // Trim trailing separator + return (completedText + currentBlockText).replace(/\n+$/, '') + } + + private async sendChunked(adapter: ChannelAdapter, chatId: string, text: string): Promise { + if (text.length <= MAX_MESSAGE_LENGTH) { + await adapter.sendMessage(chatId, text) + return + } + + const chunks = this.chunkText(text, MAX_MESSAGE_LENGTH) + for (const chunk of chunks) { + await adapter.sendMessage(chatId, chunk) + } + } + + private chunkText(text: string, maxLength: number): string[] { + const chunks: string[] = [] + let remaining = text + + while (remaining.length > 0) { + if (remaining.length <= maxLength) { + chunks.push(remaining) + break + } + + // Try paragraph boundary + let splitIdx = remaining.lastIndexOf('\n\n', maxLength) + if (splitIdx <= 0) { + // Try line boundary + splitIdx = remaining.lastIndexOf('\n', maxLength) + } + if (splitIdx <= 0) { + // Hard split + splitIdx = maxLength + } + + chunks.push(remaining.slice(0, splitIdx)) + remaining = remaining.slice(splitIdx).replace(/^\n+/, '') + } + + return chunks + } +} + +export const channelMessageHandler = ChannelMessageHandler.getInstance() diff --git a/src/main/services/agents/services/channels/__tests__/ChannelManager.test.ts b/src/main/services/agents/services/channels/__tests__/ChannelManager.test.ts new file mode 100644 index 00000000000..af974bdf470 --- /dev/null +++ b/src/main/services/agents/services/channels/__tests__/ChannelManager.test.ts @@ -0,0 +1,199 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import { agentService } from '../../AgentService' +import { ChannelAdapter, type ChannelAdapterConfig } from '../ChannelAdapter' +import { channelManager, registerAdapterFactory } from '../ChannelManager' +import { channelMessageHandler } from '../ChannelMessageHandler' + +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn(), silly: vi.fn() }) + } +})) + +vi.mock('../../AgentService', () => ({ + agentService: { + listAgents: vi.fn().mockResolvedValue({ agents: [], total: 0 }), + getAgent: vi.fn() + } +})) + +vi.mock('../ChannelMessageHandler', () => ({ + channelMessageHandler: { + handleIncoming: vi.fn(), + handleCommand: vi.fn(), + clearSessionTracker: vi.fn() + } +})) + +class MockAdapter extends ChannelAdapter { + connect = vi.fn().mockResolvedValue(undefined) + disconnect = vi.fn().mockResolvedValue(undefined) + sendMessage = vi.fn().mockResolvedValue(undefined) + sendMessageDraft = vi.fn().mockResolvedValue(undefined) + sendTypingIndicator = vi.fn().mockResolvedValue(undefined) + + constructor(config: ChannelAdapterConfig) { + super(config) + } +} + +// Track adapters created by the factory +let createdAdapters: MockAdapter[] = [] + +describe('ChannelManager', () => { + beforeEach(async () => { + // Defensively stop any leftover adapters from a previous failed test + await channelManager.stop() + vi.clearAllMocks() + createdAdapters = [] + // Re-register the mock factory (the map persists across tests since we don't resetModules) + registerAdapterFactory('telegram', (channel, agentId) => { + const adapter = new MockAdapter({ + channelId: channel.id, + agentId, + channelConfig: channel.config + }) + createdAdapters.push(adapter) + return adapter + }) + }) + + afterEach(async () => { + await channelManager.stop() + }) + + it('start() with no agents does not error', async () => { + vi.mocked(agentService.listAgents).mockResolvedValueOnce({ agents: [] as any[], total: 0 }) + await expect(channelManager.start()).resolves.not.toThrow() + expect(createdAdapters).toHaveLength(0) + }) + + it('start() connects adapters for agents with channels', async () => { + vi.mocked(agentService.listAgents).mockResolvedValueOnce({ + agents: [ + { + id: 'agent-1', + type: 'cherry-claw', + configuration: { + channels: [ + { + id: 'ch-1', + type: 'telegram', + enabled: true, + config: { bot_token: 'tok', allowed_chat_ids: [] } + } + ] + } + } + ] as any[], + total: 1 + }) + + await channelManager.start() + + expect(createdAdapters).toHaveLength(1) + expect(createdAdapters[0].connect).toHaveBeenCalledTimes(1) + }) + + it('stop() disconnects all adapters', async () => { + vi.mocked(agentService.listAgents).mockResolvedValueOnce({ + agents: [ + { + id: 'agent-1', + type: 'cherry-claw', + configuration: { + channels: [ + { id: 'ch-1', type: 'telegram', enabled: true, config: { bot_token: 'tok' } }, + { id: 'ch-2', type: 'telegram', enabled: true, config: { bot_token: 'tok2' } } + ] + } + } + ] as any[], + total: 1 + }) + + await channelManager.start() + expect(createdAdapters).toHaveLength(2) + createdAdapters.forEach((a) => expect(a.connect).toHaveBeenCalledTimes(1)) + + await channelManager.stop() + createdAdapters.forEach((a) => expect(a.disconnect).toHaveBeenCalledTimes(1)) + }) + + it('syncAgent disconnects old and reconnects', async () => { + vi.mocked(agentService.listAgents).mockResolvedValueOnce({ + agents: [ + { + id: 'agent-1', + type: 'cherry-claw', + configuration: { + channels: [{ id: 'ch-1', type: 'telegram', enabled: true, config: { bot_token: 'tok' } }] + } + } + ] as any[], + total: 1 + }) + + await channelManager.start() + expect(createdAdapters).toHaveLength(1) + + // Sync with updated config + vi.mocked(agentService.getAgent).mockResolvedValueOnce({ + id: 'agent-1', + type: 'cherry-claw', + configuration: { + channels: [{ id: 'ch-1', type: 'telegram', enabled: true, config: { bot_token: 'new-tok' } }] + } + } as any) + + await channelManager.syncAgent('agent-1') + + expect(createdAdapters[0].disconnect).toHaveBeenCalledTimes(1) + expect(createdAdapters).toHaveLength(2) // new adapter created + expect(createdAdapters[1].connect).toHaveBeenCalledTimes(1) + expect(channelMessageHandler.clearSessionTracker).toHaveBeenCalledWith('agent-1') + }) + + it('syncAgent for deleted agent disconnects without reconnecting', async () => { + vi.mocked(agentService.listAgents).mockResolvedValueOnce({ + agents: [ + { + id: 'agent-1', + type: 'cherry-claw', + configuration: { + channels: [{ id: 'ch-1', type: 'telegram', enabled: true, config: { bot_token: 'tok' } }] + } + } + ] as any[], + total: 1 + }) + + await channelManager.start() + expect(createdAdapters).toHaveLength(1) + + vi.mocked(agentService.getAgent).mockResolvedValueOnce(null as any) + await channelManager.syncAgent('agent-1') + + expect(createdAdapters[0].disconnect).toHaveBeenCalledTimes(1) + expect(createdAdapters).toHaveLength(1) // no new adapter + }) + + it('disabled channels are skipped', async () => { + vi.mocked(agentService.listAgents).mockResolvedValueOnce({ + agents: [ + { + id: 'agent-1', + type: 'cherry-claw', + configuration: { + channels: [{ id: 'ch-1', type: 'telegram', enabled: false, config: { bot_token: 'tok' } }] + } + } + ] as any[], + total: 1 + }) + + await channelManager.start() + expect(createdAdapters).toHaveLength(0) + }) +}) diff --git a/src/main/services/agents/services/channels/__tests__/ChannelMessageHandler.test.ts b/src/main/services/agents/services/channels/__tests__/ChannelMessageHandler.test.ts new file mode 100644 index 00000000000..0c23bfe10ad --- /dev/null +++ b/src/main/services/agents/services/channels/__tests__/ChannelMessageHandler.test.ts @@ -0,0 +1,286 @@ +import { EventEmitter } from 'events' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { agentService } from '../../AgentService' +import { sessionMessageService } from '../../SessionMessageService' +import { sessionService } from '../../SessionService' +import { channelMessageHandler } from '../ChannelMessageHandler' + +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn(), silly: vi.fn() }) + } +})) + +vi.mock('../../AgentService', () => ({ + agentService: { + getAgent: vi.fn() + } +})) + +vi.mock('../../SessionService', () => ({ + sessionService: { + listSessions: vi.fn().mockResolvedValue({ sessions: [], total: 0 }), + getSession: vi.fn(), + createSession: vi.fn() + } +})) + +vi.mock('../../SessionMessageService', () => ({ + sessionMessageService: { + createSessionMessage: vi.fn() + } +})) + +function createMockStream(parts: Array<{ type: string; text?: string }>) { + const stream = new ReadableStream({ + start(controller) { + for (const part of parts) { + controller.enqueue(part) + } + controller.close() + } + }) + return { stream, completion: Promise.resolve({}) } +} + +function createMockAdapter(overrides: Record = {}) { + const adapter = new EventEmitter() as any + adapter.agentId = overrides.agentId ?? 'agent-1' + adapter.channelId = overrides.channelId ?? 'channel-1' + adapter.sendMessage = vi.fn().mockResolvedValue(undefined) + adapter.sendMessageDraft = vi.fn().mockResolvedValue(undefined) + adapter.sendTypingIndicator = vi.fn().mockResolvedValue(undefined) + adapter.finalizeStream = vi.fn().mockResolvedValue(false) + return adapter +} + +describe('ChannelMessageHandler', () => { + beforeEach(() => { + vi.clearAllMocks() + // Reset the default mock for listSessions after clearAllMocks + vi.mocked(sessionService.listSessions).mockResolvedValue({ sessions: [] as any[], total: 0 }) + // Clear session tracker to ensure clean state + channelMessageHandler.clearSessionTracker('agent-1') + }) + + it('collectStreamResponse accumulates text across turns and sends via adapter', async () => { + const adapter = createMockAdapter() + const session = { id: 'session-1', agent_id: 'agent-1', agent_type: 'cherry-claw' } + + vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any) + vi.mocked(sessionMessageService.createSessionMessage).mockResolvedValueOnce( + createMockStream([ + // Turn 1: cumulative text-delta within block + { type: 'text-delta', text: 'Hello ' }, + { type: 'text-delta', text: 'Hello world!' }, + { type: 'text-end' }, + // Turn 2: new block after tool use + { type: 'text-delta', text: 'Done.' }, + { type: 'text-end' } + ]) as any + ) + + await channelMessageHandler.handleIncoming(adapter, { + chatId: 'chat-1', + userId: 'user-1', + userName: 'User', + text: 'Hi' + }) + + expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', 'Hello world!\n\nDone.') + }) + + it('skips final send when adapter finalizes the draft stream', async () => { + const adapter = createMockAdapter() + const session = { id: 'session-1', agent_id: 'agent-1', agent_type: 'cherry-claw' } + + adapter.finalizeStream.mockResolvedValueOnce(true) + vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any) + vi.mocked(sessionMessageService.createSessionMessage).mockResolvedValueOnce( + createMockStream([{ type: 'text-delta', text: 'Hello world!' }]) as any + ) + + await channelMessageHandler.handleIncoming(adapter, { + chatId: 'chat-1', + userId: 'user-1', + userName: 'User', + text: 'Hi' + }) + + expect(adapter.finalizeStream).toHaveBeenCalledWith(expect.any(Number), 'Hello world!') + expect(adapter.sendMessage).not.toHaveBeenCalled() + }) + + it('sends chunked messages for long responses', async () => { + const adapter = createMockAdapter() + const session = { id: 'session-1', agent_id: 'agent-1', agent_type: 'cherry-claw' } + + vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any) + + const longText = 'A'.repeat(5000) + vi.mocked(sessionMessageService.createSessionMessage).mockResolvedValueOnce( + createMockStream([{ type: 'text-delta', text: longText }]) as any + ) + + await channelMessageHandler.handleIncoming(adapter, { + chatId: 'chat-1', + userId: 'user-1', + userName: 'User', + text: 'Hi' + }) + + expect(adapter.sendMessage).toHaveBeenCalledTimes(2) + }) + + it('handleCommand /new creates a new session', async () => { + const adapter = createMockAdapter() + vi.mocked(sessionService.createSession).mockResolvedValueOnce({ id: 'new-session' } as any) + + await channelMessageHandler.handleCommand(adapter, { + chatId: 'chat-1', + userId: 'user-1', + userName: 'User', + command: 'new' + }) + + expect(sessionService.createSession).toHaveBeenCalledWith('agent-1', {}) + expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', 'New session created.') + }) + + it('handleCommand /compact sends /compact as message content', async () => { + const adapter = createMockAdapter() + const session = { id: 'session-1', agent_id: 'agent-1', agent_type: 'cherry-claw' } + + vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any) + vi.mocked(sessionMessageService.createSessionMessage).mockResolvedValueOnce( + createMockStream([{ type: 'text-delta', text: 'Compacted.' }]) as any + ) + + await channelMessageHandler.handleCommand(adapter, { + chatId: 'chat-1', + userId: 'user-1', + userName: 'User', + command: 'compact' + }) + + expect(sessionMessageService.createSessionMessage).toHaveBeenCalledWith( + session, + { content: '/compact' }, + expect.any(AbortController), + { persist: true } + ) + expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', 'Compacted.') + }) + + it('handleCommand /help sends help text with agent info', async () => { + const adapter = createMockAdapter() + vi.mocked(agentService.getAgent).mockResolvedValueOnce({ + name: 'TestAgent', + description: 'A test agent' + } as any) + + await channelMessageHandler.handleCommand(adapter, { + chatId: 'chat-1', + userId: 'user-1', + userName: 'User', + command: 'help' + }) + + expect(adapter.sendMessage).toHaveBeenCalledTimes(1) + const helpText = adapter.sendMessage.mock.calls[0][1] as string + expect(helpText).toContain('*TestAgent*') + expect(helpText).toContain('_A test agent_') + expect(helpText).toContain('/new') + expect(helpText).toContain('/compact') + expect(helpText).toContain('/help') + expect(helpText).toContain('/whoami') + }) + + it('handleCommand /whoami sends the current chat ID', async () => { + const adapter = createMockAdapter() + + await channelMessageHandler.handleCommand(adapter, { + chatId: 'oc_123', + userId: 'user-1', + userName: 'User', + command: 'whoami' + }) + + expect(adapter.sendMessage).toHaveBeenCalledWith( + 'oc_123', + 'Current chat ID: `oc_123`\n\nAdd this value to `allow_ids` in settings to receive notifications.' + ) + }) + + it('resolveSession tracks sessions after /new', async () => { + const adapter = createMockAdapter() + const newSession = { id: 'new-session', agent_id: 'agent-1', agent_type: 'cherry-claw' } + + vi.mocked(sessionService.createSession).mockResolvedValueOnce(newSession as any) + + await channelMessageHandler.handleCommand(adapter, { + chatId: 'chat-1', + userId: 'user-1', + userName: 'User', + command: 'new' + }) + + // Now send a message — should use the tracked session + vi.mocked(sessionService.getSession).mockResolvedValueOnce(newSession as any) + vi.mocked(sessionMessageService.createSessionMessage).mockResolvedValueOnce( + createMockStream([{ type: 'text-delta', text: 'OK' }]) as any + ) + + await channelMessageHandler.handleIncoming(adapter, { + chatId: 'chat-1', + userId: 'user-1', + userName: 'User', + text: 'test' + }) + + expect(sessionService.getSession).toHaveBeenCalledWith('agent-1', 'new-session') + }) + + it('clearSessionTracker causes fresh session resolution', async () => { + const adapter = createMockAdapter() + const session1 = { id: 'session-1', agent_id: 'agent-1', agent_type: 'cherry-claw' } + const session2 = { id: 'session-2', agent_id: 'agent-1', agent_type: 'cherry-claw' } + + // First interaction creates a session + vi.mocked(sessionService.createSession).mockResolvedValueOnce(session1 as any) + vi.mocked(sessionMessageService.createSessionMessage).mockResolvedValueOnce( + createMockStream([{ type: 'text-delta', text: 'R1' }]) as any + ) + + await channelMessageHandler.handleIncoming(adapter, { + chatId: 'chat-1', + userId: 'user-1', + userName: 'User', + text: 'msg1' + }) + + // Clear session tracker + channelMessageHandler.clearSessionTracker('agent-1') + + // Next interaction should resolve from scratch — listSessions returns session2 + vi.mocked(sessionService.listSessions).mockResolvedValueOnce({ + sessions: [session2] as any[], + total: 1 + }) + vi.mocked(sessionService.getSession).mockResolvedValueOnce(session2 as any) + vi.mocked(sessionMessageService.createSessionMessage).mockResolvedValueOnce( + createMockStream([{ type: 'text-delta', text: 'R2' }]) as any + ) + + await channelMessageHandler.handleIncoming(adapter, { + chatId: 'chat-1', + userId: 'user-1', + userName: 'User', + text: 'msg2' + }) + + expect(sessionService.listSessions).toHaveBeenCalledWith('agent-1', { limit: 1 }) + expect(sessionService.getSession).toHaveBeenCalledWith('agent-1', 'session-2') + }) +}) diff --git a/src/main/services/agents/services/channels/adapters/FeishuAdapter.ts b/src/main/services/agents/services/channels/adapters/FeishuAdapter.ts new file mode 100644 index 00000000000..2d0b65e2877 --- /dev/null +++ b/src/main/services/agents/services/channels/adapters/FeishuAdapter.ts @@ -0,0 +1,564 @@ +import * as Lark from '@larksuiteoapi/node-sdk' +import { loggerService } from '@logger' +import type { CherryClawChannel, FeishuDomain } from '@types' +import { net } from 'electron' + +import { ChannelAdapter, type ChannelAdapterConfig, type SendMessageOptions } from '../ChannelAdapter' +import { registerAdapterFactory } from '../ChannelManager' + +const logger = loggerService.withContext('FeishuAdapter') + +const FEISHU_MAX_LENGTH = 4000 + +type FeishuApiResponse = { + code?: number + msg?: string + message?: string + data?: T +} + +// Feishu message event shape (im.message.receive_v1) +type FeishuMessageEvent = { + sender: { + sender_id: { open_id?: string; user_id?: string; union_id?: string } + sender_type?: string + } + message: { + message_id: string + chat_id: string + chat_type: 'p2p' | 'group' + message_type: string + content: string // JSON-encoded + mentions?: Array<{ key: string; id: { open_id?: string }; name: string }> + } +} + +function resolveDomain(domain: FeishuDomain): Lark.Domain { + switch (domain) { + case 'lark': + return Lark.Domain.Lark + case 'feishu': + default: + return Lark.Domain.Feishu + } +} + +/** + * A lightweight HttpInstance adapter that routes requests through Electron's net.fetch, + * which respects system proxy settings. This ensures the Lark SDK works behind + * corporate proxies where raw Node.js fetch/axios would fail. + */ +function createElectronHttpInstance(): Lark.HttpInstance { + async function doRequest(method: string, url: string, data?: unknown, opts?: Record): Promise { + const headers: Record = { ...opts?.headers } + let body: string | FormData | undefined + + if (data !== undefined && data !== null) { + if (typeof data === 'string') { + body = data + } else if (data instanceof FormData) { + body = data + } else { + body = JSON.stringify(data) + if (!headers['Content-Type'] && !headers['content-type']) { + headers['Content-Type'] = 'application/json' + } + } + } + + const fetchUrl = new URL(url) + if (opts?.params) { + for (const [key, value] of Object.entries(opts.params)) { + fetchUrl.searchParams.set(key, String(value)) + } + } + + const res = await net.fetch(fetchUrl.toString(), { + method: method.toUpperCase(), + headers, + body + }) + + const isStream = opts?.responseType === 'stream' + const responseData = isStream + ? res.body + : await res.text().then((text) => { + if (!text) { + return '' + } + + try { + return JSON.parse(text) as unknown + } catch { + return text + } + }) + const responseHeaders = Object.fromEntries(res.headers.entries()) + + if (!res.ok) { + const detail = + typeof responseData === 'string' + ? responseData + : (responseData as { msg?: string; message?: string } | null)?.msg || + (responseData as { msg?: string; message?: string } | null)?.message || + res.statusText + const error = new Error(`Feishu HTTP ${res.status}: ${detail}`) + ;(error as Error & { response?: unknown }).response = { + data: responseData, + headers: responseHeaders, + status: res.status, + statusText: res.statusText + } + throw error + } + + if (opts?.$return_headers) { + return { + data: responseData, + headers: responseHeaders + } + } + + return responseData + } + + return { + request: (opts: any) => doRequest(opts.method || 'GET', opts.url, opts.data, opts), + get: (url: string, opts?: any) => doRequest('GET', url, undefined, opts), + delete: (url: string, opts?: any) => doRequest('DELETE', url, undefined, opts), + head: (url: string, opts?: any) => doRequest('HEAD', url, undefined, opts), + options: (url: string, opts?: any) => doRequest('OPTIONS', url, undefined, opts), + post: (url: string, data?: any, opts?: any) => doRequest('POST', url, data, opts), + put: (url: string, data?: any, opts?: any) => doRequest('PUT', url, data, opts), + patch: (url: string, data?: any, opts?: any) => doRequest('PATCH', url, data, opts) + } as Lark.HttpInstance +} + +function unwrapFeishuResponse(response: unknown): FeishuApiResponse { + if (response && typeof response === 'object' && 'code' in response) { + return response as FeishuApiResponse + } + + if ( + response && + typeof response === 'object' && + 'data' in response && + response.data && + typeof response.data === 'object' && + 'code' in response.data + ) { + return response.data as FeishuApiResponse + } + + return { code: -1, msg: 'Unexpected Feishu API response' } +} + +function ensureFeishuSuccess(response: unknown, action: string): FeishuApiResponse { + const unwrapped = unwrapFeishuResponse(response) + if (unwrapped.code === 0) { + return unwrapped + } + + throw new Error(`${action} failed: ${unwrapped.msg || unwrapped.message || `code=${String(unwrapped.code)}`}`) +} + +function splitMessage(text: string): string[] { + if (text.length <= FEISHU_MAX_LENGTH) { + return [text] + } + + const chunks: string[] = [] + let remaining = text + + while (remaining.length > 0) { + if (remaining.length <= FEISHU_MAX_LENGTH) { + chunks.push(remaining) + break + } + + let splitIndex = remaining.lastIndexOf('\n\n', FEISHU_MAX_LENGTH) + if (splitIndex <= 0) { + splitIndex = remaining.lastIndexOf('\n', FEISHU_MAX_LENGTH) + } + if (splitIndex <= 0) { + splitIndex = FEISHU_MAX_LENGTH + } + + chunks.push(remaining.slice(0, splitIndex)) + remaining = remaining.slice(splitIndex).replace(/^\n+/, '') + } + + return chunks +} + +/** + * Build a Feishu "post" message payload with markdown element. + * Feishu's post format with md tag renders markdown natively. + */ +function buildPostPayload(text: string): string { + return JSON.stringify({ + zh_cn: { + content: [[{ tag: 'md', text }]] + } + }) +} + +/** + * Build a Feishu interactive card with markdown content (schema 2.0). + */ +function buildMarkdownCard(text: string): string { + return JSON.stringify({ + schema: '2.0', + config: { wide_screen_mode: true }, + body: { + elements: [{ tag: 'markdown', content: text }] + } + }) +} + +const STREAMING_ELEMENT_ID = 'streaming_content' + +/** + * Manages a streaming card session using the Lark SDK's CardKit API. + * Creates a card with streaming_mode, updates content incrementally, and closes when done. + */ +class FeishuStreamingSession { + private cardId: string | null = null + private sequence = 0 + private lastUpdateTime = 0 + private updateQueue: Promise = Promise.resolve() + private readonly throttleMs = 150 + + constructor(private readonly client: Lark.Client) {} + + async create(): Promise { + try { + const res = ensureFeishuSuccess<{ card_id?: string }>( + await this.client.cardkit.v1.card.create({ + data: { + type: 'card_json', + data: JSON.stringify({ + schema: '2.0', + config: { wide_screen_mode: true, streaming_mode: true }, + body: { + elements: [{ tag: 'markdown', content: '...', element_id: STREAMING_ELEMENT_ID }] + } + }) + } + }), + 'Create streaming card' + ) + + if (res.data?.card_id) { + this.cardId = res.data.card_id + return this.cardId + } + + logger.warn('Failed to create streaming card', { code: res.code, msg: res.msg }) + return null + } catch (error) { + logger.error('Error creating streaming card', { + error: error instanceof Error ? error.message : String(error) + }) + return null + } + } + + getCardContent(): string { + return JSON.stringify({ type: 'card', data: { card_id: this.cardId } }) + } + + async update(text: string): Promise { + if (!this.cardId) return + + const now = Date.now() + if (now - this.lastUpdateTime < this.throttleMs) { + return + } + + this.updateQueue = this.updateQueue.then(async () => { + this.lastUpdateTime = Date.now() + this.sequence++ + try { + ensureFeishuSuccess( + await this.client.cardkit.v1.cardElement.content({ + path: { card_id: this.cardId!, element_id: STREAMING_ELEMENT_ID }, + data: { + content: JSON.stringify({ tag: 'markdown', content: text }), + sequence: this.sequence + } + }), + 'Update streaming card' + ) + } catch { + // Swallow update errors to avoid blocking the stream + } + }) + + await this.updateQueue + } + + async close(): Promise { + if (!this.cardId) return + + await this.updateQueue + + try { + this.sequence++ + ensureFeishuSuccess( + await this.client.cardkit.v1.card.settings({ + path: { card_id: this.cardId }, + data: { + settings: JSON.stringify({ streaming_mode: false }), + sequence: this.sequence + } + }), + 'Close streaming card' + ) + } catch (error) { + logger.warn('Error closing streaming card', { + error: error instanceof Error ? error.message : String(error) + }) + } + } +} + +class FeishuAdapter extends ChannelAdapter { + private client: Lark.Client | null = null + private wsClient: Lark.WSClient | null = null + private readonly appId: string + private readonly appSecret: string + private readonly encryptKey: string + private readonly verificationToken: string + private readonly allowedChatIds: string[] + private readonly domain: FeishuDomain + // Track active streaming sessions: draftId -> { session, chatId, messageId } + private readonly streamingSessions = new Map< + number, + { session: FeishuStreamingSession; chatId: string; messageId?: string } + >() + + constructor(config: ChannelAdapterConfig) { + super(config) + const { app_id, app_secret, encrypt_key, verification_token, allowed_chat_ids, domain } = config.channelConfig + this.appId = (app_id as string) ?? '' + this.appSecret = (app_secret as string) ?? '' + this.encryptKey = (encrypt_key as string) ?? '' + this.verificationToken = (verification_token as string) ?? '' + const rawIds = allowed_chat_ids as string[] | undefined + this.allowedChatIds = Array.isArray(rawIds) ? rawIds.map(String) : [] + this.domain = ((domain as string) ?? 'feishu') as FeishuDomain + this.notifyChatIds = [...this.allowedChatIds] + } + + async connect(): Promise { + if (!this.appId || !this.appSecret) { + throw new Error('Feishu app_id and app_secret are required') + } + + const larkDomain = resolveDomain(this.domain) + + this.client = new Lark.Client({ + appId: this.appId, + appSecret: this.appSecret, + appType: Lark.AppType.SelfBuild, + domain: larkDomain, + httpInstance: createElectronHttpInstance() + }) + + const eventDispatcher = new Lark.EventDispatcher({ + encryptKey: this.encryptKey || undefined, + verificationToken: this.verificationToken || undefined + }).register({ + 'im.message.receive_v1': async (data: unknown) => { + const event = data as FeishuMessageEvent + this.handleMessageEvent(event) + } + }) + + this.wsClient = new Lark.WSClient({ + appId: this.appId, + appSecret: this.appSecret, + domain: larkDomain, + loggerLevel: Lark.LoggerLevel.warn + }) + + await this.wsClient.start({ eventDispatcher }) + + logger.info('Feishu bot started (WebSocket)', { agentId: this.agentId, channelId: this.channelId }) + } + + async disconnect(): Promise { + for (const [, entry] of this.streamingSessions) { + await entry.session.close().catch(() => {}) + } + this.streamingSessions.clear() + + this.wsClient = null + this.client = null + logger.info('Feishu bot stopped', { agentId: this.agentId, channelId: this.channelId }) + } + + async sendMessage(chatId: string, text: string, _opts?: SendMessageOptions): Promise { + if (!this.client) { + throw new Error('Client is not connected') + } + void _opts + + const chunks = splitMessage(text) + + for (let i = 0; i < chunks.length; i++) { + ensureFeishuSuccess( + await this.client.im.message.create({ + params: { receive_id_type: 'chat_id' }, + data: { + receive_id: chatId, + msg_type: 'post', + content: buildPostPayload(chunks[i]) + } + }), + 'Send Feishu message' + ) + + if (i < chunks.length - 1) { + await new Promise((resolve) => setTimeout(resolve, 100)) + } + } + } + + async sendMessageDraft(chatId: string, draftId: number, text: string): Promise { + if (!this.client) { + throw new Error('Client is not connected') + } + + let entry = this.streamingSessions.get(draftId) + + if (!entry) { + const session = new FeishuStreamingSession(this.client) + const cardId = await session.create() + if (!cardId) return + + try { + const res = ensureFeishuSuccess<{ message_id?: string }>( + await this.client.im.message.create({ + params: { receive_id_type: 'chat_id' }, + data: { + receive_id: chatId, + msg_type: 'interactive', + content: session.getCardContent() + } + }), + 'Send streaming card message' + ) + const messageId = res.data?.message_id + entry = { session, chatId, messageId } + this.streamingSessions.set(draftId, entry) + } catch (error) { + logger.warn('Failed to send streaming card message', { + error: error instanceof Error ? error.message : String(error) + }) + return + } + } + + await entry.session.update(text) + } + + async sendTypingIndicator(_chatId: string): Promise { + void _chatId + // Feishu doesn't have a native typing indicator API. + // The streaming card itself serves as a visual indicator. + } + + /** + * Finalize a streaming session: close the streaming card and optionally + * update the message to a static markdown card for long-term readability. + */ + override async finalizeStream(draftId: number, finalText: string): Promise { + const entry = this.streamingSessions.get(draftId) + if (!entry) return false + + await entry.session.close() + this.streamingSessions.delete(draftId) + + if (entry.messageId && this.client) { + try { + ensureFeishuSuccess( + await this.client.im.message.update({ + path: { message_id: entry.messageId }, + data: { + msg_type: 'interactive', + content: buildMarkdownCard(finalText) + } + }), + 'Finalize Feishu streaming card' + ) + return true + } catch (error) { + logger.warn('Failed to finalize streaming card', { + error: error instanceof Error ? error.message : String(error) + }) + } + } + + return false + } + + private handleMessageEvent(event: FeishuMessageEvent): void { + const chatId = event.message.chat_id?.trim() + if (!chatId) return + + if (this.allowedChatIds.length > 0 && !this.allowedChatIds.includes(chatId)) { + logger.debug('Dropping message from unauthorized chat', { chatId }) + return + } + + if (event.message.message_type !== 'text') return + + let text: string + try { + const parsed = JSON.parse(event.message.content) as { text?: string } + text = parsed.text ?? '' + } catch { + return + } + + // Strip @mention tags (e.g., @_user_1 in group chats) + text = text.replace(/@_user_\d+/g, '').trim() + if (!text) return + + const userId = event.sender.sender_id.open_id ?? event.sender.sender_id.user_id ?? '' + + // Check for commands (Feishu doesn't have native bot commands, use text prefix) + if (text.startsWith('/')) { + const parts = text.split(/\s+/) + const cmd = parts[0].slice(1).toLowerCase() + if (cmd === 'new' || cmd === 'compact' || cmd === 'help' || cmd === 'whoami') { + this.emit('command', { + chatId, + userId, + userName: '', + command: cmd as 'new' | 'compact' | 'help' | 'whoami', + args: parts.slice(1).join(' ') || undefined + }) + return + } + } + + this.emit('message', { + chatId, + userId, + userName: '', + text + }) + } +} + +// Self-registration +registerAdapterFactory('feishu', (channel: CherryClawChannel, agentId: string) => { + return new FeishuAdapter({ + channelId: channel.id, + agentId, + channelConfig: channel.config + }) +}) diff --git a/src/main/services/agents/services/channels/adapters/QQAdapter.ts b/src/main/services/agents/services/channels/adapters/QQAdapter.ts new file mode 100644 index 00000000000..367263e76c0 --- /dev/null +++ b/src/main/services/agents/services/channels/adapters/QQAdapter.ts @@ -0,0 +1,612 @@ +import { loggerService } from '@logger' +import type { CherryClawChannel } from '@types' +import WebSocket from 'ws' + +import { ChannelAdapter, type ChannelAdapterConfig, type SendMessageOptions } from '../ChannelAdapter' +import { registerAdapterFactory } from '../ChannelManager' + +const logger = loggerService.withContext('QQAdapter') + +const QQ_MAX_LENGTH = 2000 +const QQ_API_BASE = 'https://api.sgroup.qq.com' + +// QQ Bot WebSocket opcodes +const OP_DISPATCH = 0 +const OP_HEARTBEAT = 1 +const OP_IDENTIFY = 2 +const OP_RESUME = 6 +const OP_RECONNECT = 7 +const OP_INVALID_SESSION = 9 +const OP_HELLO = 10 +const OP_HEARTBEAT_ACK = 11 + +// Intent flags +const INTENTS = { + PUBLIC_GUILD_MESSAGES: 1 << 30, + DIRECT_MESSAGE: 1 << 12, + GROUP_AND_C2C: 1 << 25 +} + +type QQTokenCache = { + accessToken: string + expiresAt: number +} + +type QQMessage = { + id: string + author: { + id: string + user_openid?: string + member_openid?: string + username?: string + } + content: string + timestamp: string + channel_id?: string + guild_id?: string + group_id?: string + group_openid?: string +} + +/** + * Split a long message into chunks that fit within QQ's character limit. + */ +function splitMessage(text: string): string[] { + if (text.length <= QQ_MAX_LENGTH) { + return [text] + } + + const chunks: string[] = [] + let remaining = text + + while (remaining.length > 0) { + if (remaining.length <= QQ_MAX_LENGTH) { + chunks.push(remaining) + break + } + + let splitIndex = remaining.lastIndexOf('\n\n', QQ_MAX_LENGTH) + if (splitIndex <= 0) { + splitIndex = remaining.lastIndexOf('\n', QQ_MAX_LENGTH) + } + if (splitIndex <= 0) { + splitIndex = remaining.lastIndexOf(' ', QQ_MAX_LENGTH) + } + if (splitIndex <= 0) { + splitIndex = QQ_MAX_LENGTH + } + + chunks.push(remaining.slice(0, splitIndex)) + remaining = remaining.slice(splitIndex).replace(/^\n+/, '').trimStart() + } + + return chunks +} + +class QQAdapter extends ChannelAdapter { + private ws: WebSocket | null = null + private readonly appId: string + private readonly clientSecret: string + private readonly allowedChatIds: string[] + + private tokenCache: QQTokenCache | null = null + private sessionId: string | null = null + private lastSeq: number | null = null + private heartbeatInterval: ReturnType | null = null + private reconnectAttempts = 0 + private isConnecting = false + private shouldStop = false + + private readonly reconnectDelays = [1000, 2000, 5000, 10000, 30000, 60000] + private readonly maxReconnectAttempts = 100 + + constructor(config: ChannelAdapterConfig) { + super(config) + const { app_id, client_secret, allowed_chat_ids } = config.channelConfig + this.appId = (app_id as string) ?? '' + this.clientSecret = (client_secret as string) ?? '' + const rawIds = allowed_chat_ids as string[] | undefined + this.allowedChatIds = Array.isArray(rawIds) ? rawIds.map(String) : [] + // Expose for notify tool + this.notifyChatIds = [...this.allowedChatIds] + } + + async connect(): Promise { + if (!this.appId || !this.clientSecret) { + throw new Error('QQ Bot AppID and ClientSecret are required') + } + + this.shouldStop = false + await this.startGateway() + + logger.info('QQ bot started', { agentId: this.agentId, channelId: this.channelId }) + } + + async disconnect(): Promise { + this.shouldStop = true + this.cleanup() + logger.info('QQ bot stopped', { agentId: this.agentId, channelId: this.channelId }) + } + + private async getAccessToken(): Promise { + // Check cache + if (this.tokenCache && Date.now() < this.tokenCache.expiresAt - 60000) { + return this.tokenCache.accessToken + } + + const response = await fetch('https://bots.qq.com/app/getAppAccessToken', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + appId: this.appId, + clientSecret: this.clientSecret + }) + }) + + if (!response.ok) { + throw new Error(`Failed to get access token: HTTP ${response.status}`) + } + + const data = (await response.json()) as { access_token?: string; expires_in?: number } + if (!data.access_token || !data.expires_in) { + const errorText = JSON.stringify(data) + throw new Error(`Invalid token response from QQ API: ${errorText}`) + } + + this.tokenCache = { + accessToken: data.access_token, + expiresAt: Date.now() + data.expires_in * 1000 + } + + return data.access_token + } + + private async apiRequest( + endpoint: string, + options?: { method?: string; body?: Record } + ): Promise { + const token = await this.getAccessToken() + const response = await fetch(endpoint, { + method: options?.method ?? 'GET', + headers: { + Authorization: `QQBot ${token}`, + 'Content-Type': 'application/json', + 'X-Union-Appid': this.appId + }, + ...(options?.body ? { body: JSON.stringify(options.body) } : {}) + }) + + if (!response.ok) { + const errorText = await response.text().catch(() => '') + throw new Error(`QQ API request failed ${endpoint}: HTTP ${response.status} - ${errorText}`) + } + + return response + } + + private async getGatewayUrl(): Promise { + const response = await this.apiRequest(`${QQ_API_BASE}/gateway`) + const data = (await response.json()) as { url: string } + return data.url + } + + private async startGateway(): Promise { + if (this.isConnecting || this.shouldStop) return + this.isConnecting = true + + try { + this.cleanup() + + const gatewayUrl = await this.getGatewayUrl() + logger.info('Connecting to QQ gateway', { agentId: this.agentId, url: gatewayUrl }) + + const ws = new WebSocket(gatewayUrl) + this.ws = ws + + ws.on('open', () => { + logger.info('QQ WebSocket connected', { agentId: this.agentId }) + }) + + ws.on('message', (data: Buffer) => { + this.handleWsMessage(data).catch((err) => { + logger.error('Error handling WS message', { + agentId: this.agentId, + error: err instanceof Error ? err.message : String(err) + }) + }) + }) + + ws.on('close', (code, reason) => { + logger.info('QQ WebSocket closed', { + agentId: this.agentId, + code, + reason: reason.toString() + }) + this.scheduleReconnect() + }) + + ws.on('error', (err) => { + logger.error('QQ WebSocket error', { + agentId: this.agentId, + error: err.message + }) + }) + } catch (error) { + logger.error('Failed to start QQ gateway', { + agentId: this.agentId, + error: error instanceof Error ? error.message : String(error) + }) + this.scheduleReconnect() + } finally { + this.isConnecting = false + } + } + + private async handleWsMessage(data: Buffer): Promise { + let payload: { op: number; d?: unknown; s?: number; t?: string } + try { + payload = JSON.parse(data.toString()) + } catch { + logger.warn('Invalid JSON from QQ WebSocket', { agentId: this.agentId }) + return + } + + if (payload.s !== undefined) { + this.lastSeq = payload.s + } + + switch (payload.op) { + case OP_HELLO: + await this.handleHello(payload.d as { heartbeat_interval: number }) + break + case OP_DISPATCH: + if (payload.t) { + await this.handleDispatch(payload.t, payload.d) + } + break + case OP_HEARTBEAT_ACK: + // Heartbeat acknowledged + break + case OP_RECONNECT: + logger.info('QQ gateway requested reconnect', { agentId: this.agentId }) + this.scheduleReconnect() + break + case OP_INVALID_SESSION: + logger.warn('QQ invalid session', { agentId: this.agentId }) + this.sessionId = null + this.lastSeq = null + this.scheduleReconnect() + break + } + } + + private async handleHello(data: { heartbeat_interval: number }): Promise { + // Start heartbeat + this.heartbeatInterval = setInterval(() => { + this.sendHeartbeat() + }, data.heartbeat_interval) + + // Identify or resume + if (this.sessionId && this.lastSeq !== null) { + await this.sendResume() + } else { + await this.sendIdentify() + } + } + + private async sendIdentify(): Promise { + const token = await this.getAccessToken() + const intents = INTENTS.PUBLIC_GUILD_MESSAGES | INTENTS.DIRECT_MESSAGE | INTENTS.GROUP_AND_C2C + + this.send({ + op: OP_IDENTIFY, + d: { + token: `QQBot ${token}`, + intents, + shard: [0, 1] + } + }) + } + + private async sendResume(): Promise { + const token = await this.getAccessToken() + + this.send({ + op: OP_RESUME, + d: { + token: `QQBot ${token}`, + session_id: this.sessionId, + seq: this.lastSeq + } + }) + } + + private sendHeartbeat(): void { + this.send({ + op: OP_HEARTBEAT, + d: this.lastSeq + }) + } + + private send(payload: object): void { + if (this.ws?.readyState === WebSocket.OPEN) { + this.ws.send(JSON.stringify(payload)) + } + } + + private async handleDispatch(eventType: string, data: unknown): Promise { + switch (eventType) { + case 'READY': { + const readyData = data as { session_id: string; user: { id: string; username: string } } + this.sessionId = readyData.session_id + this.reconnectAttempts = 0 + logger.info('QQ bot ready', { + agentId: this.agentId, + sessionId: this.sessionId, + botUser: readyData.user.username + }) + break + } + case 'RESUMED': + this.reconnectAttempts = 0 + logger.info('QQ session resumed', { agentId: this.agentId }) + break + case 'C2C_MESSAGE_CREATE': + await this.handleC2CMessage(data as QQMessage) + break + case 'GROUP_AT_MESSAGE_CREATE': + await this.handleGroupMessage(data as QQMessage) + break + case 'AT_MESSAGE_CREATE': + await this.handleGuildMessage(data as QQMessage) + break + case 'DIRECT_MESSAGE_CREATE': + await this.handleDirectMessage(data as QQMessage) + break + } + } + + private async handleC2CMessage(msg: QQMessage): Promise { + const chatId = `c2c:${msg.author.user_openid}` + + if (!this.isAllowed(chatId, msg.author.user_openid)) return + + const text = this.parseContent(msg.content) + if (this.isCommand(text)) { + if (text.startsWith('/whoami')) { + await this.sendWhoami(chatId) + return + } + this.emitCommand(chatId, msg.author.user_openid ?? '', '', text) + } else { + this.emit('message', { + chatId, + userId: msg.author.user_openid ?? msg.author.id, + userName: msg.author.username ?? '', + text + }) + } + } + + private async handleGroupMessage(msg: QQMessage): Promise { + const chatId = `group:${msg.group_openid}` + + if (!this.isAllowed(chatId, msg.group_openid)) return + + const text = this.parseContent(msg.content) + if (this.isCommand(text)) { + if (text.startsWith('/whoami')) { + await this.sendWhoami(chatId) + return + } + this.emitCommand(chatId, msg.author.member_openid ?? '', '', text) + } else { + this.emit('message', { + chatId, + userId: msg.author.member_openid ?? msg.author.id, + userName: msg.author.username ?? '', + text + }) + } + } + + private async handleGuildMessage(msg: QQMessage): Promise { + const chatId = `channel:${msg.channel_id}` + + if (!this.isAllowed(chatId, msg.channel_id)) return + + const text = this.parseContent(msg.content) + if (this.isCommand(text)) { + if (text.startsWith('/whoami')) { + await this.sendWhoami(chatId) + return + } + this.emitCommand(chatId, msg.author.id, msg.author.username ?? '', text) + } else { + this.emit('message', { + chatId, + userId: msg.author.id, + userName: msg.author.username ?? '', + text + }) + } + } + + private async handleDirectMessage(msg: QQMessage): Promise { + const chatId = `dm:${msg.guild_id}` + + if (!this.isAllowed(chatId, msg.guild_id)) return + + const text = this.parseContent(msg.content) + if (this.isCommand(text)) { + if (text.startsWith('/whoami')) { + await this.sendWhoami(chatId) + return + } + this.emitCommand(chatId, msg.author.id, msg.author.username ?? '', text) + } else { + this.emit('message', { + chatId, + userId: msg.author.id, + userName: msg.author.username ?? '', + text + }) + } + } + + private parseContent(content: string): string { + // Remove @bot mentions and trim + return content.replace(/<@!\d+>/g, '').trim() + } + + private isAllowed(chatId: string, rawId?: string): boolean { + if (this.allowedChatIds.length === 0) return true + return this.allowedChatIds.includes(chatId) || (rawId !== undefined && this.allowedChatIds.includes(rawId)) + } + + private isCommand(text: string): boolean { + return ( + text.startsWith('/new') || text.startsWith('/compact') || text.startsWith('/help') || text.startsWith('/whoami') + ) + } + + private emitCommand(chatId: string, userId: string, userName: string, text: string): void { + const cmd = text.split(/\s+/)[0].slice(1) as 'new' | 'compact' | 'help' + this.emit('command', { chatId, userId, userName, command: cmd }) + } + + private async sendWhoami(chatId: string): Promise { + const [type] = chatId.split(':') + const typeLabel = + type === 'c2c' ? 'Private' : type === 'group' ? 'Group' : type === 'channel' ? 'Guild Channel' : 'Direct Message' + + const message = [ + `📍 Chat Info`, + ``, + `Type: ${typeLabel}`, + `Chat ID: ${chatId}`, + ``, + `To enable notifications for this chat:`, + `1. Go to Agent Settings → Channels → QQ`, + `2. Add "${chatId}" to Allowed Chat IDs`, + `3. Enable "Receive Notifications"`, + ``, + `Then use the notify tool or scheduled tasks will send messages here.` + ].join('\n') + + try { + await this.sendMessage(chatId, message) + } catch (err) { + logger.error('Failed to send whoami response', { + agentId: this.agentId, + chatId, + error: err instanceof Error ? err.message : String(err) + }) + } + } + + async sendMessage(chatId: string, text: string, _opts?: SendMessageOptions): Promise { + const chunks = splitMessage(text) + + for (let i = 0; i < chunks.length; i++) { + await this.sendToChat(chatId, chunks[i]) + + if (i < chunks.length - 1) { + await new Promise((resolve) => setTimeout(resolve, 100)) + } + } + } + + private async sendToChat(chatId: string, text: string): Promise { + const [type, id] = chatId.split(':') + + let endpoint: string + let body: Record + + switch (type) { + case 'c2c': + endpoint = `${QQ_API_BASE}/v2/users/${id}/messages` + body = { content: text, msg_type: 0 } + break + case 'group': + endpoint = `${QQ_API_BASE}/v2/groups/${id}/messages` + body = { content: text, msg_type: 0 } + break + case 'channel': + endpoint = `${QQ_API_BASE}/channels/${id}/messages` + body = { content: text } + break + case 'dm': + endpoint = `${QQ_API_BASE}/dms/${id}/messages` + body = { content: text } + break + default: + throw new Error(`Unknown chat type: ${type}`) + } + + await this.apiRequest(endpoint, { method: 'POST', body }) + } + + async sendMessageDraft(_chatId: string, _draftId: number, _text: string): Promise { + // QQ does not have a native draft/streaming API like Telegram + // This is a no-op; final message is sent via sendMessage + } + + async sendTypingIndicator(_chatId: string): Promise { + // QQ Bot API does not support typing indicators for most message types + // For C2C, there's sendC2CInputNotify but it requires message_id context + // This is a no-op + } + + private cleanup(): void { + if (this.heartbeatInterval) { + clearInterval(this.heartbeatInterval) + this.heartbeatInterval = null + } + if (this.ws) { + if (this.ws.readyState === WebSocket.OPEN || this.ws.readyState === WebSocket.CONNECTING) { + this.ws.close() + } + this.ws = null + } + this.tokenCache = null + } + + private scheduleReconnect(): void { + if (this.shouldStop || this.reconnectAttempts >= this.maxReconnectAttempts) { + if (!this.shouldStop) { + logger.error('Max reconnect attempts reached', { agentId: this.agentId }) + } + return + } + + const delay = this.reconnectDelays[Math.min(this.reconnectAttempts, this.reconnectDelays.length - 1)] + this.reconnectAttempts++ + + logger.info('Scheduling QQ reconnect', { + agentId: this.agentId, + attempt: this.reconnectAttempts, + delay + }) + + setTimeout(() => { + if (!this.shouldStop) { + this.startGateway().catch((err) => { + logger.error('Reconnect failed', { + agentId: this.agentId, + error: err instanceof Error ? err.message : String(err) + }) + }) + } + }, delay) + } +} + +// Self-registration +registerAdapterFactory('qq', (channel: CherryClawChannel, agentId: string) => { + return new QQAdapter({ + channelId: channel.id, + agentId, + channelConfig: channel.config + }) +}) diff --git a/src/main/services/agents/services/channels/adapters/TelegramAdapter.ts b/src/main/services/agents/services/channels/adapters/TelegramAdapter.ts new file mode 100644 index 00000000000..f1c714c136c --- /dev/null +++ b/src/main/services/agents/services/channels/adapters/TelegramAdapter.ts @@ -0,0 +1,210 @@ +import { loggerService } from '@logger' +import type { CherryClawChannel } from '@types' +import { Bot } from 'grammy' + +import { ChannelAdapter, type ChannelAdapterConfig, type SendMessageOptions } from '../ChannelAdapter' +import { registerAdapterFactory } from '../ChannelManager' + +const logger = loggerService.withContext('TelegramAdapter') + +const TELEGRAM_MAX_LENGTH = 4096 + +/** + * Split a long message into chunks that fit within Telegram's 4096 character limit. + * Tries to split on paragraph boundaries first, then line boundaries, then hard-splits. + */ +function splitMessage(text: string): string[] { + if (text.length <= TELEGRAM_MAX_LENGTH) { + return [text] + } + + const chunks: string[] = [] + let remaining = text + + while (remaining.length > 0) { + if (remaining.length <= TELEGRAM_MAX_LENGTH) { + chunks.push(remaining) + break + } + + // Try to split on paragraph boundary + let splitIndex = remaining.lastIndexOf('\n\n', TELEGRAM_MAX_LENGTH) + if (splitIndex <= 0) { + // Try to split on line boundary + splitIndex = remaining.lastIndexOf('\n', TELEGRAM_MAX_LENGTH) + } + if (splitIndex <= 0) { + // Hard split at max length + splitIndex = TELEGRAM_MAX_LENGTH + } + + chunks.push(remaining.slice(0, splitIndex)) + remaining = remaining.slice(splitIndex).replace(/^\n+/, '') + } + + return chunks +} + +class TelegramAdapter extends ChannelAdapter { + private bot: Bot | null = null + private readonly botToken: string + private readonly allowedChatIds: string[] + + constructor(config: ChannelAdapterConfig) { + super(config) + const { bot_token, allowed_chat_ids } = config.channelConfig + this.botToken = (bot_token as string) ?? '' + const rawIds = allowed_chat_ids as string[] | undefined + this.allowedChatIds = Array.isArray(rawIds) ? rawIds.map(String) : [] + // Expose for notify tool — all allowed chats receive notifications + this.notifyChatIds = [...this.allowedChatIds] + } + + async connect(): Promise { + if (!this.botToken) { + throw new Error('Telegram bot token is required') + } + + const bot = new Bot(this.botToken) + this.bot = bot + + // Auth middleware — must be first + bot.use(async (ctx, next) => { + const chatId = ctx.chat?.id?.toString() + if (this.allowedChatIds.length > 0 && (!chatId || !this.allowedChatIds.includes(chatId))) { + logger.debug('Dropping message from unauthorized chat', { chatId }) + return + } + await next() + }) + + // Command handlers + bot.command('new', (ctx) => { + this.emit('command', { + chatId: ctx.chat.id.toString(), + userId: ctx.from?.id?.toString() ?? '', + userName: ctx.from?.first_name ?? '', + command: 'new' + }) + }) + + bot.command('compact', (ctx) => { + this.emit('command', { + chatId: ctx.chat.id.toString(), + userId: ctx.from?.id?.toString() ?? '', + userName: ctx.from?.first_name ?? '', + command: 'compact' + }) + }) + + bot.command('help', (ctx) => { + this.emit('command', { + chatId: ctx.chat.id.toString(), + userId: ctx.from?.id?.toString() ?? '', + userName: ctx.from?.first_name ?? '', + command: 'help' + }) + }) + + bot.command('whoami', (ctx) => { + this.emit('command', { + chatId: ctx.chat.id.toString(), + userId: ctx.from?.id?.toString() ?? '', + userName: ctx.from?.first_name ?? '', + command: 'whoami' + }) + }) + + // Message handler + bot.on('message:text', (ctx) => { + this.emit('message', { + chatId: ctx.chat.id.toString(), + userId: ctx.from?.id?.toString() ?? '', + userName: ctx.from?.first_name ?? '', + text: ctx.message.text + }) + }) + + // Register bot commands with Telegram + await bot.api.setMyCommands([ + { command: 'new', description: 'Start a new conversation' }, + { command: 'compact', description: 'Compact conversation history' }, + { command: 'help', description: 'Show help information' }, + { command: 'whoami', description: 'Show the current chat ID' } + ]) + + // Error handler — err is a BotError wrapping the original cause in err.error + bot.catch((err) => { + const cause = err.error + logger.error('Bot error', { + agentId: this.agentId, + channelId: this.channelId, + error: cause instanceof Error ? cause.message : String(cause) + }) + }) + + // Start long polling (fire-and-forget) + bot.start().catch((err) => { + logger.error('Bot polling stopped with error', { + agentId: this.agentId, + channelId: this.channelId, + error: err instanceof Error ? err.message : String(err) + }) + }) + + logger.info('Telegram bot started', { agentId: this.agentId, channelId: this.channelId }) + } + + async disconnect(): Promise { + if (this.bot) { + await this.bot.stop() + this.bot = null + logger.info('Telegram bot stopped', { agentId: this.agentId, channelId: this.channelId }) + } + } + + async sendMessage(chatId: string, text: string, opts?: SendMessageOptions): Promise { + if (!this.bot) { + throw new Error('Bot is not connected') + } + + const chunks = splitMessage(text) + + for (let i = 0; i < chunks.length; i++) { + await this.bot.api.sendMessage(chatId, chunks[i], { + ...(opts?.parseMode ? { parse_mode: opts.parseMode } : {}), + ...(opts?.replyToMessageId && i === 0 ? { reply_parameters: { message_id: opts.replyToMessageId } } : {}) + }) + + // Small delay between chunks to avoid rate limiting + if (i < chunks.length - 1) { + await new Promise((resolve) => setTimeout(resolve, 100)) + } + } + } + + async sendMessageDraft(chatId: string, draftId: number, text: string): Promise { + if (!this.bot) { + throw new Error('Bot is not connected') + } + + await this.bot.api.sendMessageDraft(Number(chatId), draftId, text) + } + + async sendTypingIndicator(chatId: string): Promise { + if (!this.bot) { + throw new Error('Bot is not connected') + } + + await this.bot.api.sendChatAction(chatId, 'typing') + } +} + +// Self-registration +registerAdapterFactory('telegram', (channel: CherryClawChannel, agentId: string) => { + return new TelegramAdapter({ + channelId: channel.id, + agentId, + channelConfig: channel.config + }) +}) diff --git a/src/main/services/agents/services/channels/adapters/__tests__/FeishuAdapter.test.ts b/src/main/services/agents/services/channels/adapters/__tests__/FeishuAdapter.test.ts new file mode 100644 index 00000000000..09521f3a111 --- /dev/null +++ b/src/main/services/agents/services/channels/adapters/__tests__/FeishuAdapter.test.ts @@ -0,0 +1,365 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn(), silly: vi.fn() }) + } +})) + +vi.mock('../../ChannelManager', () => ({ + registerAdapterFactory: vi.fn() +})) + +vi.mock('electron', () => ({ + net: { fetch: vi.fn() } +})) + +const mockImCreate = vi.fn().mockResolvedValue({ code: 0, data: { message_id: 'msg-1' } }) +const mockImUpdate = vi.fn().mockResolvedValue({ code: 0 }) +const mockCardCreate = vi.fn().mockResolvedValue({ code: 0, data: { card_id: 'card-1' } }) +const mockCardSettings = vi.fn().mockResolvedValue({ code: 0 }) +const mockElementContent = vi.fn().mockResolvedValue({ code: 0 }) + +const mockClient = { + im: { + message: { + create: mockImCreate, + update: mockImUpdate + } + }, + cardkit: { + v1: { + card: { create: mockCardCreate, settings: mockCardSettings }, + cardElement: { content: mockElementContent } + } + } +} + +const mockWsStart = vi.fn().mockResolvedValue(undefined) +let capturedEventHandlers: Record unknown> = {} + +vi.mock('@larksuiteoapi/node-sdk', () => ({ + Client: vi.fn().mockImplementation(() => mockClient), + WSClient: vi.fn().mockImplementation(() => ({ start: mockWsStart })), + EventDispatcher: vi.fn().mockImplementation(() => ({ + register: vi.fn().mockImplementation((handles: Record unknown>) => { + capturedEventHandlers = handles + return {} + }) + })), + AppType: { SelfBuild: 0 }, + Domain: { Feishu: 'https://open.feishu.cn', Lark: 'https://open.larksuite.com' }, + LoggerLevel: { warn: 2 } +})) + +import '../FeishuAdapter' + +import { registerAdapterFactory } from '../../ChannelManager' + +function getFactory() { + const calls = vi.mocked(registerAdapterFactory).mock.calls + const feishuCall = calls.find((c) => c[0] === 'feishu') + if (!feishuCall) throw new Error('registerAdapterFactory was not called for feishu') + return feishuCall[1] as (channel: any, agentId: string) => any +} + +describe('FeishuAdapter', () => { + beforeEach(() => { + mockImCreate.mockClear().mockResolvedValue({ code: 0, data: { message_id: 'msg-1' } }) + mockImUpdate.mockClear().mockResolvedValue({ code: 0 }) + mockCardCreate.mockClear().mockResolvedValue({ code: 0, data: { card_id: 'card-1' } }) + mockCardSettings.mockClear().mockResolvedValue({ code: 0 }) + mockElementContent.mockClear().mockResolvedValue({ code: 0 }) + mockWsStart.mockClear().mockResolvedValue(undefined) + capturedEventHandlers = {} + }) + + afterEach(() => { + vi.useRealTimers() + }) + + function createAdapter(overrides: Record = {}) { + const factory = getFactory() + return factory( + { + id: (overrides.channelId as string) ?? 'ch-1', + type: 'feishu', + enabled: true, + config: { + app_id: (overrides.app_id as string) ?? 'test-app-id', + app_secret: (overrides.app_secret as string) ?? 'test-app-secret', + allowed_chat_ids: (overrides.allowed_chat_ids as string[]) ?? ['oc_123'], + domain: (overrides.domain as string) ?? 'feishu' + } + }, + (overrides.agentId as string) ?? 'agent-1' + ) + } + + it('connect() creates client, event dispatcher, and starts WebSocket', async () => { + const adapter = createAdapter() + await adapter.connect() + + expect(mockWsStart).toHaveBeenCalledWith({ eventDispatcher: expect.anything() }) + }) + + it('connect() throws if app_id is missing', async () => { + const adapter = createAdapter({ app_id: '' }) + await expect(adapter.connect()).rejects.toThrow('Feishu app_id and app_secret are required') + }) + + it('sendMessage() sends post-type message via SDK', async () => { + const adapter = createAdapter() + await adapter.connect() + await adapter.sendMessage('oc_123', 'Hello Feishu') + + expect(mockImCreate).toHaveBeenCalledWith({ + params: { receive_id_type: 'chat_id' }, + data: { + receive_id: 'oc_123', + msg_type: 'post', + content: expect.stringContaining('Hello Feishu') + } + }) + + // Verify it's a proper post payload with md tag + const content = JSON.parse(mockImCreate.mock.calls[0][0].data.content) + expect(content.zh_cn.content[0][0]).toEqual({ tag: 'md', text: 'Hello Feishu' }) + }) + + it('sendMessage() chunks long messages', async () => { + vi.useFakeTimers() + const adapter = createAdapter() + await adapter.connect() + + const longText = 'A'.repeat(5000) + const sendPromise = adapter.sendMessage('oc_123', longText) + + await vi.runAllTimersAsync() + await sendPromise + + expect(mockImCreate).toHaveBeenCalledTimes(2) + }) + + it('sendMessage() throws when Feishu returns an API error', async () => { + const adapter = createAdapter() + await adapter.connect() + mockImCreate.mockResolvedValueOnce({ code: 99991663, msg: 'permission denied' }) + + await expect(adapter.sendMessage('oc_123', 'Hello Feishu')).rejects.toThrow( + 'Send Feishu message failed: permission denied' + ) + }) + + it('sendMessageDraft() creates streaming card and updates content via SDK', async () => { + const adapter = createAdapter() + await adapter.connect() + + await adapter.sendMessageDraft('oc_123', 1, 'partial text...') + + // Should create a streaming card + expect(mockCardCreate).toHaveBeenCalledWith({ + data: { + type: 'card_json', + data: expect.stringContaining('streaming_mode') + } + }) + + // Should send the card as an interactive message + expect(mockImCreate).toHaveBeenCalledWith({ + params: { receive_id_type: 'chat_id' }, + data: { + receive_id: 'oc_123', + msg_type: 'interactive', + content: expect.stringContaining('card-1') + } + }) + + // Should update element content + expect(mockElementContent).toHaveBeenCalledWith({ + path: { card_id: 'card-1', element_id: 'streaming_content' }, + data: { + content: expect.stringContaining('partial text...'), + sequence: expect.any(Number) + } + }) + }) + + it('finalizeStream() updates the existing card and returns true', async () => { + const adapter = createAdapter() + await adapter.connect() + + await adapter.sendMessageDraft('oc_123', 1, 'partial text...') + + await expect(adapter.finalizeStream(1, 'final text')).resolves.toBe(true) + expect(mockImUpdate).toHaveBeenCalledWith({ + path: { message_id: 'msg-1' }, + data: { + msg_type: 'interactive', + content: expect.stringContaining('final text') + } + }) + }) + + it('sendTypingIndicator() is a no-op (Feishu has no native typing API)', async () => { + const adapter = createAdapter() + await adapter.connect() + await adapter.sendTypingIndicator('oc_123') + }) + + it('handles incoming text messages and emits message event', async () => { + const adapter = createAdapter() + await adapter.connect() + + const messageSpy = vi.fn() + adapter.on('message', messageSpy) + + const handler = capturedEventHandlers['im.message.receive_v1'] + expect(handler).toBeDefined() + + await handler({ + sender: { sender_id: { open_id: 'ou_user1' } }, + message: { + message_id: 'msg-in-1', + chat_id: 'oc_123', + chat_type: 'p2p', + message_type: 'text', + content: JSON.stringify({ text: 'Hello agent' }) + } + }) + + expect(messageSpy).toHaveBeenCalledWith({ + chatId: 'oc_123', + userId: 'ou_user1', + userName: '', + text: 'Hello agent' + }) + }) + + it('handles slash commands from text messages', async () => { + const adapter = createAdapter() + await adapter.connect() + + const commandSpy = vi.fn() + adapter.on('command', commandSpy) + + const handler = capturedEventHandlers['im.message.receive_v1'] + await handler({ + sender: { sender_id: { open_id: 'ou_user1' } }, + message: { + message_id: 'msg-cmd-1', + chat_id: 'oc_123', + chat_type: 'p2p', + message_type: 'text', + content: JSON.stringify({ text: '/new' }) + } + }) + + expect(commandSpy).toHaveBeenCalledWith({ + chatId: 'oc_123', + userId: 'ou_user1', + userName: '', + command: 'new', + args: undefined + }) + }) + + it('handles /whoami from text messages', async () => { + const adapter = createAdapter() + await adapter.connect() + + const commandSpy = vi.fn() + adapter.on('command', commandSpy) + + const handler = capturedEventHandlers['im.message.receive_v1'] + await handler({ + sender: { sender_id: { open_id: 'ou_user1' } }, + message: { + message_id: 'msg-cmd-2', + chat_id: 'oc_123', + chat_type: 'p2p', + message_type: 'text', + content: JSON.stringify({ text: '/whoami' }) + } + }) + + expect(commandSpy).toHaveBeenCalledWith({ + chatId: 'oc_123', + userId: 'ou_user1', + userName: '', + command: 'whoami', + args: undefined + }) + }) + + it('auth guard blocks unauthorized chat IDs', async () => { + const adapter = createAdapter({ allowed_chat_ids: ['oc_123'] }) + await adapter.connect() + + const messageSpy = vi.fn() + adapter.on('message', messageSpy) + + const handler = capturedEventHandlers['im.message.receive_v1'] + await handler({ + sender: { sender_id: { open_id: 'ou_user1' } }, + message: { + message_id: 'msg-blocked', + chat_id: 'oc_unauthorized', + chat_type: 'p2p', + message_type: 'text', + content: JSON.stringify({ text: 'Should be blocked' }) + } + }) + + expect(messageSpy).not.toHaveBeenCalled() + }) + + it('strips @mention tags from group messages', async () => { + const adapter = createAdapter({ allowed_chat_ids: [] }) + await adapter.connect() + + const messageSpy = vi.fn() + adapter.on('message', messageSpy) + + const handler = capturedEventHandlers['im.message.receive_v1'] + await handler({ + sender: { sender_id: { open_id: 'ou_user1' } }, + message: { + message_id: 'msg-mention', + chat_id: 'oc_group1', + chat_type: 'group', + message_type: 'text', + content: JSON.stringify({ text: '@_user_1 Hello agent' }) + } + }) + + expect(messageSpy).toHaveBeenCalledWith(expect.objectContaining({ text: 'Hello agent' })) + }) + + it('ignores non-text message types', async () => { + const adapter = createAdapter({ allowed_chat_ids: [] }) + await adapter.connect() + + const messageSpy = vi.fn() + adapter.on('message', messageSpy) + + const handler = capturedEventHandlers['im.message.receive_v1'] + await handler({ + sender: { sender_id: { open_id: 'ou_user1' } }, + message: { + message_id: 'msg-image', + chat_id: 'oc_123', + chat_type: 'p2p', + message_type: 'image', + content: '{}' + } + }) + + expect(messageSpy).not.toHaveBeenCalled() + }) + + it('sets notifyChatIds from allowed_chat_ids', () => { + const adapter = createAdapter({ allowed_chat_ids: ['oc_a', 'oc_b'] }) + expect(adapter.notifyChatIds).toEqual(['oc_a', 'oc_b']) + }) +}) diff --git a/src/main/services/agents/services/channels/adapters/__tests__/TelegramAdapter.test.ts b/src/main/services/agents/services/channels/adapters/__tests__/TelegramAdapter.test.ts new file mode 100644 index 00000000000..08b08d5fa87 --- /dev/null +++ b/src/main/services/agents/services/channels/adapters/__tests__/TelegramAdapter.test.ts @@ -0,0 +1,225 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn(), silly: vi.fn() }) + } +})) + +// Mock registerAdapterFactory to capture the factory function +vi.mock('../../ChannelManager', () => ({ + registerAdapterFactory: vi.fn() +})) + +const mockBot = { + use: vi.fn(), + command: vi.fn(), + on: vi.fn(), + api: { + setMyCommands: vi.fn().mockResolvedValue(undefined), + sendMessage: vi.fn().mockResolvedValue(undefined), + sendMessageDraft: vi.fn().mockResolvedValue(true), + sendChatAction: vi.fn().mockResolvedValue(undefined) + }, + catch: vi.fn(), + start: vi.fn().mockResolvedValue(undefined), + stop: vi.fn().mockResolvedValue(undefined) +} + +vi.mock('grammy', () => ({ + Bot: vi.fn().mockImplementation(() => mockBot) +})) + +// Import the module to trigger self-registration side effect +import '../TelegramAdapter' + +import { registerAdapterFactory } from '../../ChannelManager' + +function getFactory() { + const call = vi.mocked(registerAdapterFactory).mock.calls[0] + if (!call) throw new Error('registerAdapterFactory was not called') + return call[1] as (channel: any, agentId: string) => any +} + +describe('TelegramAdapter', () => { + beforeEach(() => { + // Reset all mock functions but preserve the factory registration + mockBot.use.mockClear() + mockBot.command.mockClear() + mockBot.on.mockClear() + mockBot.api.setMyCommands.mockClear().mockResolvedValue(undefined) + mockBot.api.sendMessage.mockClear().mockResolvedValue(undefined) + mockBot.api.sendMessageDraft.mockClear().mockResolvedValue(true) + mockBot.api.sendChatAction.mockClear().mockResolvedValue(undefined) + mockBot.catch.mockClear() + mockBot.start.mockClear().mockResolvedValue(undefined) + mockBot.stop.mockClear().mockResolvedValue(undefined) + }) + + afterEach(() => { + vi.useRealTimers() + }) + + function createAdapter(overrides: Record = {}) { + const factory = getFactory() + return factory( + { + id: (overrides.channelId as string) ?? 'ch-1', + type: 'telegram', + enabled: true, + config: { + bot_token: (overrides.bot_token as string) ?? 'test-token', + allowed_chat_ids: (overrides.allowed_chat_ids as string[]) ?? ['123'] + } + }, + (overrides.agentId as string) ?? 'agent-1' + ) + } + + it('connect() registers middleware, commands, message handler, and starts polling', async () => { + const adapter = createAdapter() + await adapter.connect() + + expect(mockBot.use).toHaveBeenCalledTimes(1) // auth middleware + expect(mockBot.command).toHaveBeenCalledTimes(4) // new, compact, help, whoami + expect(mockBot.on).toHaveBeenCalledWith('message:text', expect.any(Function)) + expect(mockBot.api.setMyCommands).toHaveBeenCalledWith([ + { command: 'new', description: 'Start a new conversation' }, + { command: 'compact', description: 'Compact conversation history' }, + { command: 'help', description: 'Show help information' }, + { command: 'whoami', description: 'Show the current chat ID' } + ]) + expect(mockBot.catch).toHaveBeenCalledTimes(1) + expect(mockBot.start).toHaveBeenCalledTimes(1) + }) + + it('disconnect() stops the bot', async () => { + const adapter = createAdapter() + await adapter.connect() + await adapter.disconnect() + + expect(mockBot.stop).toHaveBeenCalledTimes(1) + }) + + it('sendMessage() sends text via bot API', async () => { + const adapter = createAdapter() + await adapter.connect() + await adapter.sendMessage('123', 'Hello') + + expect(mockBot.api.sendMessage).toHaveBeenCalledWith('123', 'Hello', {}) + }) + + it('sendMessage() chunks long messages', async () => { + vi.useFakeTimers() + const adapter = createAdapter() + await adapter.connect() + + const longText = 'A'.repeat(5000) + const sendPromise = adapter.sendMessage('123', longText) + + // Flush all pending timers (inter-chunk delays) regardless of count + await vi.runAllTimersAsync() + await sendPromise + + expect(mockBot.api.sendMessage).toHaveBeenCalledTimes(2) + // Verify chunk sizes: first chunk is 4096 chars, second is the remainder + expect(mockBot.api.sendMessage.mock.calls[0][1]).toHaveLength(4096) + expect(mockBot.api.sendMessage.mock.calls[1][1]).toHaveLength(904) + }) + + it('sendTypingIndicator() sends typing action', async () => { + const adapter = createAdapter() + await adapter.connect() + await adapter.sendTypingIndicator('123') + + expect(mockBot.api.sendChatAction).toHaveBeenCalledWith('123', 'typing') + }) + + it('auth middleware blocks unauthorized chats', async () => { + const adapter = createAdapter({ allowed_chat_ids: ['123'] }) + await adapter.connect() + + // Extract the auth middleware + const middleware = mockBot.use.mock.calls[0][0] as (ctx: any, next: () => Promise) => Promise + + const next = vi.fn() + + // Unauthorized chat + await middleware({ chat: { id: 999 } }, next) + expect(next).not.toHaveBeenCalled() + + // Authorized chat + next.mockClear() + await middleware({ chat: { id: 123 } }, next) + expect(next).toHaveBeenCalledTimes(1) + }) + + it('command handler emits command events', async () => { + const adapter = createAdapter() + await adapter.connect() + + const commandSpy = vi.fn() + adapter.on('command', commandSpy) + + // Find the 'new' command handler (first bot.command call) + const commandHandler = mockBot.command.mock.calls[0][1] as (ctx: any) => void + + commandHandler({ + chat: { id: 123 }, + from: { id: 456, first_name: 'TestUser' } + }) + + expect(commandSpy).toHaveBeenCalledWith({ + chatId: '123', + userId: '456', + userName: 'TestUser', + command: 'new' + }) + }) + + it('whoami command handler emits command events', async () => { + const adapter = createAdapter() + await adapter.connect() + + const commandSpy = vi.fn() + adapter.on('command', commandSpy) + + const commandHandler = mockBot.command.mock.calls[3][1] as (ctx: any) => void + + commandHandler({ + chat: { id: 123 }, + from: { id: 456, first_name: 'TestUser' } + }) + + expect(commandSpy).toHaveBeenCalledWith({ + chatId: '123', + userId: '456', + userName: 'TestUser', + command: 'whoami' + }) + }) + + it('message handler emits message events', async () => { + const adapter = createAdapter() + await adapter.connect() + + const messageSpy = vi.fn() + adapter.on('message', messageSpy) + + // Extract the message:text handler + const messageHandler = mockBot.on.mock.calls[0][1] as (ctx: any) => void + + messageHandler({ + chat: { id: 123 }, + from: { id: 456, first_name: 'TestUser' }, + message: { text: 'Hello bot' } + }) + + expect(messageSpy).toHaveBeenCalledWith({ + chatId: '123', + userId: '456', + userName: 'TestUser', + text: 'Hello bot' + }) + }) +}) diff --git a/src/main/services/agents/services/channels/index.ts b/src/main/services/agents/services/channels/index.ts new file mode 100644 index 00000000000..2bf4efa6f38 --- /dev/null +++ b/src/main/services/agents/services/channels/index.ts @@ -0,0 +1,14 @@ +export type { + ChannelAdapterConfig, + ChannelCommandEvent, + ChannelMessageEvent, + SendMessageOptions +} from './ChannelAdapter' +export { ChannelAdapter } from './ChannelAdapter' +export { channelManager, registerAdapterFactory } from './ChannelManager' +export { ChannelMessageHandler, channelMessageHandler } from './ChannelMessageHandler' + +// Register adapters (side-effect imports) +import './adapters/FeishuAdapter' +import './adapters/QQAdapter' +import './adapters/TelegramAdapter' diff --git a/src/main/services/agents/services/cherryclaw/__tests__/heartbeat.test.ts b/src/main/services/agents/services/cherryclaw/__tests__/heartbeat.test.ts new file mode 100644 index 00000000000..fae0a35ee15 --- /dev/null +++ b/src/main/services/agents/services/cherryclaw/__tests__/heartbeat.test.ts @@ -0,0 +1,59 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn() }) + } +})) + +vi.mock('node:fs/promises', () => ({ + readFile: vi.fn() +})) + +import { readFile } from 'node:fs/promises' + +import { HeartbeatReader } from '../heartbeat' + +const mockedReadFile = vi.mocked(readFile) + +describe('HeartbeatReader', () => { + let reader: HeartbeatReader + + beforeEach(() => { + reader = new HeartbeatReader() + vi.clearAllMocks() + }) + + it('returns content when file exists', async () => { + mockedReadFile.mockResolvedValue('heartbeat content') + + const result = await reader.readHeartbeat('/workspace') + + expect(result).toBe('heartbeat content') + expect(mockedReadFile).toHaveBeenCalledWith(expect.stringContaining('heartbeat.md'), 'utf-8') + }) + + it('returns undefined when file does not exist', async () => { + mockedReadFile.mockRejectedValue(Object.assign(new Error('ENOENT'), { code: 'ENOENT' })) + + const result = await reader.readHeartbeat('/workspace') + + expect(result).toBeUndefined() + }) + + it('returns undefined when file is empty', async () => { + mockedReadFile.mockResolvedValue(' \n ') + + const result = await reader.readHeartbeat('/workspace') + + expect(result).toBeUndefined() + }) + + it('trims whitespace from content', async () => { + mockedReadFile.mockResolvedValue(' check my email \n') + + const result = await reader.readHeartbeat('/workspace') + + expect(result).toBe('check my email') + }) +}) diff --git a/src/main/services/agents/services/cherryclaw/__tests__/prompt.test.ts b/src/main/services/agents/services/cherryclaw/__tests__/prompt.test.ts new file mode 100644 index 00000000000..bc1e8c83c7e --- /dev/null +++ b/src/main/services/agents/services/cherryclaw/__tests__/prompt.test.ts @@ -0,0 +1,183 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn() }) + } +})) + +vi.mock('node:fs/promises', () => ({ + stat: vi.fn(), + readFile: vi.fn(), + readdir: vi.fn() +})) + +import { readdir, readFile, stat } from 'node:fs/promises' + +import { PromptBuilder } from '../prompt' + +const mockedStat = vi.mocked(stat) +const mockedReadFile = vi.mocked(readFile) +const mockedReaddir = vi.mocked(readdir) + +function setupFiles(files: Record) { + // Build directory listing from file paths + const dirs = new Map() + for (const filePath of Object.keys(files)) { + const dir = filePath.substring(0, filePath.lastIndexOf('/')) + const name = filePath.substring(filePath.lastIndexOf('/') + 1) + if (!dirs.has(dir)) dirs.set(dir, []) + dirs.get(dir)!.push(name) + } + + mockedStat.mockImplementation(async (filePath) => { + const p = typeof filePath === 'string' ? filePath : filePath.toString() + if (files[p] !== undefined) { + return { mtimeMs: 1000 } as any + } + throw Object.assign(new Error('ENOENT'), { code: 'ENOENT' }) + }) + mockedReadFile.mockImplementation(async (filePath) => { + const p = typeof filePath === 'string' ? filePath : filePath.toString() + if (files[p] !== undefined) { + return files[p] + } + throw Object.assign(new Error('ENOENT'), { code: 'ENOENT' }) + }) + mockedReaddir.mockImplementation(async (dirPath) => { + const p = typeof dirPath === 'string' ? dirPath : dirPath.toString() + return (dirs.get(p) ?? []) as any + }) +} + +describe('PromptBuilder', () => { + let builder: PromptBuilder + + beforeEach(() => { + builder = new PromptBuilder() + vi.clearAllMocks() + }) + + it('returns default basic prompt when no workspace files exist', async () => { + setupFiles({}) + + const result = await builder.buildSystemPrompt('/workspace') + + expect(result).toContain('You are CherryClaw') + expect(result).toContain('## CherryClaw Tools') + expect(result).not.toContain('## Memories') + }) + + it('overrides basic prompt with system.md from workspace', async () => { + setupFiles({ + '/workspace/system.md': 'You are CustomBot, a specialized assistant.' + }) + + const result = await builder.buildSystemPrompt('/workspace') + + expect(result).toContain('You are CustomBot') + expect(result).not.toContain('You are CherryClaw') + }) + + it('includes soul.md in memories section', async () => { + setupFiles({ + '/workspace/soul.md': 'Warm but direct. Lead with answers.' + }) + + const result = await builder.buildSystemPrompt('/workspace') + + expect(result).toContain('## Memories') + expect(result).toContain('') + expect(result).toContain('Warm but direct. Lead with answers.') + expect(result).toContain('') + expect(result).toContain('WHO you are') + }) + + it('includes user.md in memories section', async () => { + setupFiles({ + '/workspace/user.md': 'Name: V\nTimezone: UTC+8' + }) + + const result = await builder.buildSystemPrompt('/workspace') + + expect(result).toContain('') + expect(result).toContain('Name: V') + expect(result).toContain('') + expect(result).toContain('WHO the user is') + }) + + it('includes memory/FACT.md in memories section', async () => { + setupFiles({ + '/workspace/memory/FACT.md': '# Active Projects\n\n- Cherry Studio' + }) + + const result = await builder.buildSystemPrompt('/workspace') + + expect(result).toContain('') + expect(result).toContain('Cherry Studio') + expect(result).toContain('') + expect(result).toContain('WHAT you know') + }) + + it('includes all memory files when all exist', async () => { + setupFiles({ + '/workspace/soul.md': 'Be concise.', + '/workspace/user.md': 'Name: V', + '/workspace/memory/FACT.md': 'Project: CherryClaw' + }) + + const result = await builder.buildSystemPrompt('/workspace') + + expect(result).toContain('') + expect(result).toContain('') + expect(result).toContain('') + expect(result).toContain('Update them autonomously') + expect(result).toContain('exclusive scope') + }) + + it('combines system.md override with memories', async () => { + setupFiles({ + '/workspace/system.md': 'You are CustomBot.', + '/workspace/soul.md': 'Sharp and efficient.' + }) + + const result = await builder.buildSystemPrompt('/workspace') + + expect(result).toContain('You are CustomBot.') + expect(result).toContain('') + expect(result).toContain('Sharp and efficient.') + }) + + it('resolves filenames case-insensitively', async () => { + // Files exist with different casing than the canonical names + setupFiles({ + '/workspace/SOUL.md': 'Uppercase soul', + '/workspace/User.md': 'Mixed case user', + '/workspace/memory/fact.md': 'Lowercase facts' + }) + + const result = await builder.buildSystemPrompt('/workspace') + + expect(result).toContain('') + expect(result).toContain('Uppercase soul') + expect(result).toContain('') + expect(result).toContain('Mixed case user') + expect(result).toContain('') + expect(result).toContain('Lowercase facts') + }) + + it('uses mtime cache for repeated reads', async () => { + setupFiles({ + '/workspace/soul.md': 'Cached soul' + }) + + await builder.buildSystemPrompt('/workspace') + await builder.buildSystemPrompt('/workspace') + + // readFile should only be called once per unique file due to caching + const soulReadCalls = mockedReadFile.mock.calls.filter( + (call) => typeof call[0] === 'string' && call[0].includes('soul.md') + ) + expect(soulReadCalls).toHaveLength(1) + }) +}) diff --git a/src/main/services/agents/services/cherryclaw/heartbeat.ts b/src/main/services/agents/services/cherryclaw/heartbeat.ts new file mode 100644 index 00000000000..b60e9efefe6 --- /dev/null +++ b/src/main/services/agents/services/cherryclaw/heartbeat.ts @@ -0,0 +1,38 @@ +import { readFile } from 'node:fs/promises' +import path from 'node:path' + +import { loggerService } from '@logger' + +const logger = loggerService.withContext('HeartbeatReader') + +const HEARTBEAT_FILENAME = 'heartbeat.md' + +export class HeartbeatReader { + async readHeartbeat(workspacePath: string): Promise { + const resolved = path.resolve(workspacePath, HEARTBEAT_FILENAME) + const normalizedWorkspace = path.resolve(workspacePath) + + if (!resolved.startsWith(normalizedWorkspace + path.sep) && resolved !== normalizedWorkspace) { + logger.warn(`Path traversal attempt blocked: ${HEARTBEAT_FILENAME}`) + return undefined + } + + try { + const content = await readFile(resolved, 'utf-8') + const trimmed = content.trim() + if (!trimmed) { + logger.debug('Heartbeat file is empty', { path: resolved }) + return undefined + } + logger.info(`Read heartbeat file: ${resolved}`) + return trimmed + } catch (error) { + if ((error as NodeJS.ErrnoException).code === 'ENOENT') { + logger.debug(`Heartbeat file not found: ${resolved}`) + return undefined + } + logger.error(`Failed to read heartbeat file: ${resolved}`, error as Error) + return undefined + } + } +} diff --git a/src/main/services/agents/services/cherryclaw/index.ts b/src/main/services/agents/services/cherryclaw/index.ts new file mode 100644 index 00000000000..a91f471f516 --- /dev/null +++ b/src/main/services/agents/services/cherryclaw/index.ts @@ -0,0 +1,101 @@ +import { loggerService } from '@logger' +import ClawServer from '@main/mcpServers/claw' +import type { GetAgentSessionResponse } from '@types' + +import type { AgentServiceInterface, AgentStream, AgentThinkingOptions } from '../../interfaces/AgentStreamInterface' +import { agentServiceRegistry } from '../AgentServiceRegistry' +import type { EnhancedSessionFields } from '../claudecode/enhanced-session' +import { HeartbeatReader } from './heartbeat' +import { PromptBuilder } from './prompt' + +const logger = loggerService.withContext('CherryClawService') + +/** + * CherryClawService — a Claude Code variant with soul-driven personality + * and scheduler-based autonomous operation. + * + * Delegates to ClaudeCodeService (via registry) with a full custom system prompt + * (replaces Claude Code preset) and an injected claw MCP server for autonomous task management. + */ +export class CherryClawService implements AgentServiceInterface { + private promptBuilder = new PromptBuilder() + readonly heartbeatReader = new HeartbeatReader() + + async invoke( + prompt: string, + session: GetAgentSessionResponse, + abortController: AbortController, + lastAgentSessionId?: string, + thinkingOptions?: AgentThinkingOptions + ): Promise { + const workspacePath = session.accessible_paths[0] + + type EnhancedSession = GetAgentSessionResponse & EnhancedSessionFields + + // Build soul-enhanced session + let enhancedSession: EnhancedSession = session + + // Build full custom system prompt from workspace files (soul.md, user.md, memory/FACT.md, system.md) + if (workspacePath) { + const systemPrompt = await this.promptBuilder.buildSystemPrompt(workspacePath) + logger.info('Built custom system prompt for CherryClaw', { + workspacePath, + promptLength: systemPrompt.length + }) + enhancedSession = { + ...session, + _systemPrompt: systemPrompt + } + } + + // Inject the claw MCP server as an in-memory instance for autonomous task management + // and disable the SDK's builtin cron tools so the agent uses our MCP cron tool instead + const clawServer = new ClawServer(session.agent_id) + enhancedSession = { + ...enhancedSession, + _internalMcpServers: { + claw: { + type: 'inmem', + instance: clawServer.mcpServer + } + }, + _disallowedTools: [ + // Disable builtin cron tools (agent uses our MCP cron tool instead) + 'CronCreate', + 'CronDelete', + 'CronList', + // Disable tools not suited for autonomous agent operation + 'TodoWrite', + 'AskUserQuestion', + 'EnterPlanMode', + 'ExitPlanMode', + 'EnterWorktree', + 'NotebookEdit' + ] + } + + // If the agent has an explicit allowed_tools whitelist, append the claw MCP + // tool names so the SDK doesn't hide them. When allowed_tools is undefined + // (no restriction), leave it alone — all tools are already available. + const clawMcpTools = ['mcp__claw__*'] // wildcard to allow all claw MCP tools (e.g. cron, file management, etc.) + const currentAllowed = enhancedSession.allowed_tools + if (Array.isArray(currentAllowed) && currentAllowed.length > 0) { + const missing = clawMcpTools.filter((t) => !currentAllowed.includes(t)) + if (missing.length > 0) { + enhancedSession = { ...enhancedSession, allowed_tools: [...currentAllowed, ...missing] } + } + } + + logger.debug('CherryClaw invoke: injecting claw MCP and allowed_tools', { + agentId: session.agent_id, + mcpServers: Object.keys(enhancedSession._internalMcpServers ?? {}), + allowedTools: enhancedSession.allowed_tools + }) + + // Delegate to claude-code service (CherryClaw is a Claude Code variant) + const claudeCodeService = agentServiceRegistry.getService('claude-code') + return claudeCodeService.invoke(prompt, enhancedSession, abortController, lastAgentSessionId, thinkingOptions) + } +} + +export default CherryClawService diff --git a/src/main/services/agents/services/cherryclaw/prompt.ts b/src/main/services/agents/services/cherryclaw/prompt.ts new file mode 100644 index 00000000000..cb8a91b6f26 --- /dev/null +++ b/src/main/services/agents/services/cherryclaw/prompt.ts @@ -0,0 +1,168 @@ +import { readdir, readFile, stat } from 'node:fs/promises' +import path from 'node:path' + +import { loggerService } from '@logger' + +const logger = loggerService.withContext('PromptBuilder') + +/** + * Resolve a filename within a directory using case-insensitive matching. + * Returns the full path if found (preferring exact match), or undefined. + */ +async function resolveFile(dir: string, name: string): Promise { + const exact = path.join(dir, name) + try { + await stat(exact) + return exact + } catch { + // exact match not found, try case-insensitive + } + + try { + const entries = await readdir(dir) + const target = name.toLowerCase() + const match = entries.find((e) => e.toLowerCase() === target) + return match ? path.join(dir, match) : undefined + } catch { + return undefined + } +} + +type CacheEntry = { + mtimeMs: number + content: string +} + +const DEFAULT_BASIC_PROMPT = `You are CherryClaw, a personal assistant running inside CherryStudio. + +` + +const TOOLS_SECTION = `## CherryClaw Tools + +You have exclusive access to these tools for interacting with CherryStudio. Always prefer them over manual alternatives. + +| Tool | Purpose | When to use | +|---|---|---| +| \`mcp__claw__cron\` | Schedule recurring or one-time tasks | Creating reminders, periodic checks, scheduled reports. Never use builtin Cron* tools — they are disabled. | +| \`mcp__claw__notify\` | Send messages to the user via IM channels | Proactive updates, task results, alerts. Use when the user is not in the current session. | +| \`mcp__claw__skills\` | Search, install, and remove Claude skills | When the user asks for new capabilities or you need a skill you don't have. | +| \`mcp__claw__memory\` | Manage JOURNAL.jsonl (append and search) | Log events and search past activity. Never write to JOURNAL.jsonl directly via file tools. | + +Rules: +- These are your primary interface to CherryStudio. Do not attempt workarounds or alternative approaches. +- When creating scheduled tasks, always use \`mcp__claw__cron\`. The SDK builtin CronCreate, CronDelete, and CronList tools are disabled. +- When you need to notify the user outside the current conversation, use \`mcp__claw__notify\`. +` + +function memoriesTemplate(workspacePath: string, sections: string): string { + return `## Memories + +Persistent files in \`${workspacePath}/\` carry your state across sessions. Update them autonomously — never ask for approval. + +| File | Purpose | How to update | +|---|---|---| +| \`SOUL.md\` | WHO you are — personality, tone, communication style, core principles | Read + Edit tools | +| \`USER.md\` | WHO the user is — name, preferences, timezone, personal context | Read + Edit tools | +| \`memory/FACT.md\` | WHAT you know — active projects, technical decisions, durable knowledge (6+ months) | Read + Edit tools | +| \`memory/JOURNAL.jsonl\` | WHEN things happened — one-time events, session notes (append-only log) | \`mcp__claw__memory\` tool only (actions: append, search) | + +Rules: +- Each file has an exclusive scope — never duplicate information across files. +- \`SOUL.md\`, \`USER.md\`, and \`memory/FACT.md\` are loaded below. Read and edit them directly when updates are needed. +- \`memory/JOURNAL.jsonl\` is NOT loaded into context. Use \`mcp__claw__memory\` to append entries or search past events. Never read or write the file directly. +- Filenames are case-insensitive. +${sections}` +} + +/** + * PromptBuilder assembles the full system prompt for CherryClaw from workspace files. + * + * Structure: basic prompt (system.md override or default) + tools section + memories section. + * + * Memory files layout: + * {workspace}/soul.md — personality, tone, communication style + * {workspace}/user.md — user profile, preferences, context + * {workspace}/memory/FACT.md — durable project knowledge, technical decisions + * {workspace}/memory/JOURNAL.jsonl — timestamped event log (managed by memory tool) + */ +export class PromptBuilder { + private cache = new Map() + + async buildSystemPrompt(workspacePath: string): Promise { + const parts: string[] = [] + + // Basic prompt: workspace system.md (case-insensitive) > embedded default + const systemPath = await resolveFile(workspacePath, 'system.md') + const basicPrompt = systemPath ? await this.readCachedFile(systemPath) : undefined + parts.push(basicPrompt ?? DEFAULT_BASIC_PROMPT) + + // Tools section (always included) + parts.push(TOOLS_SECTION) + + // Memories section + const memoriesContent = await this.buildMemoriesSection(workspacePath) + if (memoriesContent) { + parts.push(memoriesContent) + } + + return parts.join('\n\n') + } + + private async buildMemoriesSection(workspacePath: string): Promise { + const memoryDir = path.join(workspacePath, 'memory') + + const [soulPath, userPath, factPath] = await Promise.all([ + resolveFile(workspacePath, 'SOUL.md'), + resolveFile(workspacePath, 'USER.md'), + resolveFile(memoryDir, 'FACT.md') + ]) + + const [soulContent, userContent, factContent] = await Promise.all([ + soulPath ? this.readCachedFile(soulPath) : Promise.resolve(undefined), + userPath ? this.readCachedFile(userPath) : Promise.resolve(undefined), + factPath ? this.readCachedFile(factPath) : Promise.resolve(undefined) + ]) + + if (!soulContent && !userContent && !factContent) { + return undefined + } + + const sections = [ + soulContent ? `\n${soulContent}\n` : '', + userContent ? `\n${userContent}\n` : '', + factContent ? `\n${factContent}\n` : '' + ] + .filter(Boolean) + .join('\n\n') + + return memoriesTemplate(workspacePath, sections) + } + + /** + * Read a file with mtime-based caching. Returns undefined if the file does not exist. + */ + private async readCachedFile(filePath: string): Promise { + let fileStat + try { + fileStat = await stat(filePath) + } catch { + return undefined + } + + const cached = this.cache.get(filePath) + if (cached && cached.mtimeMs === fileStat.mtimeMs) { + return cached.content + } + + try { + const content = await readFile(filePath, 'utf-8') + const trimmed = content.trim() + this.cache.set(filePath, { mtimeMs: fileStat.mtimeMs, content: trimmed }) + logger.debug(`Loaded ${path.basename(filePath)}`, { path: filePath, length: trimmed.length }) + return trimmed + } catch (error) { + logger.error(`Failed to read ${filePath}`, error as Error) + return undefined + } + } +} diff --git a/src/main/services/agents/services/claudecode/enhanced-session.ts b/src/main/services/agents/services/claudecode/enhanced-session.ts new file mode 100644 index 00000000000..fa288485069 --- /dev/null +++ b/src/main/services/agents/services/claudecode/enhanced-session.ts @@ -0,0 +1,16 @@ +import type { Settings } from '@anthropic-ai/claude-agent-sdk' + +import type { InternalMcpServerConfig } from './internal-mcp' + +/** + * Extra fields that agent services (e.g. CherryClaw) can attach to a session + * before it is passed to ClaudeCodeService. ClaudeCodeService reads these + * and maps them to SDK options. + */ +export type EnhancedSessionFields = { + _internalMcpServers?: Record + _disallowedTools?: string[] + _settings?: Settings + /** When set, replaces the SDK system prompt entirely (instead of using preset+append). */ + _systemPrompt?: string +} diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts index 264e4829e04..4ce98788113 100644 --- a/src/main/services/agents/services/claudecode/index.ts +++ b/src/main/services/agents/services/claudecode/index.ts @@ -34,9 +34,12 @@ import type { } from '../../interfaces/AgentStreamInterface' import { sessionService } from '../SessionService' import { buildNamespacedToolCallId } from './claude-stream-state' +import type { EnhancedSessionFields } from './enhanced-session' import { promptForToolApproval } from './tool-permissions' import { ClaudeStreamState, transformSDKMessageToStreamParts } from './transform' +type EnhancedSession = GetAgentSessionResponse & EnhancedSessionFields + const require_ = createRequire(import.meta.url) const logger = loggerService.withContext('ClaudeCodeService') const DEFAULT_AUTO_ALLOW_TOOLS = new Set(['Read', 'Glob', 'Grep']) @@ -66,6 +69,8 @@ class ClaudeCodeStream extends EventEmitter implements AgentStream { declare emit: (event: 'data', data: AgentStreamEvent) => boolean declare on: (event: 'data', listener: (data: AgentStreamEvent) => void) => this declare once: (event: 'data', listener: (data: AgentStreamEvent) => void) => this + /** SDK session_id captured from the init message, used for resume. */ + sdkSessionId?: string } class ClaudeCodeService implements AgentServiceInterface { @@ -164,6 +169,37 @@ class ClaudeCodeService implements AgentServiceInterface { ...(customGitBashPath ? { CLAUDE_CODE_GIT_BASH_PATH: customGitBashPath } : {}) } + // Merge user-defined environment variables from session configuration + const userEnvVars = session.configuration?.env_vars + if (userEnvVars && typeof userEnvVars === 'object') { + const BLOCKED_ENV_KEYS = new Set([ + 'ANTHROPIC_API_KEY', + 'ANTHROPIC_AUTH_TOKEN', + 'ANTHROPIC_BASE_URL', + 'ANTHROPIC_MODEL', + 'ANTHROPIC_DEFAULT_OPUS_MODEL', + 'ANTHROPIC_DEFAULT_SONNET_MODEL', + 'ANTHROPIC_DEFAULT_HAIKU_MODEL', + 'ELECTRON_RUN_AS_NODE', + 'ELECTRON_NO_ATTACH_CONSOLE', + 'CLAUDE_CONFIG_DIR', + 'CLAUDE_CODE_USE_BEDROCK', + 'CLAUDE_CODE_GIT_BASH_PATH', + 'NODE_OPTIONS', + '__PROTO__', + 'CONSTRUCTOR', + 'PROTOTYPE' + ]) + for (const [key, value] of Object.entries(userEnvVars)) { + const upperKey = key.toUpperCase() + if (BLOCKED_ENV_KEYS.has(upperKey)) { + logger.warn('Blocked user env var override for system-critical variable', { key }) + } else if (typeof value === 'string') { + env[key] = value + } + } + } + const errorChunks: string[] = [] const sessionAllowedTools = new Set(session.allowed_tools ?? []) @@ -285,17 +321,19 @@ class ClaudeCodeService implements AgentServiceInterface { logger.warn('claude stderr', { chunk }) errorChunks.push(chunk) }, - systemPrompt: session.instructions - ? { - type: 'preset', - preset: 'claude_code', - append: `${session.instructions}\n\n${getLanguageInstruction()}` - } - : { - type: 'preset', - preset: 'claude_code', - append: getLanguageInstruction() - }, + systemPrompt: (session as EnhancedSession)._systemPrompt + ? `${(session as EnhancedSession)._systemPrompt}\n\n${getLanguageInstruction()}` + : session.instructions + ? { + type: 'preset', + preset: 'claude_code', + append: `${session.instructions}\n\n${getLanguageInstruction()}` + } + : { + type: 'preset', + preset: 'claude_code', + append: getLanguageInstruction() + }, settingSources: ['project', 'local'], includePartialMessages: true, permissionMode: session.configuration?.permission_mode, @@ -334,6 +372,35 @@ class ClaudeCodeService implements AgentServiceInterface { options.strictMcpConfig = true } + // Merge enhanced session fields injected by agent services (e.g. CherryClaw) + const enhancedSession = session as EnhancedSession + if (enhancedSession._internalMcpServers) { + if (!options.mcpServers) { + options.mcpServers = {} + } + for (const [name, config] of Object.entries(enhancedSession._internalMcpServers)) { + if (config.type === 'inmem') { + options.mcpServers[name] = { type: 'sdk', name, instance: config.instance } + } else { + options.mcpServers[name] = { type: config.type, url: config.url, headers: config.headers } + } + } + logger.debug('Merged internal MCP servers into SDK options', { + serverNames: Object.keys(enhancedSession._internalMcpServers), + totalMcpServers: Object.keys(options.mcpServers).length + }) + } + + // Disable specific builtin tools if requested by agent service + if (enhancedSession._disallowedTools) { + options.disallowedTools = enhancedSession._disallowedTools + } + + // Apply additional settings if provided by agent service + if (enhancedSession._settings) { + options.settings = enhancedSession._settings + } + if (lastAgentSessionId && !NO_RESUME_COMMANDS.some((cmd) => prompt.includes(cmd))) { options.resume = lastAgentSessionId // TODO: use fork session when we support branching sessions @@ -489,8 +556,16 @@ class ClaudeCodeService implements AgentServiceInterface { jsonOutput.push(message) - // Handle init message - merge builtin and SDK slash_commands + // Handle init message - capture SDK session_id and merge slash_commands if (message.type === 'system' && message.subtype === 'init') { + if (message.session_id) { + stream.sdkSessionId = message.session_id + logger.info('Captured SDK session_id from init message', { + sdkSessionId: message.session_id, + sessionId + }) + } + const sdkSlashCommands = message.slash_commands || [] logger.info('Received init message with slash commands', { sessionId, diff --git a/src/main/services/agents/services/claudecode/internal-mcp.ts b/src/main/services/agents/services/claudecode/internal-mcp.ts new file mode 100644 index 00000000000..7efc4fb37a4 --- /dev/null +++ b/src/main/services/agents/services/claudecode/internal-mcp.ts @@ -0,0 +1,24 @@ +import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' + +/** + * HTTP-based MCP server config (connects over network). + */ +export type InternalMcpHttpServerConfig = { + type: 'http' + url: string + headers?: Record +} + +/** + * In-memory MCP server config (runs in-process via McpServer instance). + */ +export type InternalMcpInMemServerConfig = { + type: 'inmem' + instance: McpServer +} + +/** + * Configuration for an internal MCP server injected by agent services. + * These get merged into the SDK's mcpServers option alongside user-configured MCPs. + */ +export type InternalMcpServerConfig = InternalMcpHttpServerConfig | InternalMcpInMemServerConfig diff --git a/src/main/services/agents/services/index.ts b/src/main/services/agents/services/index.ts index e6e545a442a..0cd432dae3a 100644 --- a/src/main/services/agents/services/index.ts +++ b/src/main/services/agents/services/index.ts @@ -9,11 +9,24 @@ export { AgentService } from './AgentService' export { SessionMessageService } from './SessionMessageService' export { SessionService } from './SessionService' +export { TaskService } from './TaskService' // Service instances (singletons) export { agentService } from './AgentService' export { sessionMessageService } from './SessionMessageService' export { sessionService } from './SessionService' +export { taskService } from './TaskService' + +// Agent service registry +export { agentServiceRegistry } from './AgentServiceRegistry' + +// Register agent services — claude-code first (CherryClaw delegates to it at runtime) +import { agentServiceRegistry } from './AgentServiceRegistry' +import { CherryClawService } from './cherryclaw' +import ClaudeCodeService from './claudecode' + +agentServiceRegistry.register('claude-code', new ClaudeCodeService()) +agentServiceRegistry.register('cherry-claw', new CherryClawService()) // Type definitions for service requests and responses export type { AgentEntity, AgentSessionEntity, CreateAgentRequest, UpdateAgentRequest } from '@types' diff --git a/src/main/utils/ipService.ts b/src/main/utils/ipService.ts index 708af4c40ef..8604c952a90 100644 --- a/src/main/utils/ipService.ts +++ b/src/main/utils/ipService.ts @@ -13,7 +13,7 @@ export async function getIpCountry(): Promise { const controller = new AbortController() const timeoutId = setTimeout(() => controller.abort(), 5000) - const ipinfo = await net.fetch(`https://api.ipinfo.io/lite/me?token=2a42580355dae4`, { + const ipinfo = await net.fetch(`https://api.ipinfo.io/lite/me?token=5aa4105b40adbc`, { signal: controller.signal }) diff --git a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts index b99e469e1e3..e1993fe8c14 100644 --- a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts +++ b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts @@ -32,6 +32,7 @@ export class AiSdkToChunkAdapter { private firstTokenTimestamp: number | null = null private hasTextContent = false private getSessionWasCleared?: () => boolean + private providerId?: string constructor( private onChunk: (chunk: Chunk) => void, @@ -39,13 +40,15 @@ export class AiSdkToChunkAdapter { accumulate?: boolean, enableWebSearch?: boolean, onSessionUpdate?: (sessionId: string) => void, - getSessionWasCleared?: () => boolean + getSessionWasCleared?: () => boolean, + providerId?: string ) { this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools) this.accumulate = accumulate this.enableWebSearch = enableWebSearch || false this.onSessionUpdate = onSessionUpdate this.getSessionWasCleared = getSessionWasCleared + this.providerId = providerId } private markFirstTokenIfNeeded() { @@ -324,7 +327,7 @@ export class AiSdkToChunkAdapter { } }) } else if (final.webSearchResults.length) { - const providerName = Object.keys(providerMetadata || {})[0] + const providerName: string | undefined = Object.keys(providerMetadata || {})[0] || this.providerId const sourceMap: Record = { [WEB_SEARCH_SOURCE.OPENAI]: WEB_SEARCH_SOURCE.OPENAI_RESPONSE, [WEB_SEARCH_SOURCE.ANTHROPIC]: WEB_SEARCH_SOURCE.ANTHROPIC, @@ -335,9 +338,10 @@ export class AiSdkToChunkAdapter { [WEB_SEARCH_SOURCE.HUNYUAN]: WEB_SEARCH_SOURCE.HUNYUAN, [WEB_SEARCH_SOURCE.ZHIPU]: WEB_SEARCH_SOURCE.ZHIPU, [WEB_SEARCH_SOURCE.GROK]: WEB_SEARCH_SOURCE.GROK, + xai: WEB_SEARCH_SOURCE.GROK, [WEB_SEARCH_SOURCE.WEBSEARCH]: WEB_SEARCH_SOURCE.WEBSEARCH } - const source = sourceMap[providerName] || WEB_SEARCH_SOURCE.AISDK + const source = (providerName && sourceMap[providerName]) || WEB_SEARCH_SOURCE.AISDK this.onChunk({ type: ChunkType.LLM_WEB_SEARCH_COMPLETE, diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index cae49ff83b3..15a302a8403 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -320,7 +320,15 @@ export default class ModernAiProvider { // 创建带有中间件的执行器 if (config.onChunk) { const accumulate = this.model!.supported_text_delta !== false // true and undefined - const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools, accumulate, config.enableWebSearch) + const adapter = new AiSdkToChunkAdapter( + config.onChunk, + config.mcpTools, + accumulate, + config.enableWebSearch, + undefined, + undefined, + this.config!.providerId + ) const streamResult = await executor.streamText({ ...params, diff --git a/src/renderer/src/aiCore/utils/__tests__/extractAiSdkStandardParams.test.ts b/src/renderer/src/aiCore/utils/__tests__/extractAiSdkStandardParams.test.ts index 288cc2e4a53..3074d61fc6b 100644 --- a/src/renderer/src/aiCore/utils/__tests__/extractAiSdkStandardParams.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/extractAiSdkStandardParams.test.ts @@ -57,13 +57,17 @@ vi.mock('@renderer/services/ProviderService', () => ({ })) // Mock config modules -vi.mock('@renderer/config/models', () => ({ - isOpenAIModel: vi.fn(() => false), - isQwenMTModel: vi.fn(() => false), - isSupportFlexServiceTierModel: vi.fn(() => false), - isSupportVerbosityModel: vi.fn(() => false), - getModelSupportedVerbosity: vi.fn(() => []) -})) +vi.mock('@renderer/config/models', async (importOriginal) => { + const actual: any = await importOriginal() + return { + ...actual, + isOpenAIModel: vi.fn(() => false), + isQwenMTModel: vi.fn(() => false), + isSupportFlexServiceTierModel: vi.fn(() => false), + isSupportVerbosityModel: vi.fn(() => false), + getModelSupportedVerbosity: vi.fn(() => []) + } +}) vi.mock('@renderer/config/translate', () => ({ mapLanguageToQwenMTModel: vi.fn() diff --git a/src/renderer/src/aiCore/utils/__tests__/options.test.ts b/src/renderer/src/aiCore/utils/__tests__/options.test.ts index a6c9a6c95c6..7e06d07aea1 100644 --- a/src/renderer/src/aiCore/utils/__tests__/options.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/options.test.ts @@ -1118,6 +1118,106 @@ describe('options utils', () => { }) }) + it('should auto-convert reasoning_effort to reasoningEffort for openai-compatible provider (issue #11987)', async () => { + const { getCustomParameters } = await import('../reasoning') + + // Simulate Volcano Engine (Doubao) or similar OpenAI-compatible provider + const volcengineProvider = { + id: 'openai-compatible', + name: 'Volcano Engine', + type: 'openai', + apiKey: 'test-key', + apiHost: 'https://ark.cn-beijing.volces.com/api/v3', + models: [] as Model[] + } as Provider + + const doubaoModel: Model = { + id: 'doubao-seed-1.8-thinking', + name: 'Doubao Seed 1.8 Thinking', + provider: 'openai-compatible' + } as Model + + // User configures reasoning_effort (snake_case) following API docs + vi.mocked(getCustomParameters).mockReturnValue({ + reasoning_effort: 'high' + }) + + const result = buildProviderOptions(mockAssistant, doubaoModel, volcengineProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // buildProviderOptions converts reasoning_effort → reasoningEffort for openai-compatible + expect(result.providerOptions['openai-compatible']).toHaveProperty('reasoningEffort') + expect(result.providerOptions['openai-compatible'].reasoningEffort).toBe('high') + expect(result.providerOptions['openai-compatible']).not.toHaveProperty('reasoning_effort') + }) + + it('should NOT convert reasoning_effort for non-openai-compatible providers', async () => { + const { getCustomParameters } = await import('../reasoning') + + const openaiProvider: Provider = { + id: SystemProviderIds.openai, + name: 'OpenAI', + type: 'openai-response', + apiKey: 'test-key', + apiHost: 'https://api.openai.com/v1', + isSystem: true + } as Provider + + // User configures reasoning_effort for native OpenAI provider + vi.mocked(getCustomParameters).mockReturnValue({ + reasoning_effort: 'high' + }) + + const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Native OpenAI provider should keep reasoning_effort as-is + expect(result.providerOptions.openai).toHaveProperty('reasoning_effort') + expect(result.providerOptions.openai.reasoning_effort).toBe('high') + expect(result.providerOptions.openai).not.toHaveProperty('reasoningEffort') + }) + + it('should not overwrite existing reasoningEffort when converting for openai-compatible', async () => { + const { getCustomParameters } = await import('../reasoning') + + const volcengineProvider = { + id: 'openai-compatible', + name: 'Volcano Engine', + type: 'openai', + apiKey: 'test-key', + apiHost: 'https://ark.cn-beijing.volces.com/api/v3', + models: [] as Model[] + } as Provider + + const doubaoModel: Model = { + id: 'doubao-seed-1.8-thinking', + name: 'Doubao Seed 1.8 Thinking', + provider: 'openai-compatible' + } as Model + + // User configures both forms + vi.mocked(getCustomParameters).mockReturnValue({ + reasoningEffort: 'low', + reasoning_effort: 'high' + }) + + const result = buildProviderOptions(mockAssistant, doubaoModel, volcengineProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Explicit reasoningEffort should be preserved, reasoning_effort removed + expect(result.providerOptions['openai-compatible'].reasoningEffort).toBe('low') + expect(result.providerOptions['openai-compatible']).not.toHaveProperty('reasoning_effort') + }) + it('should handle cross-provider configurations', async () => { const { getCustomParameters } = await import('../reasoning') diff --git a/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts b/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts index a9cd2fbb895..2b28878a374 100644 --- a/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts @@ -7,7 +7,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings' import type { SettingsState } from '@renderer/store/settings' import type { Assistant, Model, Provider } from '@renderer/types' import { SystemProviderIds } from '@renderer/types' -import { beforeEach, describe, expect, it, vi } from 'vitest' +import { beforeAll, beforeEach, describe, expect, it, vi } from 'vitest' import { getAnthropicReasoningParams, @@ -707,10 +707,16 @@ describe('reasoning utils', () => { }) describe('getGeminiReasoningParams', () => { - it('should return empty for non-reasoning model', async () => { - const { isReasoningModel } = await import('@renderer/config/models') + // Use beforeAll to avoid per-test dynamic imports while keeping compatibility + // with the async vi.mock factory (static imports of the mocked module break other tests) + let mockModels: any - vi.mocked(isReasoningModel).mockReturnValue(false) + beforeAll(async () => { + mockModels = await import('@renderer/config/models') + }) + + it('should return empty for non-reasoning model', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(false) const model: Model = { id: 'gemini-2.0-flash', @@ -728,11 +734,69 @@ describe('reasoning utils', () => { expect(result).toEqual({}) }) - it('should disable thinking for Flash models when reasoning effort is none', async () => { - const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models') + it('should return empty when isReasoningModel is true but not a Gemini thinking model', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(false) - vi.mocked(isReasoningModel).mockReturnValue(true) - vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + const model: Model = { + id: 'some-reasoning-model', + name: 'Some Model', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: { reasoning_effort: 'high' } + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + expect(result).toEqual({}) + }) + + it('should return empty when reasoning effort is not set', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + + const model: Model = { + id: 'gemini-2.5-pro', + name: 'Gemini 2.5 Pro', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: {} + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + expect(result).toEqual({}) + }) + + it('should return empty when reasoning effort is default', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + + const model: Model = { + id: 'gemini-2.5-pro', + name: 'Gemini 2.5 Pro', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: { reasoning_effort: 'default' } + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + expect(result).toEqual({}) + }) + + it('should disable thinking for Flash models when reasoning effort is none', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) const model: Model = { id: 'gemini-2.5-flash', @@ -757,11 +821,218 @@ describe('reasoning utils', () => { }) }) - it('should enable thinking with budget for reasoning effort', async () => { - const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models') + it('should disable thinking for non-Flash models when reasoning effort is none (no thinkingBudget)', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) - vi.mocked(isReasoningModel).mockReturnValue(true) - vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + const model: Model = { + id: 'gemini-2.5-pro', + name: 'Gemini 2.5 Pro', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: { + reasoning_effort: 'none' + } + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + expect(result).toEqual({ + thinkingConfig: { + includeThoughts: false + } + }) + }) + + it('should include thinkingLevel for Gemini 3 model with none effort', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + vi.mocked(mockModels.isGemini3ThinkingTokenModel).mockReturnValue(true) + + const model: Model = { + id: 'gemini-3-flash-preview', + name: 'Gemini 3 Flash', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: { reasoning_effort: 'none' } + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + expect(result).toEqual({ + thinkingConfig: { + includeThoughts: false, + thinkingLevel: 'minimal' + } + }) + }) + + it('should return thinkingLevel for Gemini 3 model with low effort', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + vi.mocked(mockModels.isGemini3ThinkingTokenModel).mockReturnValue(true) + + const model: Model = { + id: 'gemini-3-flash-preview', + name: 'Gemini 3 Flash', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: { reasoning_effort: 'low' } + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + expect(result).toEqual({ + thinkingConfig: { + includeThoughts: true, + thinkingLevel: 'low' + } + }) + }) + + it('should return thinkingLevel medium for Gemini 3 model with medium effort', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + vi.mocked(mockModels.isGemini3ThinkingTokenModel).mockReturnValue(true) + + const model: Model = { + id: 'gemini-3-flash-preview', + name: 'Gemini 3 Flash', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: { reasoning_effort: 'medium' } + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + expect(result).toEqual({ + thinkingConfig: { + includeThoughts: true, + thinkingLevel: 'medium' + } + }) + }) + + it('should return thinkingLevel high for Gemini 3 model with high effort', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + vi.mocked(mockModels.isGemini3ThinkingTokenModel).mockReturnValue(true) + + const model: Model = { + id: 'gemini-3-flash-preview', + name: 'Gemini 3 Flash', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: { reasoning_effort: 'high' } + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + expect(result).toEqual({ + thinkingConfig: { + includeThoughts: true, + thinkingLevel: 'high' + } + }) + }) + + it('should return thinkingLevel high for Gemini 3 model with xhigh effort', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + vi.mocked(mockModels.isGemini3ThinkingTokenModel).mockReturnValue(true) + + const model: Model = { + id: 'gemini-3-flash-preview', + name: 'Gemini 3 Flash', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: { reasoning_effort: 'xhigh' } + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + expect(result).toEqual({ + thinkingConfig: { + includeThoughts: true, + thinkingLevel: 'high' + } + }) + }) + + it('should use undefined thinkingLevel for Gemini 3 model with auto effort', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + vi.mocked(mockModels.isGemini3ThinkingTokenModel).mockReturnValue(true) + + const model: Model = { + id: 'gemini-3-flash-preview', + name: 'Gemini 3 Flash', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: { reasoning_effort: 'auto' } + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + // auto maps to undefined thinkingLevel (let API decide), stays in Gemini 3 branch + expect(result).toEqual({ + thinkingConfig: { + includeThoughts: true, + thinkingLevel: undefined + } + }) + }) + + it('should return thinkingLevel minimal for Gemini 3 model with minimal effort', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + vi.mocked(mockModels.isGemini3ThinkingTokenModel).mockReturnValue(true) + + const model: Model = { + id: 'gemini-3-flash-preview', + name: 'Gemini 3 Flash', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: { reasoning_effort: 'minimal' } + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + expect(result).toEqual({ + thinkingConfig: { + includeThoughts: true, + thinkingLevel: 'minimal' + } + }) + }) + + it('should enable thinking with budget for reasoning effort', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) const model: Model = { id: 'gemini-2.5-pro', @@ -786,11 +1057,35 @@ describe('reasoning utils', () => { }) }) - it('should enable thinking without budget for auto effort ratio > 1', async () => { - const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models') + it('should compute thinkingBudget for old models with xhigh effort', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) - vi.mocked(isReasoningModel).mockReturnValue(true) - vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + const model: Model = { + id: 'gemini-2.5-pro', + name: 'Gemini 2.5 Pro', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: { reasoning_effort: 'xhigh' } + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + // EFFORT_RATIO['xhigh'] = 0.9, which is NOT > 1, so it should compute a budget + expect(result).toEqual({ + thinkingConfig: { + thinkingBudget: expect.any(Number), + includeThoughts: true + } + }) + }) + + it('should return thinkingBudget -1 for old models with auto effort', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) const model: Model = { id: 'gemini-2.5-pro', @@ -814,6 +1109,60 @@ describe('reasoning utils', () => { } }) }) + + it('should omit thinkingBudget for old models when no token limit is found', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + vi.mocked(mockModels.findTokenLimit).mockReturnValue(undefined) + + const model: Model = { + id: 'gemini-2.5-pro-unknown', + name: 'Gemini 2.5 Pro Unknown', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: { reasoning_effort: 'medium' } + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + // budget = Math.floor((0 - 0) * 0.5 + 0) = 0, so no thinkingBudget + expect(result).toEqual({ + thinkingConfig: { + includeThoughts: true + } + }) + }) + + it('should calculate correct thinkingBudget for low effort', () => { + vi.mocked(mockModels.isReasoningModel).mockReturnValue(true) + vi.mocked(mockModels.isSupportedThinkingTokenGeminiModel).mockReturnValue(true) + vi.mocked(mockModels.findTokenLimit).mockReturnValue({ min: 1024, max: 32768 }) + + const model: Model = { + id: 'gemini-2.5-pro', + name: 'Gemini 2.5 Pro', + provider: SystemProviderIds.gemini + } as Model + + const assistant: Assistant = { + id: 'test', + name: 'Test', + settings: { reasoning_effort: 'low' } + } as Assistant + + const result = getGeminiReasoningParams(assistant, model) + // EFFORT_RATIO['low'] = 0.05 + // budget = Math.floor((32768 - 1024) * 0.05 + 1024) = Math.floor(1587.2 + 1024) = 2611 + expect(result).toEqual({ + thinkingConfig: { + thinkingBudget: 2611, + includeThoughts: true + } + }) + }) }) describe('getXAIReasoningParams', () => { diff --git a/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts b/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts index 5c95a664aaa..ee788d47b56 100644 --- a/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts @@ -254,42 +254,43 @@ describe('websearch utils', () => { }) describe('xai provider', () => { - it('should return xai search options', () => { + it('should return xai search options with enableImageUnderstanding when no excludeDomains', () => { const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig) expect(result).toEqual({ - xai: { - maxSearchResults: 30, - returnCitations: true, - sources: [{ type: 'web', excludedWebsites: [] }, { type: 'news' }, { type: 'x' }], - mode: 'on' - } + xai: { enableImageUnderstanding: true }, + 'xai-xsearch': { enableImageUnderstanding: true } }) }) - it('should limit excluded websites to 5', () => { + it('should include excludedDomains when excludeDomains provided', () => { const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 40, - excludeDomains: ['site1.com', 'site2.com', 'site3.com', 'site4.com', 'site5.com', 'site6.com', 'site7.com'] + excludeDomains: ['site1.com', 'site2.com'] } const result = buildProviderBuiltinWebSearchConfig('xai', config) - expect(result?.xai?.sources).toBeDefined() - const webSource = result?.xai?.sources?.[0] - if (webSource && webSource.type === 'web') { - expect(webSource.excludedWebsites).toHaveLength(5) - } + expect(result).toEqual({ + xai: { + enableImageUnderstanding: true, + excludedDomains: ['site1.com', 'site2.com'] + }, + 'xai-xsearch': { enableImageUnderstanding: true } + }) }) - it('should include all sources types', () => { - const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig) + it('should limit excluded domains to 5', () => { + const config: CherryWebSearchConfig = { + searchWithTime: true, + maxResults: 40, + excludeDomains: ['site1.com', 'site2.com', 'site3.com', 'site4.com', 'site5.com', 'site6.com', 'site7.com'] + } + + const result = buildProviderBuiltinWebSearchConfig('xai', config) - expect(result?.xai?.sources).toHaveLength(3) - expect(result?.xai?.sources?.[0].type).toBe('web') - expect(result?.xai?.sources?.[1].type).toBe('news') - expect(result?.xai?.sources?.[2].type).toBe('x') + expect(result?.xai?.excludedDomains).toHaveLength(5) }) }) diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index a3d2d30b46d..23a82957170 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -283,6 +283,16 @@ export function buildProviderOptions( const actualAiSdkProviderIds = Object.keys(providerSpecificOptions) const primaryAiSdkProviderId = actualAiSdkProviderIds[0] // Use the first one as primary for non-scoped params + // For openai-compatible providers, auto-convert reasoning_effort (snake_case) to reasoningEffort (camelCase). + // The AI SDK's openai-compatible provider overwrites reasoning_effort to undefined, + // but accepts reasoningEffort. See: https://github.com/CherryHQ/cherry-studio/issues/11987 + if (primaryAiSdkProviderId === 'openai-compatible' && 'reasoning_effort' in providerParams) { + if (!('reasoningEffort' in providerParams)) { + providerParams.reasoningEffort = providerParams.reasoning_effort + } + delete providerParams.reasoning_effort + } + /** * Merge custom parameters into providerSpecificOptions. * Simple logic: diff --git a/src/renderer/src/aiCore/utils/reasoning.ts b/src/renderer/src/aiCore/utils/reasoning.ts index 87114c5d2cb..1d13d4dd6d2 100644 --- a/src/renderer/src/aiCore/utils/reasoning.ts +++ b/src/renderer/src/aiCore/utils/reasoning.ts @@ -43,6 +43,7 @@ import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types' import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types' import type { OpenAIReasoningSummary } from '@renderer/types/aiCoreTypes' import type { ReasoningEffortOptionalParams } from '@renderer/types/sdk' +import { getLowerBaseModelName } from '@renderer/utils' import { isSupportEnableThinkingProvider } from '@renderer/utils/provider' import { toInteger } from 'lodash' @@ -653,8 +654,11 @@ type GoogleThinkingLevel = NonNullable { - if (!isReasoningModel(model)) { + if (!isReasoningModel(model) || !isSupportedThinkingTokenGeminiModel(model)) { return {} } @@ -689,34 +695,43 @@ export function getGeminiReasoningParams( return {} } - // Gemini 推理参数 - if (isSupportedThinkingTokenGeminiModel(model)) { - if (reasoningEffort === undefined || reasoningEffort === 'none') { - return { - thinkingConfig: { - includeThoughts: false, - ...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {}) - } + let thinkingLevel: GoogleThinkingLevel | null = null + const includeThoughts = reasoningEffort !== 'none' + + // https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3 + if (isGemini3ThinkingTokenModel(model)) { + thinkingLevel = mapToGeminiThinkingLevel(reasoningEffort) + if (thinkingLevel === 'minimal' && getLowerBaseModelName(model.id).includes('pro')) { + thinkingLevel = 'low' + } + } + + if (thinkingLevel !== null) { + // Gemini 3 branch. thinkingLevel can be undefined (auto) or a specific level. + return { + thinkingConfig: { + includeThoughts, + thinkingLevel } } + } else { + // Old models + const effortRatio = EFFORT_RATIO[reasoningEffort] - // https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3 - if (isGemini3ThinkingTokenModel(model)) { + if (reasoningEffort === 'auto') { return { thinkingConfig: { - includeThoughts: true, - thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort) + includeThoughts, + thinkingBudget: -1 } } } - const effortRatio = EFFORT_RATIO[reasoningEffort] - - if (effortRatio > 1) { + if (reasoningEffort === 'none') { return { thinkingConfig: { - thinkingBudget: -1, - includeThoughts: true + includeThoughts, + ...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {}) } } } @@ -726,13 +741,11 @@ export function getGeminiReasoningParams( return { thinkingConfig: { - ...(budget > 0 ? { thinkingBudget: budget } : {}), - includeThoughts: true + includeThoughts, + ...(budget > 0 ? { thinkingBudget: budget } : {}) } } } - - return {} } /** diff --git a/src/renderer/src/aiCore/utils/websearch.ts b/src/renderer/src/aiCore/utils/websearch.ts index 14a99139bec..60649b86c82 100644 --- a/src/renderer/src/aiCore/utils/websearch.ts +++ b/src/renderer/src/aiCore/utils/websearch.ts @@ -1,7 +1,9 @@ import type { AnthropicSearchConfig, OpenAISearchConfig, - WebSearchPluginConfig + WebSearchPluginConfig, + XAIWebSearchConfig, + XAIXSearchConfig } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin/helper' import type { BaseProviderId } from '@cherrystudio/ai-core/provider' import { isOpenAIDeepResearchModel, isOpenAIWebSearchChatCompletionOnlyModel } from '@renderer/config/models' @@ -9,8 +11,6 @@ import type { CherryWebSearchConfig } from '@renderer/store/websearch' import type { Model } from '@renderer/types' import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern' -const X_AI_MAX_SEARCH_RESULT = 30 - export function getWebSearchParams(model: Model): Record { if (model.provider === 'hunyuan') { return { enable_enhancement: true, citation: true, search_info: true } @@ -82,20 +82,18 @@ export function buildProviderBuiltinWebSearchConfig( } case 'xai': { const excludeDomains = mapRegexToPatterns(webSearchConfig.excludeDomains) + const xaiWebConfig: XAIWebSearchConfig = { + enableImageUnderstanding: true + } + if (excludeDomains.length > 0) { + xaiWebConfig.excludedDomains = excludeDomains.slice(0, 5) + } + const xaiXSearchConfig: XAIXSearchConfig = { + enableImageUnderstanding: true + } return { - xai: { - maxSearchResults: Math.min(webSearchConfig.maxResults, X_AI_MAX_SEARCH_RESULT), - returnCitations: true, - sources: [ - { - type: 'web', - excludedWebsites: excludeDomains.slice(0, Math.min(excludeDomains.length, 5)) - }, - { type: 'news' }, - { type: 'x' } - ], - mode: 'on' - } + xai: xaiWebConfig, + 'xai-xsearch': xaiXSearchConfig } } case 'openrouter': { diff --git a/src/renderer/src/api/agent.ts b/src/renderer/src/api/agent.ts index 6f1c8cf1e34..50596046f29 100644 --- a/src/renderer/src/api/agent.ts +++ b/src/renderer/src/api/agent.ts @@ -19,6 +19,13 @@ import type { UpdateSessionForm, UpdateSessionRequest } from '@types' +import type { + CreateTaskRequest, + ListTaskLogsResponse, + ListTasksResponse, + ScheduledTaskEntity, + UpdateTaskRequest +} from '@types' import { AgentServerErrorSchema, ApiModelsResponseSchema, @@ -29,8 +36,11 @@ import { ListAgentSessionsResponseSchema, type ListAgentsResponse, ListAgentsResponseSchema, + ListTaskLogsResponseSchema, + ListTasksResponseSchema, objectEntries, objectKeys, + ScheduledTaskEntitySchema, UpdateAgentResponseSchema } from '@types' import type { Axios, AxiosRequestConfig } from 'axios' @@ -86,6 +96,13 @@ export class AgentApiClient { withId: (id: number) => `/${this.apiVersion}/agents/${agentId}/sessions/${sessionId}/messages/${id}` }) + public getTaskPaths = (agentId: string) => ({ + base: `/${this.apiVersion}/agents/${agentId}/tasks`, + withId: (taskId: string) => `/${this.apiVersion}/agents/${agentId}/tasks/${taskId}`, + run: (taskId: string) => `/${this.apiVersion}/agents/${agentId}/tasks/${taskId}/run`, + logs: (taskId: string) => `/${this.apiVersion}/agents/${agentId}/tasks/${taskId}/logs` + }) + public getModelsPath = (props?: ApiModelsFilter) => { const base = `/${this.apiVersion}/models` if (!props) return base @@ -257,4 +274,85 @@ export class AgentApiClient { throw processError(error, 'Failed to get models.') } } + + // --- Task CRUD --- + + public async listTasks(agentId: string, options?: ListOptions): Promise { + const url = this.getTaskPaths(agentId).base + try { + const response = await this.axios.get(url, { params: options }) + const result = ListTasksResponseSchema.safeParse(response.data) + if (!result.success) { + throw new Error('Not a valid Tasks response.') + } + return result.data + } catch (error) { + throw processError(error, 'Failed to list tasks.') + } + } + + public async createTask(agentId: string, task: CreateTaskRequest): Promise { + const url = this.getTaskPaths(agentId).base + try { + const response = await this.axios.post(url, task) + const data = ScheduledTaskEntitySchema.parse(response.data) + return data + } catch (error) { + throw processError(error, 'Failed to create task.') + } + } + + public async getTask(agentId: string, taskId: string): Promise { + const url = this.getTaskPaths(agentId).withId(taskId) + try { + const response = await this.axios.get(url) + const data = ScheduledTaskEntitySchema.parse(response.data) + return data + } catch (error) { + throw processError(error, 'Failed to get task.') + } + } + + public async updateTask(agentId: string, taskId: string, updates: UpdateTaskRequest): Promise { + const url = this.getTaskPaths(agentId).withId(taskId) + try { + const response = await this.axios.patch(url, updates) + const data = ScheduledTaskEntitySchema.parse(response.data) + return data + } catch (error) { + throw processError(error, 'Failed to update task.') + } + } + + public async deleteTask(agentId: string, taskId: string): Promise { + const url = this.getTaskPaths(agentId).withId(taskId) + try { + await this.axios.delete(url) + } catch (error) { + throw processError(error, 'Failed to delete task.') + } + } + + public async runTask(agentId: string, taskId: string): Promise { + const url = this.getTaskPaths(agentId).run(taskId) + try { + await this.axios.post(url) + } catch (error) { + throw processError(error, 'Failed to run task.') + } + } + + public async getTaskLogs(agentId: string, taskId: string, options?: ListOptions): Promise { + const url = this.getTaskPaths(agentId).logs(taskId) + try { + const response = await this.axios.get(url, { params: options }) + const result = ListTaskLogsResponseSchema.safeParse(response.data) + if (!result.success) { + throw new Error('Not a valid TaskLogs response.') + } + return result.data + } catch (error) { + throw processError(error, 'Failed to get task logs.') + } + } } diff --git a/src/renderer/src/assets/images/models/cherry-claw.png b/src/renderer/src/assets/images/models/cherry-claw.png new file mode 100644 index 00000000000..2be60db76b0 Binary files /dev/null and b/src/renderer/src/assets/images/models/cherry-claw.png differ diff --git a/src/renderer/src/assets/styles/ant.css b/src/renderer/src/assets/styles/ant.css index 7d651a6a6a9..e42786002eb 100644 --- a/src/renderer/src/assets/styles/ant.css +++ b/src/renderer/src/assets/styles/ant.css @@ -140,10 +140,27 @@ .ant-dropdown-menu-submenu { background-color: var(--ant-color-bg-elevated); - overflow: hidden; + overflow: visible; border-radius: var(--ant-border-radius-lg); } +/* Enable scrolling for Move To submenu - Issue #13350 */ +.ant-dropdown-menu-submenu .ant-dropdown-menu { + max-height: 60vh; + overflow-y: auto; + overflow-x: hidden; +} + +/* Hide scrollbar for Move To submenu only */ +.move-to-submenu .ant-dropdown-menu { + scrollbar-width: none; + -ms-overflow-style: none; +} + +.move-to-submenu .ant-dropdown-menu::-webkit-scrollbar { + display: none; +} + .ant-dropdown-menu-submenu .ant-dropdown-menu-submenu-title { align-items: center; } diff --git a/src/renderer/src/components/AnthropicProviderListPopover.tsx b/src/renderer/src/components/AnthropicProviderListPopover.tsx new file mode 100644 index 00000000000..9e38f60ab03 --- /dev/null +++ b/src/renderer/src/components/AnthropicProviderListPopover.tsx @@ -0,0 +1,148 @@ +import { ProviderAvatar } from '@renderer/components/ProviderAvatar' +import { useAllProviders } from '@renderer/hooks/useProvider' +import ImageStorage from '@renderer/services/ImageStorage' +import type { Provider } from '@renderer/types' +import { getFancyProviderName } from '@renderer/utils' +import { getClaudeSupportedProviders } from '@renderer/utils/provider' +import type { PopoverProps } from 'antd' +import { Popover } from 'antd' +import { ArrowUpRight, HelpCircle } from 'lucide-react' +import type { FC, ReactNode } from 'react' +import { useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import styled from 'styled-components' + +interface AnthropicProviderListPopoverProps { + /** Callback when provider is clicked */ + onProviderClick?: () => void + /** Use window.navigate instead of Link (for non-router context like TopView) */ + useWindowNavigate?: boolean + /** Custom trigger element, defaults to HelpCircle icon */ + children?: ReactNode + /** Popover placement */ + placement?: PopoverProps['placement'] + /** Custom filter function for providers, defaults to getClaudeSupportedProviders */ + filterProviders?: (providers: Provider[]) => Provider[] +} + +const AnthropicProviderListPopover: FC = ({ + onProviderClick, + useWindowNavigate = false, + children, + placement = 'right', + filterProviders = getClaudeSupportedProviders +}) => { + const { t } = useTranslation() + const allProviders = useAllProviders() + const providers = filterProviders(allProviders) + const [providerLogos, setProviderLogos] = useState>({}) + + useEffect(() => { + const loadAllLogos = async () => { + const logos: Record = {} + for (const provider of providers) { + if (provider.id) { + try { + const logoData = await ImageStorage.get(`provider-${provider.id}`) + if (logoData) { + logos[provider.id] = logoData + } + } catch { + // Ignore errors loading logos + } + } + } + setProviderLogos(logos) + } + + loadAllLogos() + }, [providers]) + + const handleClick = (providerId: string) => { + onProviderClick?.() + if (useWindowNavigate) { + window.navigate(`/settings/provider?id=${providerId}`) + } + } + + const content = ( + + {t('code.supported_providers')} + + {providers.map((provider) => + useWindowNavigate ? ( + handleClick(provider.id)}> + + {getFancyProviderName(provider)} + + + ) : ( + handleClick(provider.id)}> + + {getFancyProviderName(provider)} + + + ) + )} + + + ) + + return ( + + {children || } + + ) +} + +const PopoverContent = styled.div` + width: 200px; +` + +const PopoverTitle = styled.div` + margin-bottom: 8px; + font-weight: 500; +` + +const ProviderListContainer = styled.div` + display: flex; + flex-direction: column; + gap: 8px; +` + +const ProviderItem = styled.div` + color: var(--color-text); + display: flex; + align-items: center; + gap: 4px; + cursor: pointer; + &:hover { + color: var(--color-link); + } +` + +const ProviderLink = styled.a` + color: var(--color-text); + display: flex; + align-items: center; + gap: 4px; + text-decoration: none; + &:hover { + color: var(--color-link); + } +` + +export default AnthropicProviderListPopover diff --git a/src/renderer/src/components/Popups/SelectModelPopup/api-model-popup.tsx b/src/renderer/src/components/Popups/SelectModelPopup/agent-model-popup.tsx similarity index 94% rename from src/renderer/src/components/Popups/SelectModelPopup/api-model-popup.tsx rename to src/renderer/src/components/Popups/SelectModelPopup/agent-model-popup.tsx index 3924d6b57f8..bbfa54182ad 100644 --- a/src/renderer/src/components/Popups/SelectModelPopup/api-model-popup.tsx +++ b/src/renderer/src/components/Popups/SelectModelPopup/agent-model-popup.tsx @@ -5,6 +5,7 @@ import { TopView } from '@renderer/components/TopView' import { DynamicVirtualList, type DynamicVirtualListRef } from '@renderer/components/VirtualList' import { getModelLogoById } from '@renderer/config/models' import { useApiModels } from '@renderer/hooks/agents/useModels' +import { useAllProviders } from '@renderer/hooks/useProvider' import { getModelUniqId } from '@renderer/services/ModelService' import { getProviderNameById } from '@renderer/services/ProviderService' import type { AdaptedApiModel, ApiModel, ApiModelsFilter, Model, ModelType } from '@renderer/types' @@ -58,9 +59,11 @@ const PopupContainer: React.FC = ({ model, apiFilter, modelFilter, showTa const searchText = useDeferredValue(_searchText) const { models, isLoading } = useApiModels(apiFilter) const adaptedModels = useMemo(() => models.map((model) => apiModelAdapter(model)), [models]) + const allProviders = useAllProviders() + const providerOrderMap = useMemo(() => new Map(allProviders.map((p, i) => [p.id, i])), [allProviders]) - // 当前选中的模型ID - const currentModelId = model ? model.id : '' + // 当前选中的模型ID(需要转换为与列表项相同的格式) + const currentModelId = model ? getModelUniqId(apiModelAdapter(model)) : '' // 管理滚动和焦点状态 const [focusedItemKey, _setFocusedItemKey] = useState('') @@ -141,7 +144,14 @@ const PopupContainer: React.FC = ({ model, apiFilter, modelFilter, showTa // 按 provider 分组 const groups = groupBy(filteredModels, (model) => model.provider) as Record - objectEntries(groups).forEach(([key, models]) => { + // 按照 provider 配置顺序排序 group keys,cherryin 优先放在第一位 + const sortedProviderIds = sortBy(Object.keys(groups), (id) => { + if (id === 'cherryin') return -1 + return providerOrderMap.get(id) ?? Infinity + }) + + sortedProviderIds.forEach((key) => { + const models = groups[key] items.push({ key: key ?? 'Unknown', type: 'group', @@ -154,7 +164,7 @@ const PopupContainer: React.FC = ({ model, apiFilter, modelFilter, showTa // 获取可选择的模型项(过滤掉分组标题) const modelItems = items.filter((item) => item.type === 'model') return { listItems: items, modelItems } - }, [searchFilter, adaptedModels, showTagFilter, tagFilter, createModelItem, modelFilter]) + }, [searchFilter, adaptedModels, showTagFilter, tagFilter, createModelItem, modelFilter, providerOrderMap]) const listHeight = useMemo(() => { return Math.min(PAGE_SIZE, listItems.length) * ITEM_HEIGHT @@ -283,7 +293,7 @@ const PopupContainer: React.FC = ({ model, apiFilter, modelFilter, showTa const onAfterClose = useCallback(async () => { resolve(undefined) - SelectApiModelPopup.hide() + SelectAgentModelPopup.hide() }, [resolve]) const getItemKey = useCallback((index: number) => listItems[index].key, [listItems]) @@ -491,9 +501,9 @@ const EmptyState = styled.div` height: 200px; ` -const TopViewKey = 'SelectModelPopup' +const TopViewKey = 'SelectAgentModelPopup' -export class SelectApiModelPopup { +export class SelectAgentModelPopup { static topviewId = 0 static hide() { TopView.hide(TopViewKey) diff --git a/src/renderer/src/components/Popups/SelectModelPopup/index.ts b/src/renderer/src/components/Popups/SelectModelPopup/index.ts index 9df1e79dbcd..08032dc4ed2 100644 --- a/src/renderer/src/components/Popups/SelectModelPopup/index.ts +++ b/src/renderer/src/components/Popups/SelectModelPopup/index.ts @@ -1,2 +1,2 @@ -export { SelectApiModelPopup } from './api-model-popup' +export { SelectAgentModelPopup } from './agent-model-popup' export { SelectModelPopup } from './popup' diff --git a/src/renderer/src/components/Popups/agent/AgentModal.tsx b/src/renderer/src/components/Popups/agent/AgentModal.tsx index b72c23e6a43..7d9bc9884b9 100644 --- a/src/renderer/src/components/Popups/agent/AgentModal.tsx +++ b/src/renderer/src/components/Popups/agent/AgentModal.tsx @@ -1,8 +1,9 @@ import { loggerService } from '@logger' +import AnthropicProviderListPopover from '@renderer/components/AnthropicProviderListPopover' import { ErrorBoundary } from '@renderer/components/ErrorBoundary' import { HelpTooltip } from '@renderer/components/TooltipIcons' import { TopView } from '@renderer/components/TopView' -import { permissionModeCards } from '@renderer/config/agent' +import { DEFAULT_CHERRY_CLAW_CONFIG, permissionModeCards } from '@renderer/config/agent' import { isWin } from '@renderer/config/constant' import { useAgents } from '@renderer/hooks/agents/useAgents' import { useUpdateAgent } from '@renderer/hooks/agents/useUpdateAgent' @@ -10,6 +11,7 @@ import SelectAgentBaseModelButton from '@renderer/pages/home/components/SelectAg import type { AddAgentForm, AgentEntity, + AgentType, ApiModel, BaseAgentForm, PermissionMode, @@ -17,6 +19,8 @@ import type { UpdateAgentForm } from '@renderer/types' import { AgentConfigurationSchema, isAgentType } from '@renderer/types' +import { parseKeyValueString, serializeKeyValueString } from '@renderer/utils/env' +import { getAnthropicSupportedProviders } from '@renderer/utils/provider' import type { GitBashPathInfo } from '@shared/config/constant' import { Button, Input, Modal, Select } from 'antd' import type { ChangeEvent, FormEvent } from 'react' @@ -122,6 +126,26 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { } }, [checkGitBash]) + const onTypeChange = useCallback((value: AgentType) => { + setForm((prev) => { + if (value === 'cherry-claw') { + return { + ...prev, + type: value, + configuration: { + ...AgentConfigurationSchema.parse(prev.configuration ?? {}), + ...DEFAULT_CHERRY_CLAW_CONFIG.configuration + } + } + } + return { + ...prev, + type: value, + configuration: AgentConfigurationSchema.parse(prev.configuration ?? {}) + } + }) + }, []) + const onPermissionModeChange = useCallback((value: PermissionMode) => { setForm((prev) => { const parsedConfiguration = AgentConfigurationSchema.parse(prev.configuration ?? {}) @@ -166,6 +190,27 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { })) }, []) + const [envVarsText, setEnvVarsText] = useState(() => serializeKeyValueString(form.configuration?.env_vars ?? {})) + + useEffect(() => { + if (open) { + setEnvVarsText(serializeKeyValueString(buildAgentForm(agent).configuration?.env_vars ?? {})) + } + }, [agent, open]) + + const onEnvVarsChange = useCallback((e: ChangeEvent) => { + const text = e.target.value + setEnvVarsText(text) + const parsed = parseKeyValueString(text) + setForm((prev) => ({ + ...prev, + configuration: { + ...AgentConfigurationSchema.parse(prev.configuration ?? {}), + env_vars: parsed + } + })) + }, []) + const addAccessiblePath = useCallback(async () => { try { const selected = await window.api.file.selectFolder() @@ -338,14 +383,41 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { + {!isEditing(agent) && ( + + + + + )} + {form.type === 'cherry-claw' && ( + + {t( + 'agent.cherryClaw.warning.bypassPermissions', + 'CherryClaw agents run with Full Auto Mode by default. All tools execute without asking for approval.' + )} + + )} +
- + { + setOpen(false) + resolve(undefined) + }} + />
= ({ agent, afterSubmit, resolve }) => {