diff --git a/examples/agents/todo_tools_example.py b/examples/agents/todo_tools_example.py index c811a12a2..c639345f0 100644 --- a/examples/agents/todo_tools_example.py +++ b/examples/agents/todo_tools_example.py @@ -7,9 +7,8 @@ from ragbits.core.llms import LiteLLM, ToolCall -async def main(): +async def main() -> None: """Demonstrate the new instance-based todo approach with streaming and logging.""" - # Create a dedicated TodoList instance for this agent my_todo_list = TodoList() my_todo_manager = create_todo_manager(my_todo_list) @@ -32,14 +31,21 @@ async def main(): - Transportation details with times, costs, parking info - Weather considerations and backup plans - Safety information and emergency contacts - """ + get_todo_instruction_tpl(task_range=(3, 5)), + """ + + get_todo_instruction_tpl(task_range=(3, 5)), tools=[my_todo_manager], # Use the instance-specific todo manager - default_options=AgentOptions(max_turns=30) + default_options=AgentOptions(max_turns=30), ) - query = "Plan a 1-day hiking trip for 2 people in Tatra Mountains, Poland. Focus on scenic routes under 15km, avoiding crowds." + query = ( + "Plan a 1-day hiking trip for 2 people in Tatra Mountains, Poland. ", + "Focus on scenic routes under 15km, avoiding crowds.", + ) # query = "How long is hike to Giewont from Kuźnice?" - # query = "Is it difficult to finish Orla Perć? Would you recommend me to go there if I've never been in mountains before?" + # query = ( + # "Is it difficult to finish Orla Perć? Would you recommend me ", + # "to go there if I've never been in mountains before?", + # ) stream = my_agent.run_streaming(query) @@ -63,9 +69,9 @@ async def main(): for i, task in enumerate(tasks, 1): print(f" {i}. {task}") - print("\n\n" + "="*50) + print("\n\n" + "=" * 50) print("🎉 Systematic hiking trip planning completed!") if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/examples/chat/chat.py b/examples/chat/chat.py index ccbac8bc6..6457f6eff 100644 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -26,6 +26,7 @@ from pydantic import BaseModel, ConfigDict, Field +from ragbits.agents.tools.todo import Task, TaskStatus from ragbits.chat.interface import ChatInterface from ragbits.chat.interface.forms import FeedbackConfig, UserSettings from ragbits.chat.interface.types import ChatContext, ChatResponse, LiveUpdateType @@ -141,10 +142,28 @@ async def chat( ), ] + parentTask = Task(id="task_id_1", description="Example task with a subtask") + subtaskTask = Task(id="task_id_2", description="Example subtask", parent_id="task_id_1") + for live_update in example_live_updates: yield live_update await asyncio.sleep(2) + yield self.create_todo_item_response(parentTask) + yield self.create_todo_item_response(subtaskTask) + + await asyncio.sleep(2) + parentTask.status = TaskStatus.IN_PROGRESS + yield self.create_todo_item_response(parentTask) + await asyncio.sleep(2) + subtaskTask.status = TaskStatus.IN_PROGRESS + yield self.create_todo_item_response(subtaskTask) + await asyncio.sleep(2) + parentTask.status = TaskStatus.COMPLETED + subtaskTask.status = TaskStatus.COMPLETED + yield self.create_todo_item_response(subtaskTask) + yield self.create_todo_item_response(parentTask) + streaming_result = self.llm.generate_streaming([*history, {"role": "user", "content": message}]) async for chunk in streaming_result: yield self.create_text_response(chunk) diff --git a/packages/ragbits-agents/CHANGELOG.md b/packages/ragbits-agents/CHANGELOG.md index 52949840d..83ba7280a 100644 --- a/packages/ragbits-agents/CHANGELOG.md +++ b/packages/ragbits-agents/CHANGELOG.md @@ -1,6 +1,7 @@ # CHANGELOG ## Unreleased +- Add support for todo lists generated by agents with examples (#823) ## 1.3.0 (2025-09-11) ### Changed diff --git a/packages/ragbits-agents/src/ragbits/agents/__init__.py b/packages/ragbits-agents/src/ragbits/agents/__init__.py index 1b8b22fbb..b363cc9f6 100644 --- a/packages/ragbits-agents/src/ragbits/agents/__init__.py +++ b/packages/ragbits-agents/src/ragbits/agents/__init__.py @@ -7,7 +7,7 @@ AgentRunContext, ToolCallResult, ) -from ragbits.agents.tools import get_todo_instruction_tpl, create_todo_manager +from ragbits.agents.tools import create_todo_manager, get_todo_instruction_tpl from ragbits.agents.types import QuestionAnswerAgent, QuestionAnswerPromptInput, QuestionAnswerPromptOutput __all__ = [ @@ -21,6 +21,6 @@ "QuestionAnswerPromptInput", "QuestionAnswerPromptOutput", "ToolCallResult", - "get_todo_instruction_tpl", "create_todo_manager", + "get_todo_instruction_tpl", ] diff --git a/packages/ragbits-agents/src/ragbits/agents/tools/__init__.py b/packages/ragbits-agents/src/ragbits/agents/tools/__init__.py index 6960f5396..35f89bd19 100644 --- a/packages/ragbits-agents/src/ragbits/agents/tools/__init__.py +++ b/packages/ragbits-agents/src/ragbits/agents/tools/__init__.py @@ -1,5 +1,5 @@ """Agent tools for extending functionality.""" -from .todo import get_todo_instruction_tpl, create_todo_manager +from .todo import create_todo_manager, get_todo_instruction_tpl -__all__ = ["create_todo_manager", "get_todo_instruction_tpl"] \ No newline at end of file +__all__ = ["create_todo_manager", "get_todo_instruction_tpl"] diff --git a/packages/ragbits-agents/src/ragbits/agents/tools/todo.py b/packages/ragbits-agents/src/ragbits/agents/tools/todo.py index 4213e5b84..5858b4707 100644 --- a/packages/ragbits-agents/src/ragbits/agents/tools/todo.py +++ b/packages/ragbits-agents/src/ragbits/agents/tools/todo.py @@ -1,31 +1,37 @@ """Todo list management tool for agents.""" import uuid +from collections.abc import Callable from dataclasses import dataclass, field from enum import Enum -from typing import Any, Literal, Callable +from typing import Any, Literal + +from pydantic import BaseModel class TaskStatus(str, Enum): """Task status options.""" + PENDING = "pending" IN_PROGRESS = "in_progress" COMPLETED = "completed" -@dataclass -class Task: +class Task(BaseModel): """Simple task representation.""" + id: str description: str status: TaskStatus = TaskStatus.PENDING order: int = 0 summary: str | None = None + parent_id: str | None = None @dataclass class TodoList: """Simple todo list for one agent run.""" + tasks: list[Task] = field(default_factory=list) current_index: int = 0 @@ -35,7 +41,7 @@ def get_current_task(self) -> Task | None: return self.tasks[self.current_index] return None - def advance_to_next(self): + def advance_to_next(self) -> None: """Move to next task.""" self.current_index += 1 @@ -49,18 +55,14 @@ def create_tasks(self, task_descriptions: list[str]) -> dict[str, Any]: self.current_index = 0 for i, desc in enumerate(task_descriptions): - task = Task( - id=str(uuid.uuid4()), - description=desc.strip(), - order=i - ) + task = Task(id=str(uuid.uuid4()), description=desc.strip(), order=i) self.tasks.append(task) return { "action": "create", "tasks": [{"id": t.id, "description": t.description, "order": t.order} for t in self.tasks], "total_count": len(self.tasks), - "message": f"Created {len(task_descriptions)} tasks" + "message": f"Created {len(task_descriptions)} tasks", } def get_current(self) -> dict[str, Any]: @@ -71,14 +73,14 @@ def get_current(self) -> dict[str, Any]: "action": "get_current", "current_task": None, "all_completed": True, - "message": "All tasks completed!" + "message": "All tasks completed!", } return { "action": "get_current", "current_task": {"id": current.id, "description": current.description, "status": current.status.value}, "progress": f"{self.current_index + 1}/{len(self.tasks)}", - "message": f"Current task: {current.description}" + "message": f"Current task: {current.description}", } def start_current_task(self) -> dict[str, Any]: @@ -91,7 +93,7 @@ def start_current_task(self) -> dict[str, Any]: return { "action": "start_task", "task": {"id": current.id, "description": current.description, "status": current.status.value}, - "message": f"Started task: {current.description}" + "message": f"Started task: {current.description}", } def complete_current_task(self, summary: str) -> dict[str, Any]: @@ -119,7 +121,7 @@ def complete_current_task(self, summary: str) -> dict[str, Any]: "next_task": {"id": next_task.id, "description": next_task.description} if next_task else None, "progress": f"{completed_count}/{len(self.tasks)}", "all_completed": next_task is None, - "message": f"Completed: {current.description}" + "message": f"Completed: {current.description}", } def get_final_summary(self) -> dict[str, Any]: @@ -127,11 +129,7 @@ def get_final_summary(self) -> dict[str, Any]: completed_tasks = [t for t in self.tasks if t.status == TaskStatus.COMPLETED] if not completed_tasks: - return { - "action": "get_final_summary", - "final_summary": "", - "message": "No completed tasks found." - } + return {"action": "get_final_summary", "final_summary": "", "message": "No completed tasks found."} # Create comprehensive final summary final_content = [] @@ -147,7 +145,7 @@ def get_final_summary(self) -> dict[str, Any]: "action": "get_final_summary", "final_summary": final_summary, "total_completed": len(completed_tasks), - "message": f"Final summary with {len(completed_tasks)} completed tasks." + "message": f"Final summary with {len(completed_tasks)} completed tasks.", } @@ -163,6 +161,7 @@ def create_todo_manager(todo_list: TodoList) -> Callable[..., dict[str, Any]]: Returns: A todo_manager function that operates on the provided TodoList """ + def todo_manager( action: Literal["create", "get_current", "start_task", "complete_task", "get_final_summary"], tasks: list[str] | None = None, @@ -216,4 +215,4 @@ def get_todo_instruction_tpl(task_range: tuple[int, int] = (3, 5)) -> str: IMPORTANT: Task summaries should be DETAILED and COMPREHENSIVE (3-5 sentences). Include specific information, recommendations, and actionable details. - """ \ No newline at end of file + """ diff --git a/packages/ragbits-chat/CHANGELOG.md b/packages/ragbits-chat/CHANGELOG.md index aebabd8c1..1c43ff86b 100644 --- a/packages/ragbits-chat/CHANGELOG.md +++ b/packages/ragbits-chat/CHANGELOG.md @@ -1,6 +1,7 @@ # CHANGELOG ## Unreleased +- Add todo list component to the UI, add support for todo events in API (#827) ## 1.3.0 (2025-09-11) diff --git a/packages/ragbits-chat/src/ragbits/chat/interface/_interface.py b/packages/ragbits-chat/src/ragbits/chat/interface/_interface.py index 8e4b73df2..7a68af9e9 100644 --- a/packages/ragbits-chat/src/ragbits/chat/interface/_interface.py +++ b/packages/ragbits-chat/src/ragbits/chat/interface/_interface.py @@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator, Callable from typing import Any +from ragbits.agents.tools.todo import Task from ragbits.chat.interface.ui_customization import UICustomization from ragbits.core.audit.metrics import record_metric from ragbits.core.audit.metrics.base import MetricType @@ -251,6 +252,10 @@ def create_usage_response(usage: Usage) -> ChatResponse: content={model: MessageUsage.from_usage(usage) for model, usage in usage.model_breakdown.items()}, ) + @staticmethod + def create_todo_item_response(task: Task) -> ChatResponse: + return ChatResponse(type=ChatResponseType.TODO_ITEM, content=task) + @staticmethod def _sign_state(state: dict[str, Any]) -> str: """ diff --git a/packages/ragbits-chat/src/ragbits/chat/interface/types.py b/packages/ragbits-chat/src/ragbits/chat/interface/types.py index 920003651..75bfd4d01 100644 --- a/packages/ragbits-chat/src/ragbits/chat/interface/types.py +++ b/packages/ragbits-chat/src/ragbits/chat/interface/types.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, ConfigDict, Field +from ragbits.agents.tools.todo import Task from ragbits.chat.auth.types import User from ragbits.chat.interface.forms import UserSettings from ragbits.chat.interface.ui_customization import UICustomization @@ -122,6 +123,7 @@ class ChatResponseType(str, Enum): CHUNKED_CONTENT = "chunked_content" CLEAR_MESSAGE = "clear_message" USAGE = "usage" + TODO_ITEM = "todo_item" class ChatContext(BaseModel): @@ -140,7 +142,16 @@ class ChatResponse(BaseModel): type: ChatResponseType content: ( - str | Reference | StateUpdate | LiveUpdate | list[str] | Image | dict[str, MessageUsage] | ChunkedContent | None + str + | Reference + | StateUpdate + | LiveUpdate + | list[str] + | Image + | dict[str, MessageUsage] + | ChunkedContent + | None + | Task ) def as_text(self) -> str | None: @@ -217,6 +228,12 @@ def as_usage(self) -> dict[str, MessageUsage] | None: """ return cast(dict[str, MessageUsage], self.content) if self.type == ChatResponseType.USAGE else None + def as_task(self) -> Task | None: + """ + Return the content as Task if this is an todo_item response, else None. + """ + return cast(Task, self.content) if self.type == ChatResponseType.TODO_ITEM else None + class ChatMessageRequest(BaseModel): """Client-side chat request interface.""" diff --git a/packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py b/packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py index 6b1191c74..b94bfe4bc 100644 --- a/packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py +++ b/packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py @@ -10,6 +10,7 @@ from pydantic import BaseModel +from ragbits.agents.tools.todo import Task, TaskStatus from ragbits.chat.interface.types import AuthType @@ -82,6 +83,7 @@ def get_models(self) -> dict[str, type[BaseModel | Enum]]: "FeedbackType": FeedbackType, "LiveUpdateType": LiveUpdateType, "MessageRole": MessageRole, + "TaskStatus": TaskStatus, # Core data models "ChatContext": ChatContext, "ChunkedContent": ChunkedContent, @@ -93,6 +95,7 @@ def get_models(self) -> dict[str, type[BaseModel | Enum]]: "FeedbackItem": FeedbackItem, "Image": Image, "MessageUsage": MessageUsage, + "Task": Task, # Configuration models "HeaderCustomization": HeaderCustomization, "UICustomization": UICustomization, @@ -151,6 +154,8 @@ def get_categories(self) -> dict[str, list[str]]: "JWTToken", "User", "MessageUsage", + "Task", + "TaskStatus", ], "configuration": [ "HeaderCustomization", diff --git a/scripts/generate_typescript_from_json_schema.py b/scripts/generate_typescript_from_json_schema.py index 903e85e12..317b7fce8 100644 --- a/scripts/generate_typescript_from_json_schema.py +++ b/scripts/generate_typescript_from_json_schema.py @@ -201,6 +201,7 @@ def _generate_chat_response_union_type() -> str: ("ImageChatResponse", "image", "Image"), ("ClearMessageResponse", "clear_message", "never"), ("MessageUsageChatResponse", "usage", "Record"), + ("TodoItemChatResonse", "todo_item", "Task"), ] internal_response_interfaces = [ diff --git a/typescript/@ragbits/api-client/package.json b/typescript/@ragbits/api-client/package.json index 5d5a27567..f10063cf4 100644 --- a/typescript/@ragbits/api-client/package.json +++ b/typescript/@ragbits/api-client/package.json @@ -5,7 +5,7 @@ "repository": { "type": "git", "url": "https://github.com/deepsense-ai/ragbits" - }, + }, "main": "dist/index.cjs", "module": "dist/index.js", "types": "dist/index.d.ts", diff --git a/typescript/@ragbits/api-client/src/autogen.types.ts b/typescript/@ragbits/api-client/src/autogen.types.ts index 5ac4f55af..0240f250c 100644 --- a/typescript/@ragbits/api-client/src/autogen.types.ts +++ b/typescript/@ragbits/api-client/src/autogen.types.ts @@ -23,6 +23,7 @@ export const ChatResponseType = { ChunkedContent: 'chunked_content', ClearMessage: 'clear_message', Usage: 'usage', + TodoItem: 'todo_item', } as const export type ChatResponseType = TypeFrom @@ -58,6 +59,17 @@ export const MessageRole = { export type MessageRole = TypeFrom +/** + * Represents the TaskStatus enum + */ +export const TaskStatus = { + Pending: 'pending', + InProgress: 'in_progress', + Completed: 'completed', +} as const + +export type TaskStatus = TypeFrom + /** * Represents the AuthType enum */ @@ -170,6 +182,21 @@ export interface MessageUsage { total_tokens: number } +/** + * Simple task representation. + */ +export interface Task { + id: string + description: string + /** + * Task status options. + */ + status: 'pending' | 'in_progress' | 'completed' + order: number + summary: string | null + parent_id: string | null +} + /** * Customization for the header section of the UI. */ @@ -465,6 +492,11 @@ export interface MessageUsageChatResponse { content: Record } +export interface TodoItemChatResonse { + type: 'todo_item' + content: Task +} + export interface ChunkedChatResponse { type: 'chunked_content' content: ChunkedContent @@ -484,3 +516,4 @@ export type ChatResponse = | ImageChatResponse | ClearMessageResponse | MessageUsageChatResponse + | TodoItemChatResonse diff --git a/typescript/ui/__tests__/unit/TodoList.test.tsx b/typescript/ui/__tests__/unit/TodoList.test.tsx new file mode 100644 index 000000000..efc5b6b1b --- /dev/null +++ b/typescript/ui/__tests__/unit/TodoList.test.tsx @@ -0,0 +1,146 @@ +import { cleanup, render, screen, within } from "@testing-library/react"; +import { Task, TaskStatus } from "@ragbits/api-client-react"; +import TodoList from "../../src/core/components/TodoList"; +import { afterEach, describe, expect, it, vi } from "vitest"; + +function makeTask(partial: Partial): Task { + return { + id: partial.id ?? "1", + description: partial.description ?? "Test task", + status: partial.status ?? TaskStatus.Pending, + order: partial.order ?? 1, + summary: partial.summary ?? null, + parent_id: partial.parent_id ?? null, + } as Task; +} + +describe("TodoList", () => { + afterEach(() => { + cleanup(); + vi.clearAllMocks(); + }); + + it("renders root tasks", () => { + const tasks = [ + makeTask({ id: "1", description: "Task A" }), + makeTask({ id: "2", description: "Task B" }), + ]; + + render(); + + expect(screen.getByTestId("todo-task-1")).toBeInTheDocument(); + expect(screen.getByTestId("todo-task-2")).toBeInTheDocument(); + + expect(screen.getByText("Task A")).toBeInTheDocument(); + expect(screen.getByText("Task B")).toBeInTheDocument(); + }); + + it("applies completed styles and checks checkbox", () => { + const tasks = [ + makeTask({ + id: "done", + description: "Done", + status: TaskStatus.Completed, + }), + ]; + + render(); + + const task = screen.getByTestId("todo-task-done"); + const checkbox = within(task).getByRole("checkbox"); + expect(checkbox).toBeChecked(); + expect(screen.getByText("Done")).toHaveClass("line-through"); + }); + + it("shows spinner for in-progress tasks and applies italic class", () => { + const tasks = [ + makeTask({ + id: "working", + description: "Working", + status: TaskStatus.InProgress, + }), + ]; + + render(); + + const task = screen.getByTestId("todo-task-working"); + + const spinner = within(task).getByLabelText("Task is in progress"); + expect(spinner).toBeInTheDocument(); + expect(screen.getByText("Working")).toHaveClass("italic"); + }); + + it("renders nested tasks and applies margin-left on children wrapper", () => { + const tasks = [ + makeTask({ id: "1", description: "Parent" }), + makeTask({ id: "2", description: "Child", parent_id: "1" }), + ]; + + render(); + + expect(screen.getByTestId("todo-task-1")).toBeInTheDocument(); + expect(screen.getByTestId("todo-task-2")).toBeInTheDocument(); + const childrenWrapper = screen.getByTestId("todo-children-wrapper-1"); + expect(childrenWrapper).toHaveStyle("margin-left: 0.5rem"); + }); + + it("orders root tasks according to `order` field", () => { + const tasks = [ + makeTask({ id: "a", description: "A", order: 2 }), + makeTask({ id: "b", description: "B", order: 1 }), + ]; + + render(); + + const root = screen.getByTestId("todo-list-root-0"); + const renderedTasks = Array.from( + root.querySelectorAll('[data-testid^="todo-task-"]'), + ); + const ids = renderedTasks.map((el) => el.getAttribute("data-testid")); + expect(ids).toEqual(["todo-task-b", "todo-task-a"]); + }); + + it("orders children according to `order` field", () => { + const tasks = [ + makeTask({ id: "parent", description: "Parent", order: 1 }), + makeTask({ + id: "c1", + description: "Child 1", + order: 2, + parent_id: "parent", + }), + makeTask({ + id: "c2", + description: "Child 2", + order: 1, + parent_id: "parent", + }), + ]; + + render(); + + const wrapper = screen.getByTestId("todo-children-wrapper-parent"); + const childTasks = Array.from( + wrapper.querySelectorAll('[data-testid^="todo-task-"]'), + ); + const ids = childTasks.map((el) => el.getAttribute("data-testid")); + expect(ids).toEqual(["todo-task-c2", "todo-task-c1"]); + }); + + it("renders task summary when provided", () => { + const tasks = [ + makeTask({ + id: "with-summary", + description: "Task with summary", + summary: "This is a summary", + }), + ]; + + render(); + + const task = screen.getByTestId("todo-task-with-summary"); + expect(task).toBeInTheDocument(); + expect(screen.getByText("Task with summary")).toBeInTheDocument(); + expect(screen.getByText("This is a summary")).toBeInTheDocument(); + }); +}); diff --git a/typescript/ui/src/core/components/ChatMessage/ChatMessage.tsx b/typescript/ui/src/core/components/ChatMessage/ChatMessage.tsx index 989a1d87c..2a802cb6e 100644 --- a/typescript/ui/src/core/components/ChatMessage/ChatMessage.tsx +++ b/typescript/ui/src/core/components/ChatMessage/ChatMessage.tsx @@ -12,6 +12,8 @@ import { useMessage, } from "../../stores/HistoryStore/selectors.ts"; import { MessageRole } from "@ragbits/api-client"; +import TodoList from "../TodoList.tsx"; +import { AnimatePresence, motion } from "framer-motion"; type ChatMessageProps = { classNames?: { @@ -92,6 +94,20 @@ const ChatMessage = forwardRef( classNames={{ liveUpdates: classNames?.liveUpdates }} /> )} + {message.tasks && message.tasks.length > 0 && ( + + +

Execution plan

+ +
+
+ )} + {tasksTree + .getRoots() + .map(({ id, description, status, summary, children }) => { + const inProgressIcon = ( + + ); + return ( +
+
+ inProgressIcon + : undefined + } + classNames={{ + hiddenInput: "cursor-default", + wrapper: + status === TaskStatus.InProgress && "before:border-none", + base: "pointer-events-none hover:bg-transparent", + label: [ + "transition-colors", + status === TaskStatus.Completed && + "line-through text-default-400", + status === TaskStatus.InProgress && "text-primary italic", + status === TaskStatus.Pending && "text-default-900", + ] + .filter(Boolean) + .join(" "), + }} + > + {description} + {summary &&

{summary}

} +
+
+ {children.length > 0 && ( +
+ +
+ )} +
+ ); + })} + + ); +} diff --git a/typescript/ui/src/core/stores/HistoryStore/eventHandlers/eventHandlerRegistry.ts b/typescript/ui/src/core/stores/HistoryStore/eventHandlers/eventHandlerRegistry.ts index 8b4799cc3..7e0d456e9 100644 --- a/typescript/ui/src/core/stores/HistoryStore/eventHandlers/eventHandlerRegistry.ts +++ b/typescript/ui/src/core/stores/HistoryStore/eventHandlers/eventHandlerRegistry.ts @@ -13,6 +13,7 @@ import { handleMessageId, handleReference, handleText, + handleTodoItem, handleUsage, } from "./messageHandlers"; @@ -101,3 +102,6 @@ ChatHandlerRegistry.register(ChatResponseType.ClearMessage, { ChatHandlerRegistry.register(ChatResponseType.Usage, { handle: handleUsage, }); +ChatHandlerRegistry.register(ChatResponseType.TodoItem, { + handle: handleTodoItem, +}); diff --git a/typescript/ui/src/core/stores/HistoryStore/eventHandlers/messageHandlers.ts b/typescript/ui/src/core/stores/HistoryStore/eventHandlers/messageHandlers.ts index 2811625a3..d03437a1b 100644 --- a/typescript/ui/src/core/stores/HistoryStore/eventHandlers/messageHandlers.ts +++ b/typescript/ui/src/core/stores/HistoryStore/eventHandlers/messageHandlers.ts @@ -7,6 +7,7 @@ import { MessageUsageChatResponse, ReferenceChatResponse, TextChatResponse, + TodoItemChatResonse, } from "@ragbits/api-client-react"; import { PrimaryHandler } from "./eventHandlerRegistry"; import { produce } from "immer"; @@ -97,3 +98,22 @@ export const handleUsage: PrimaryHandler = ( const message = draft.history[ctx.messageId]; message.usage = response.content; }; + +export const handleTodoItem: PrimaryHandler = ( + { content }, + draft, + ctx, +) => { + const message = draft.history[ctx.messageId]; + const tasks = message.tasks ?? []; + const newTasks = produce(tasks, (tasksDraft) => { + const taskIndex = tasksDraft.findIndex((t) => t.id === content.id); + if (taskIndex === -1) { + tasksDraft.push(content); + } else { + tasksDraft[taskIndex] = content; + } + }); + + message.tasks = newTasks; +}; diff --git a/typescript/ui/src/core/utils/tasks.ts b/typescript/ui/src/core/utils/tasks.ts new file mode 100644 index 000000000..192fa37b7 --- /dev/null +++ b/typescript/ui/src/core/utils/tasks.ts @@ -0,0 +1,80 @@ +import { Task } from "@ragbits/api-client-react"; + +export interface TaskNode extends Task { + children: TaskNode[]; +} + +export class TaskTree { + private nodes: Map = new Map(); + private roots: TaskNode[] = []; + + constructor(tasks: Task[]) { + this.buildTree(tasks); + } + + private buildTree(tasks: Task[]) { + for (const task of tasks) { + this.nodes.set(task.id, { ...task, children: [] }); + } + + for (const node of this.nodes.values()) { + if (node.parent_id) { + const parent = this.nodes.get(node.parent_id); + if (parent) { + parent.children.push(node); + parent.children.sort((a, b) => a.order - b.order); + } else { + this.roots.push(node); + this.roots.sort((a, b) => a.order - b.order); + } + } else { + this.roots.push(node); + this.roots.sort((a, b) => a.order - b.order); + } + } + } + + *iterate(): IterableIterator { + function* dfs(nodes: TaskNode[]): IterableIterator { + for (const node of nodes) { + yield node; + yield* dfs(node.children); + } + } + yield* dfs(this.roots); + } + + get(id: string): TaskNode | undefined { + return this.nodes.get(id); + } + + update(id: string, updates: Partial): void { + const node = this.nodes.get(id); + if (!node) return; + + Object.assign(node, updates); + if (updates.order !== undefined) { + if (node.parent_id) { + const parent = this.nodes.get(node.parent_id); + parent?.children.sort((a, b) => a.order - b.order); + } else { + this.roots.sort((a, b) => a.order - b.order); + } + } + + if (updates.status === "completed") { + this.completeChildren(node); + } + } + + private completeChildren(node: TaskNode) { + for (const child of node.children) { + child.status = "completed"; + this.completeChildren(child); + } + } + + getRoots(): TaskNode[] { + return this.roots; + } +} diff --git a/typescript/ui/src/types/history.ts b/typescript/ui/src/types/history.ts index 4e3bfc7f8..74205eb36 100644 --- a/typescript/ui/src/types/history.ts +++ b/typescript/ui/src/types/history.ts @@ -7,6 +7,7 @@ import { Image, MessageUsage, RagbitsClient, + Task, } from "@ragbits/api-client-react"; export type UnsubscribeFn = (() => void) | null; @@ -24,6 +25,7 @@ export interface ChatMessage { extensions?: Record; images?: Record; usage?: Record; + tasks?: Task[]; } export interface Conversation {