diff --git a/backend/chainlit/data/chainlit_data_layer.py b/backend/chainlit/data/chainlit_data_layer.py
index 847668b6de..2d0cb729f3 100644
--- a/backend/chainlit/data/chainlit_data_layer.py
+++ b/backend/chainlit/data/chainlit_data_layer.py
@@ -338,9 +338,9 @@ async def create_step(self, step_dict: StepDict):
query = """
INSERT INTO "Step" (
id, "threadId", "parentId", input, metadata, name, output,
- type, "startTime", "endTime", "showInput", "isError"
+ type, "startTime", "endTime", "showInput", "isError", icon
) VALUES (
- $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12
+ $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
)
ON CONFLICT (id) DO UPDATE SET
"parentId" = COALESCE(EXCLUDED."parentId", "Step"."parentId"),
@@ -359,7 +359,8 @@ async def create_step(self, step_dict: StepDict):
"endTime" = COALESCE(EXCLUDED."endTime", "Step"."endTime"),
"startTime" = LEAST(EXCLUDED."startTime", "Step"."startTime"),
"showInput" = COALESCE(EXCLUDED."showInput", "Step"."showInput"),
- "isError" = COALESCE(EXCLUDED."isError", "Step"."isError")
+ "isError" = COALESCE(EXCLUDED."isError", "Step"."isError"),
+ icon = COALESCE(EXCLUDED.icon, "Step".icon)
"""
timestamp = await self.get_current_timestamp()
@@ -380,6 +381,7 @@ async def create_step(self, step_dict: StepDict):
"end_time": timestamp,
"show_input": str(step_dict.get("showInput", "json")),
"is_error": step_dict.get("isError", False),
+ "icon": step_dict.get("icon"),
}
await self.execute_query(query, params)
diff --git a/backend/chainlit/data/literalai.py b/backend/chainlit/data/literalai.py
index 2d750b6f86..d8df7dbaaa 100644
--- a/backend/chainlit/data/literalai.py
+++ b/backend/chainlit/data/literalai.py
@@ -367,6 +367,7 @@ async def create_step(self, step_dict: "StepDict"):
waitForAnswer=step_dict.get("waitForAnswer"),
language=step_dict.get("language"),
showInput=step_dict.get("showInput"),
+ icon=step_dict.get("icon"),
)
step: LiteralStepDict = {
diff --git a/backend/chainlit/step.py b/backend/chainlit/step.py
index ec77746a19..edde2d78e0 100644
--- a/backend/chainlit/step.py
+++ b/backend/chainlit/step.py
@@ -63,6 +63,7 @@ class StepDict(TypedDict, total=False):
showInput: Optional[Union[bool, str]]
defaultOpen: Optional[bool]
language: Optional[str]
+ icon: Optional[str]
feedback: Optional[FeedbackDict]
@@ -83,6 +84,7 @@ def step(
tags: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
language: Optional[str] = None,
+ icon: Optional[str] = None,
show_input: Union[bool, str] = "json",
default_open: bool = False,
):
@@ -106,6 +108,7 @@ async def async_wrapper(*args, **kwargs):
parent_id=parent_id,
tags=tags,
language=language,
+ icon=icon,
show_input=show_input,
default_open=default_open,
metadata=metadata,
@@ -135,6 +138,7 @@ def sync_wrapper(*args, **kwargs):
parent_id=parent_id,
tags=tags,
language=language,
+ icon=icon,
show_input=show_input,
default_open=default_open,
metadata=metadata,
@@ -182,6 +186,7 @@ class Step:
end: Union[str, None]
generation: Optional[BaseGeneration]
language: Optional[str]
+ icon: Optional[str]
default_open: Optional[bool]
elements: Optional[List[Element]]
fail_on_persist_error: bool
@@ -196,6 +201,7 @@ def __init__(
metadata: Optional[Dict] = None,
tags: Optional[List[str]] = None,
language: Optional[str] = None,
+ icon: Optional[str] = None,
default_open: Optional[bool] = False,
show_input: Union[bool, str] = "json",
thread_id: Optional[str] = None,
@@ -214,6 +220,7 @@ def __init__(
self.parent_id = parent_id
self.language = language
+ self.icon = icon
self.default_open = default_open
self.generation = None
self.elements = elements or []
@@ -303,6 +310,7 @@ def to_dict(self) -> StepDict:
"start": self.start,
"end": self.end,
"language": self.language,
+ "icon": self.icon,
"defaultOpen": self.default_open,
"showInput": self.show_input,
"generation": self.generation.to_dict() if self.generation else None,
diff --git a/backend/tests/data/test_literalai.py b/backend/tests/data/test_literalai.py
index 1eeb91243e..e7506b1d36 100644
--- a/backend/tests/data/test_literalai.py
+++ b/backend/tests/data/test_literalai.py
@@ -83,6 +83,7 @@ def test_step_dict(test_thread) -> StepDict:
"waitForAnswer": True,
"showInput": True,
"language": "en",
+ "icon": "search",
}
@@ -179,6 +180,7 @@ async def test_create_step(
"waitForAnswer": True,
"language": "en",
"showInput": True,
+ "icon": "search",
},
"input": {"content": "test input"},
"output": {"content": "test output"},
@@ -768,6 +770,7 @@ async def test_update_step(
"waitForAnswer": True,
"language": "en",
"showInput": True,
+ "icon": "search",
},
"input": {"content": "test input"},
"output": {"content": "test output"},
diff --git a/backend/tests/data/test_sql_alchemy.py b/backend/tests/data/test_sql_alchemy.py
index decd5e34c5..3afcfe401e 100644
--- a/backend/tests/data/test_sql_alchemy.py
+++ b/backend/tests/data/test_sql_alchemy.py
@@ -75,6 +75,7 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
"generation" JSONB,
"showInput" TEXT,
"language" TEXT,
+ "icon" TEXT,
"indent" INT
);
"""
diff --git a/backend/tests/test_emitter.py b/backend/tests/test_emitter.py
index 9a8290583c..984907476b 100644
--- a/backend/tests/test_emitter.py
+++ b/backend/tests/test_emitter.py
@@ -54,6 +54,22 @@ async def test_send_step(
mock_websocket_session.emit.assert_called_once_with("new_message", step_dict)
+async def test_send_step_with_icon(
+ emitter: ChainlitEmitter, mock_websocket_session: MagicMock
+) -> None:
+ step_dict: StepDict = {
+ "id": "test_step_with_icon",
+ "type": "tool",
+ "name": "Test Step with Icon",
+ "output": "This is a test step with an icon",
+ "icon": "search",
+ }
+
+ await emitter.send_step(step_dict)
+
+ mock_websocket_session.emit.assert_called_once_with("new_message", step_dict)
+
+
async def test_update_step(
emitter: ChainlitEmitter, mock_websocket_session: MagicMock
) -> None:
@@ -69,6 +85,22 @@ async def test_update_step(
mock_websocket_session.emit.assert_called_once_with("update_message", step_dict)
+async def test_update_step_with_icon(
+ emitter: ChainlitEmitter, mock_websocket_session: MagicMock
+) -> None:
+ step_dict: StepDict = {
+ "id": "test_step_with_icon",
+ "type": "tool",
+ "name": "Updated Test Step with Icon",
+ "output": "This is an updated test step with an icon",
+ "icon": "database",
+ }
+
+ await emitter.update_step(step_dict)
+
+ mock_websocket_session.emit.assert_called_once_with("update_message", step_dict)
+
+
async def test_delete_step(
emitter: ChainlitEmitter, mock_websocket_session: MagicMock
) -> None:
@@ -139,6 +171,20 @@ async def test_stream_start(
mock_websocket_session.emit.assert_called_once_with("stream_start", step_dict)
+async def test_stream_start_with_icon(
+ emitter: ChainlitEmitter, mock_websocket_session: MagicMock
+) -> None:
+ step_dict: StepDict = {
+ "id": "test_stream_with_icon",
+ "type": "tool",
+ "name": "Test Stream with Icon",
+ "output": "This is a test stream with an icon",
+ "icon": "cpu",
+ }
+ await emitter.stream_start(step_dict)
+ mock_websocket_session.emit.assert_called_once_with("stream_start", step_dict)
+
+
async def test_send_toast(
emitter: ChainlitEmitter, mock_websocket_session: MagicMock
) -> None:
diff --git a/cypress/e2e/step_icon/main.py b/cypress/e2e/step_icon/main.py
new file mode 100644
index 0000000000..a5587605d8
--- /dev/null
+++ b/cypress/e2e/step_icon/main.py
@@ -0,0 +1,33 @@
+import chainlit as cl
+
+
+@cl.step(name="search", type="tool", icon="search")
+async def search():
+ await cl.sleep(1)
+ return "Response from search"
+
+
+@cl.step(name="database", type="tool", icon="database")
+async def database():
+ await cl.sleep(1)
+ return "Response from database"
+
+
+@cl.step(name="regular", type="tool")
+async def regular():
+ await cl.sleep(1)
+ return "Response from regular"
+
+
+async def cpu():
+ async with cl.Step(name="cpu", type="tool", icon="cpu") as s:
+ await cl.sleep(1)
+ s.output = "Response from cpu"
+
+
+@cl.on_message
+async def main(message: cl.Message):
+ await search()
+ await database()
+ await regular()
+ await cpu()
diff --git a/cypress/e2e/step_icon/spec.cy.ts b/cypress/e2e/step_icon/spec.cy.ts
new file mode 100644
index 0000000000..a0240af594
--- /dev/null
+++ b/cypress/e2e/step_icon/spec.cy.ts
@@ -0,0 +1,42 @@
+import { submitMessage } from '../../support/testUtils';
+
+describe('Step with Icon', () => {
+ it('should display icons for steps with icon property', () => {
+ submitMessage('Hello');
+
+ cy.get('.step').should('have.length', 5);
+
+ // Check that steps with icons have SVG icons (not avatar images)
+ // The avatar is a sibling of the step content in the .ai-message container
+ cy.get('#step-search')
+ .closest('.ai-message')
+ .within(() => {
+ // Should have an svg icon (Lucide icons are SVGs)
+ cy.get('svg').should('exist');
+ // Should NOT have an avatar image
+ cy.get('img').should('not.exist');
+ });
+
+ cy.get('#step-database')
+ .closest('.ai-message')
+ .within(() => {
+ cy.get('svg').should('exist');
+ cy.get('img').should('not.exist');
+ });
+
+ // Check that step without icon has avatar (image)
+ cy.get('#step-regular')
+ .closest('.ai-message')
+ .within(() => {
+ // Should have an avatar image
+ cy.get('img').should('exist');
+ });
+
+ cy.get('#step-cpu')
+ .closest('.ai-message')
+ .within(() => {
+ cy.get('svg').should('exist');
+ cy.get('img').should('not.exist');
+ });
+ });
+});
diff --git a/frontend/src/components/chat/Messages/Message/Avatar.tsx b/frontend/src/components/chat/Messages/Message/Avatar.tsx
index c414cfefb0..525a701c44 100644
--- a/frontend/src/components/chat/Messages/Message/Avatar.tsx
+++ b/frontend/src/components/chat/Messages/Message/Avatar.tsx
@@ -8,6 +8,7 @@ import {
useConfig
} from '@chainlit/react-client';
+import Icon from '@/components/Icon';
import { Avatar, AvatarFallback, AvatarImage } from '@/components/ui/avatar';
import { Skeleton } from '@/components/ui/skeleton';
import {
@@ -21,9 +22,10 @@ interface Props {
author?: string;
hide?: boolean;
isError?: boolean;
+ iconName?: string;
}
-const MessageAvatar = ({ author, hide, isError }: Props) => {
+const MessageAvatar = ({ author, hide, isError, iconName }: Props) => {
const apiClient = useContext(ChainlitContext);
const { chatProfile } = useChatSession();
const { config } = useConfig();
@@ -48,22 +50,29 @@ const MessageAvatar = ({ author, hide, isError }: Props) => {
);
}
+ // Render icon or avatar based on iconName
+ const avatarContent = iconName ? (
+
+
+
+ ) : (
+
+
+
+
+
+
+ );
+
return (
-
-
-
-
-
-
-
-
+ {avatarContent}
{author}
diff --git a/frontend/src/components/chat/Messages/Message/index.tsx b/frontend/src/components/chat/Messages/Message/index.tsx
index e1957819e0..bde80d0708 100644
--- a/frontend/src/components/chat/Messages/Message/index.tsx
+++ b/frontend/src/components/chat/Messages/Message/index.tsx
@@ -99,6 +99,7 @@ const Message = memo(
) : null}
{/* Display the step and its children */}
diff --git a/libs/react-client/src/types/step.ts b/libs/react-client/src/types/step.ts
index 08e9af4065..c0c7e274a7 100644
--- a/libs/react-client/src/types/step.ts
+++ b/libs/react-client/src/types/step.ts
@@ -16,6 +16,7 @@ export interface IStep {
id: string;
name: string;
type: StepType;
+ icon?: string;
threadId?: string;
parentId?: string;
isError?: boolean;