Skip to content

Commit fa1c301

Browse files
authored
feat(ui): auth token storage (#802)
1 parent 36d8e5a commit fa1c301

File tree

24 files changed

+325
-141
lines changed

24 files changed

+325
-141
lines changed

packages/ragbits-chat/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# CHANGELOG
22

33
## Unreleased
4+
- Add auth token storage and automatic logout on 401 (#802)
45
- Improve user settings storage when history is disabled (#799)
56
- Remove redundant test for `/api/config` endpoint (#795)
67
- Fix bug causing infinite initialization screen (#793)

typescript/@ragbits/api-client-react/tsconfig.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"target": "ES2020",
44
"module": "ESNext",
55
"lib": ["DOM", "DOM.Iterable", "ESNext"],
6-
"moduleResolution": "node",
6+
"moduleResolution": "bundler",
77
"jsx": "react-jsx",
88
"strict": true,
99
"declaration": true,

typescript/@ragbits/api-client/README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ import { RagbitsClient } from '@ragbits/api-client'
2525
// Initialize the client
2626
const client = new RagbitsClient({
2727
baseUrl: 'http://127.0.0.1:8000', // Optional, defaults to http://127.0.0.1:8000
28+
auth: {
29+
getToken: () => 'token',
30+
onUnauthorized: () => {
31+
console.warn('⚠️ Unauthorized')
32+
},
33+
}, // Optional, used to provide auth details
2834
})
2935

3036
// Get API configuration
@@ -80,6 +86,7 @@ new RagbitsClient(config?: ClientConfig)
8086
**Parameters:**
8187

8288
- `config.baseUrl` (optional): Base URL for the API. Defaults to 'http://127.0.0.1:8000'
89+
- `config.auth` (optional): An object containing authentication details. Provide `getToken` to automatically attach `Authorization: Bearer <token>` to every request. `onUnauthorized` is called if the library encounters a 401 status code.
8390

8491
**Throws:** `Error` if the base URL is invalid
8592

@@ -172,8 +179,8 @@ const cleanup = client.makeStreamRequest(
172179
{
173180
message: 'Tell me about AI',
174181
history: [
175-
{ role: 'user', content: 'Hello', id: 'msg-1' },
176-
{ role: 'assistant', content: 'Hi there!', id: 'msg-2' },
182+
{ role: 'user', content: 'Hello' },
183+
{ role: 'assistant', content: 'Hi there!' },
177184
],
178185
context: { user_id: 'user-123' },
179186
},

typescript/@ragbits/api-client/src/index.ts

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import type {
1515
*/
1616
export class RagbitsClient {
1717
private readonly baseUrl: string
18+
private readonly auth: ClientConfig['auth']
1819
private chunkQueue: Map<
1920
string,
2021
{
@@ -29,6 +30,7 @@ export class RagbitsClient {
2930
*/
3031
constructor(config: ClientConfig = {}) {
3132
this.baseUrl = config.baseUrl ?? ''
33+
this.auth = config.auth
3234

3335
if (this.baseUrl.endsWith('/')) {
3436
this.baseUrl = this.baseUrl.slice(0, -1)
@@ -71,13 +73,25 @@ export class RagbitsClient {
7173
url: string,
7274
options: RequestInit = {}
7375
): Promise<Response> {
74-
const defaultOptions: RequestInit = {
75-
headers: {
76-
'Content-Type': 'application/json',
77-
},
76+
const defaultHeaders: Record<string, string> = {
77+
'Content-Type': 'application/json',
78+
}
79+
80+
const headers = {
81+
...defaultHeaders,
82+
...this.normalizeHeaders(options.headers),
83+
}
84+
85+
if (this.auth?.getToken) {
86+
headers['Authorization'] = `Bearer ${this.auth.getToken()}`
87+
}
88+
89+
const response = await fetch(url, { ...options, headers })
90+
91+
if (response.status === 401) {
92+
this.auth?.onUnauthorized?.()
7893
}
7994

80-
const response = await fetch(url, { ...defaultOptions, ...options })
8195
if (!response.ok) {
8296
throw new Error(`HTTP error! status: ${response.status}`)
8397
}
@@ -206,20 +220,34 @@ export class RagbitsClient {
206220

207221
const startStream = async (): Promise<void> => {
208222
try {
223+
const defaultHeaders: Record<string, string> = {
224+
'Content-Type': 'application/json',
225+
Accept: 'text/event-stream',
226+
}
227+
228+
const headers = {
229+
...defaultHeaders,
230+
...customHeaders,
231+
}
232+
233+
if (this.auth?.getToken) {
234+
headers['Authorization'] = `Bearer ${this.auth.getToken()}`
235+
}
236+
209237
const response = await fetch(
210238
this._buildApiUrl(endpoint.toString()),
211239
{
212240
method: 'POST',
213-
headers: {
214-
'Content-Type': 'application/json',
215-
Accept: 'text/event-stream',
216-
...customHeaders,
217-
},
241+
headers,
218242
body: JSON.stringify(data),
219243
signal,
220244
}
221245
)
222246

247+
if (response.status === 401) {
248+
this.auth?.onUnauthorized?.()
249+
}
250+
223251
if (!response.ok) {
224252
throw new Error(`HTTP error! status: ${response.status}`)
225253
}
@@ -254,6 +282,17 @@ export class RagbitsClient {
254282
}
255283
}
256284

285+
private normalizeHeaders(init?: HeadersInit): Record<string, string> {
286+
if (!init) return {}
287+
if (init instanceof Headers) {
288+
return Object.fromEntries(init.entries())
289+
}
290+
if (Array.isArray(init)) {
291+
return Object.fromEntries(init)
292+
}
293+
return init
294+
}
295+
257296
private async handleChunkedContent<T>(
258297
data: T,
259298
callbacks: StreamCallbacks<T>

typescript/@ragbits/api-client/src/types.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ export interface GenericResponse {
1818
*/
1919
export interface ClientConfig {
2020
baseUrl?: string
21+
auth?: {
22+
getToken?: () => string
23+
onUnauthorized?: () => Promise<void> | void
24+
}
2125
}
2226

2327
/**

typescript/@ragbits/api-client/tsconfig.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"target": "ES2020",
44
"module": "ESNext",
55
"lib": ["DOM", "DOM.Iterable", "ESNext"],
6-
"moduleResolution": "node",
6+
"moduleResolution": "bundler",
77
"strict": true,
88
"declaration": true,
99
"esModuleInterop": true,

typescript/ui/__tests__/integration/intergation.test.tsx

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import { createHistoryStore } from "../../src/core/stores/HistoryStore/historySt
3535
import { createStore } from "zustand";
3636
import { useHistoryStore } from "../../src/core/stores/HistoryStore/useHistoryStore";
3737
import { HistoryStore } from "../../src/types/history";
38+
import { API_URL } from "../../src/config";
3839

3940
vi.mock("../../src/core/stores/HistoryStore/useHistoryStore", () => {
4041
return {
@@ -51,6 +52,7 @@ vi.mock("idb-keyval", () => ({
5152
}));
5253

5354
const historyStore = createStore(createHistoryStore);
55+
const ragbitsClient = new RagbitsClient({ baseUrl: API_URL });
5456
historyStore.getState()._internal._setHasHydrated(true);
5557

5658
(useHistoryStore as Mock).mockImplementation(
@@ -72,7 +74,9 @@ describe("Integration tests", () => {
7274
"makeStreamRequest",
7375
);
7476
await act(() => {
75-
historyStore.getState().actions.sendMessage("Test message");
77+
historyStore
78+
.getState()
79+
.actions.sendMessage("Test message", ragbitsClient);
7680
});
7781

7882
expect(makeStreamRequestSpy).toHaveBeenCalledWith(
@@ -84,7 +88,6 @@ describe("Integration tests", () => {
8488
},
8589
expect.anything(), // We don't care about callbacks
8690
expect.anything(), // We don't care about AbortSignal
87-
expect.anything(), // We don't care about headers
8891
);
8992

9093
await waitFor(
@@ -106,7 +109,9 @@ describe("Integration tests", () => {
106109
"makeStreamRequest",
107110
);
108111
await act(() => {
109-
historyStore.getState().actions.sendMessage("Test message 2");
112+
historyStore
113+
.getState()
114+
.actions.sendMessage("Test message 2", ragbitsClient);
110115
});
111116

112117
expect(makeStreamRequestSpy).toHaveBeenCalledWith(
@@ -128,7 +133,6 @@ describe("Integration tests", () => {
128133
},
129134
expect.anything(), // We don't care about callbacks
130135
expect.anything(), // We don't care about AbortSignal
131-
expect.anything(), // We don't care about headers
132136
);
133137

134138
await waitFor(
@@ -155,7 +159,7 @@ describe("Integration tests", () => {
155159
<ConfigContextProvider>
156160
<PromptInput
157161
isLoading={false}
158-
submit={sendMessage}
162+
submit={(text) => sendMessage(text, ragbitsClient)}
159163
stopAnswering={stopAnswering}
160164
followupMessages={getCurrentConversation().followupMessages}
161165
/>
@@ -218,7 +222,6 @@ describe("Integration tests", () => {
218222
},
219223
expect.anything(), // We don't care about callbacks
220224
expect.anything(), // We don't care about AbortSignal
221-
expect.anything(), // We don't care about headers
222225
);
223226
await waitFor(
224227
() => {
@@ -256,7 +259,9 @@ describe("Integration tests", () => {
256259
);
257260

258261
await act(() => {
259-
historyStore.getState().actions.sendMessage("Test message");
262+
historyStore
263+
.getState()
264+
.actions.sendMessage("Test message", ragbitsClient);
260265
});
261266

262267
await waitFor(

typescript/ui/__tests__/unit/plugin/AuthPlugin/AuthGuard.test.tsx

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,31 @@ describe("AuthGuard", () => {
5050
mockPathname = "/"; // default
5151
});
5252

53+
it("renders Initialization screen when not hydrated", () => {
54+
mockPathname = "/login";
55+
useStoreMock.mockImplementation((_, selector) =>
56+
selector({
57+
hasHydrated: false,
58+
}),
59+
);
60+
61+
render(
62+
<AuthGuard>
63+
<div data-testid="child">Login page child</div>
64+
</AuthGuard>,
65+
);
66+
67+
expect(screen.getByText("Initializing...")).toBeInTheDocument();
68+
});
69+
5370
it("renders children if route is /login regardless of auth state", () => {
5471
mockPathname = "/login";
55-
useStoreMock.mockReturnValue(false);
72+
useStoreMock.mockImplementation((_, selector) =>
73+
selector({
74+
hasHydrated: true,
75+
isAuthenticated: false,
76+
}),
77+
);
5678

5779
render(
5880
<AuthGuard>
@@ -67,7 +89,15 @@ describe("AuthGuard", () => {
6789

6890
it("wraps children and renders AuthWatcher if authenticated", () => {
6991
mockPathname = "/dashboard";
70-
useStoreMock.mockReturnValue(true);
92+
useStoreMock.mockImplementation((_, selector) =>
93+
selector({
94+
hasHydrated: true,
95+
isAuthenticated: true,
96+
token: {
97+
access_token: "token",
98+
},
99+
}),
100+
);
71101

72102
render(
73103
<AuthGuard>
@@ -83,7 +113,12 @@ describe("AuthGuard", () => {
83113

84114
it("renders Navigate to /login if not authenticated", () => {
85115
mockPathname = "/dashboard";
86-
useStoreMock.mockReturnValue(false);
116+
useStoreMock.mockImplementation((_, selector) =>
117+
selector({
118+
hasHydrated: true,
119+
isAuthenticated: false,
120+
}),
121+
);
87122

88123
render(
89124
<AuthGuard>

0 commit comments

Comments
 (0)