Skip to content

Commit 53fd73b

Browse files
authored
fix: copilot breaking change introduced in 2.8.5 (#2647)
Fixes #2641 - Custom auth and copilot token auth tests added - Copilot tests refactored - Added e2e test for previously fixed security vulnerability
1 parent 0843245 commit 53fd73b

File tree

7 files changed

+381
-135
lines changed

7 files changed

+381
-135
lines changed

backend/chainlit/socket.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ async def _authenticate_connection(
121121
async def connect(sid: str, environ: WSGIEnvironment, auth: WebSocketSessionAuth):
122122
user: User | PersistedUser | None = None
123123
token: str | None = None
124-
thread_id = auth.get("threadId")
124+
thread_id = auth.get("threadId", None)
125125

126126
if require_login():
127127
try:
@@ -134,14 +134,11 @@ async def connect(sid: str, environ: WSGIEnvironment, auth: WebSocketSessionAuth
134134
raise ConnectionRefusedError("authentication failed")
135135

136136
if thread_id:
137-
data_layer = get_data_layer()
138-
if not data_layer:
139-
logger.error("Data layer is not initialized.")
140-
raise ConnectionRefusedError("data layer not initialized")
141-
142-
if not (await data_layer.get_thread_author(thread_id) == user.identifier):
143-
logger.error("Authorization for the thread failed.")
144-
raise ConnectionRefusedError("authorization failed")
137+
if data_layer := get_data_layer():
138+
thread = await data_layer.get_thread(thread_id)
139+
if thread and not (thread["userIdentifier"] == user.identifier):
140+
logger.error("Authorization for the thread failed.")
141+
raise ConnectionRefusedError("authorization failed")
145142

146143
# Session scoped function to emit to the client
147144
def emit_fn(event, data):
@@ -155,11 +152,11 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
155152
if restore_existing_session(sid, session_id, emit_fn, emit_call_fn):
156153
return True
157154

158-
user_env_string = auth.get("userEnv")
155+
user_env_string = auth.get("userEnv", None)
159156
user_env = load_user_env(user_env_string)
160157

161158
client_type = auth["clientType"]
162-
url_encoded_chat_profile = auth.get("chatProfile")
159+
url_encoded_chat_profile = auth.get("chatProfile", None)
163160
chat_profile = (
164161
unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None
165162
)

cypress/e2e/auth/main.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import os
2+
from uuid import uuid4
3+
4+
import chainlit as cl
5+
from chainlit.auth import create_jwt
6+
from chainlit.server import _authenticate_user, app
7+
from chainlit.user import User
8+
from fastapi import Request, Response
9+
10+
os.environ["CHAINLIT_AUTH_SECRET"] = "SUPER_SECRET" # nosec B105
11+
os.environ["CHAINLIT_CUSTOM_AUTH"] = "true"
12+
13+
14+
@app.get("/auth/custom")
15+
async def custom_auth(request: Request) -> Response:
16+
user_id = str(uuid4())
17+
18+
user = User(identifier=user_id, metadata={"role": "user"})
19+
response = await _authenticate_user(request, user)
20+
21+
return response
22+
23+
24+
@app.get("/auth/token")
25+
async def custom_token_auth() -> Response:
26+
user_id = str(uuid4())
27+
28+
user = User(identifier=user_id, metadata={"role": "admin"})
29+
response = create_jwt(user)
30+
31+
return response
32+
33+
34+
catch_all_route = None
35+
for route in app.routes:
36+
if route.path == "/{full_path:path}":
37+
catch_all_route = route
38+
39+
if catch_all_route:
40+
app.routes.remove(catch_all_route)
41+
app.routes.append(catch_all_route)
42+
43+
44+
@cl.on_chat_start
45+
async def on_chat_start():
46+
user = cl.user_session.get("user")
47+
await cl.Message(f"Hello {user.identifier}").send()
48+
49+
50+
@cl.on_message
51+
async def on_message(msg: cl.Message):
52+
await cl.Message(content=f"Echo: {msg.content}").send()

cypress/e2e/auth/spec.cy.ts

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import { loadCopilotScript, mountCopilotWidget, openCopilot, submitMessage } from '../../support/testUtils';
2+
3+
function login() {
4+
return cy.request({
5+
method: 'GET',
6+
url: '/auth/custom',
7+
followRedirect: false
8+
})
9+
}
10+
11+
function getToken() {
12+
return cy.request({
13+
method: 'GET',
14+
url: '/auth/token',
15+
followRedirect: false
16+
})
17+
}
18+
19+
function shouldShowGreetingMessage() {
20+
it('should show greeting message', () => {
21+
cy.get('.step').should('exist');
22+
cy.get('.step').should('contain', 'Hello');
23+
});
24+
}
25+
26+
function shouldSendMessageAndRecieveAnswer() {
27+
it('should send message and receive answer', () => {
28+
cy.get('.step').should('contain', 'Hello');
29+
30+
const testMessage = 'Test message from custom auth';
31+
submitMessage(testMessage);
32+
33+
cy.get('.step').should('contain', 'Echo:');
34+
cy.get('.step').should('contain', testMessage);
35+
});
36+
37+
}
38+
39+
describe('Custom Auth', () => {
40+
describe('when unauthenticated', () => {
41+
beforeEach(() => {
42+
cy.intercept('GET', '/user').as('user');
43+
});
44+
45+
it('should attempt to and not have permission to access /user', () => {
46+
cy.wait('@user').then((interception) => {
47+
expect(interception.response.statusCode).to.equal(401);
48+
});
49+
});
50+
51+
it('should redirect to login dialog', () => {
52+
cy.location('pathname').should('eq', '/login');
53+
});
54+
});
55+
56+
describe('authenticating via custom endpoint', () => {
57+
beforeEach(() => {
58+
login().then((response) => {
59+
expect(response.status).to.equal(200);
60+
// Verify cookie is set in response headers
61+
expect(response.headers).to.have.property('set-cookie');
62+
const cookies = Array.isArray(response.headers['set-cookie'])
63+
? response.headers['set-cookie']
64+
: [response.headers['set-cookie']];
65+
expect(cookies[0]).to.contain('access_token');
66+
});
67+
});
68+
69+
const shouldBeLoggedIn = () => {
70+
it('should not be on /login', () => {
71+
cy.location('pathname').should('not.contain', '/login');
72+
});
73+
74+
shouldShowGreetingMessage();
75+
};
76+
77+
shouldBeLoggedIn();
78+
79+
it('should request and have access to /user', () => {
80+
cy.intercept('GET', '/user').as('user');
81+
cy.wait('@user').then((interception) => {
82+
expect(interception.response.statusCode).to.equal(200);
83+
});
84+
});
85+
86+
shouldSendMessageAndRecieveAnswer();
87+
88+
describe('after reloading', () => {
89+
beforeEach(() => {
90+
cy.reload();
91+
});
92+
93+
shouldBeLoggedIn();
94+
});
95+
});
96+
});
97+
98+
describe('Copilot Token', { includeShadowDom: true }, () => {
99+
beforeEach(() => {
100+
cy.location('pathname').should('eq', '/login');
101+
102+
loadCopilotScript();
103+
});
104+
105+
describe('when unauthenticated', () => {
106+
it('should throw error about missing authentication token', () => {
107+
mountCopilotWidget();
108+
openCopilot();
109+
cy.get('#chainlit-copilot-chat').should('contain', 'No authentication token provided.');
110+
});
111+
});
112+
113+
describe('authenticating via custom endpoint', () => {
114+
beforeEach(() => {
115+
getToken().then((response) => {
116+
expect(response.status).to.equal(200);
117+
118+
const accessToken = response.body
119+
expect(accessToken).to.not.be.null;
120+
121+
mountCopilotWidget({ accessToken });
122+
openCopilot();
123+
});
124+
})
125+
126+
shouldShowGreetingMessage();
127+
128+
shouldSendMessageAndRecieveAnswer();
129+
});
130+
});

0 commit comments

Comments
 (0)