Skip to content

Commit 3a10d6d

Browse files
committed
Absolutely glorious domain invariant enforcement
1 parent 426de54 commit 3a10d6d

File tree

18 files changed

+171
-177
lines changed

18 files changed

+171
-177
lines changed

src/client/components/ChatV2/ChatV2.tsx

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,15 @@ import { enqueueSnackbar } from 'notistack'
77
import { useCallback, useEffect, useRef, useState } from 'react'
88
import { useTranslation } from 'react-i18next'
99
import { useParams, useSearchParams } from 'react-router-dom'
10-
import { DEFAULT_MODEL, DEFAULT_MODEL_TEMPERATURE, FREE_MODEL, inProduction, validModels } from '../../../config'
10+
import { DEFAULT_MODEL, DEFAULT_MODEL_TEMPERATURE, FREE_MODEL, inProduction, ValidModelName, ValidModelNameSchema, validModels } from '../../../config'
1111
import type { ChatMessage, MessageGenerationInfo, ToolCallResultEvent } from '../../../shared/chat'
12-
import type { RagIndexAttributes } from '../../../shared/types'
1312
import { getLanguageValue } from '../../../shared/utils'
1413
import { useIsEmbedded } from '../../contexts/EmbeddedContext'
1514
import { useChatScroll } from './useChatScroll'
1615
import useCourse from '../../hooks/useCourse'
1716
import useCurrentUser from '../../hooks/useCurrentUser'
1817
import useInfoTexts from '../../hooks/useInfoTexts'
1918
import useLocalStorageState from '../../hooks/useLocalStorageState'
20-
import { useCourseRagIndices } from '../../hooks/useRagIndices'
2119
import useRetryTimeout from '../../hooks/useRetryTimeout'
2220
import useUserStatus from '../../hooks/useUserStatus'
2321
import { useAnalyticsDispatch } from '../../stores/analytics'
@@ -31,31 +29,37 @@ import { handleCompletionStreamError } from './error'
3129
import ToolResult from './ToolResult'
3230
import { OutlineButtonBlack } from './general/Buttons'
3331
import { ChatInfo } from './general/ChatInfo'
34-
import RagSelector, { RagSelectorDescription } from './RagSelector'
3532
import { SettingsModal } from './SettingsModal'
3633
import { useChatStream } from './useChatStream'
3734
import { postCompletionStreamV3 } from './api'
3835
import PromptSelector from './PromptSelector'
3936
import ModelSelector from './ModelSelector'
4037
import { ConversationSplash } from './general/ConversationSplash'
4138
import { PromptStateProvider, usePromptState } from './PromptState'
39+
import z from 'zod/v4'
4240

43-
function useLocalStorageStateWithURLDefault(key: string, defaultValue: string, urlKey: string) {
41+
function useLocalStorageStateWithURLDefault<T>(key: string, defaultValue: string, urlKey: string, schema: z.ZodType<T>) {
4442
const [value, setValue] = useLocalStorageState(key, defaultValue)
4543
const [searchParams, setSearchParams] = useSearchParams()
4644
const urlValue = searchParams.get(urlKey)
4745

4846
// If urlValue is defined, it overrides the localStorage setting.
4947
// However if user changes the setting, the urlValue is removed.
50-
const modifiedSetValue = (newValue: string) => {
48+
const modifiedSetValue = (newValue: T) => {
5149
if (newValue !== urlValue) {
52-
setValue(newValue)
50+
if (typeof newValue === 'string') {
51+
setValue(newValue)
52+
} else {
53+
setValue(String(newValue))
54+
}
5355
searchParams.delete(urlKey)
5456
setSearchParams(searchParams)
5557
}
5658
}
5759

58-
return [urlValue ?? value, modifiedSetValue] as const
60+
const parsedValue = schema.parse(urlValue ?? value)
61+
62+
return [parsedValue, modifiedSetValue] as const
5963
}
6064

6165
const ChatV2Content = () => {
@@ -65,7 +69,6 @@ const ChatV2Content = () => {
6569
const isMobile = useMediaQuery(theme.breakpoints.down('md'))
6670

6771
const { data: course } = useCourse(courseId)
68-
const { ragIndices } = useCourseRagIndices(course?.id)
6972
const { infoTexts } = useInfoTexts()
7073

7174
const { userStatus, isLoading: statusLoading, refetch: refetchStatus } = useUserStatus(courseId)
@@ -74,12 +77,13 @@ const ChatV2Content = () => {
7477

7578
// local storage states
7679
const localStoragePrefix = courseId ? `course-${courseId}` : 'general'
77-
const [activeModel, setActiveModel] = useLocalStorageStateWithURLDefault('model-v2', DEFAULT_MODEL, 'model')
80+
const [activeModel, setActiveModel] = useLocalStorageStateWithURLDefault('model-v2', DEFAULT_MODEL, 'model', ValidModelNameSchema)
7881
const [disclaimerStatus, setDisclaimerStatus] = useLocalStorageState<boolean>('disclaimer-status', true)
7982
const [modelTemperature, setModelTemperature] = useLocalStorageStateWithURLDefault(
8083
`${localStoragePrefix}-chat-model-temperature`,
8184
String(DEFAULT_MODEL_TEMPERATURE),
8285
'temperature',
86+
z.number(),
8387
)
8488

8589
const [messages, setMessages] = useLocalStorageState(`${localStoragePrefix}-chat-messages`, [] as ChatMessage[])
@@ -90,12 +94,9 @@ const ChatV2Content = () => {
9094
const [fileName, setFileName] = useState<string>('')
9195
const [tokenUsageWarning, setTokenUsageWarning] = useState<string>('')
9296
const [tokenUsageAlertOpen, setTokenUsageAlertOpen] = useState<boolean>(false)
93-
const [allowedModels, setAllowedModels] = useState<string[]>([])
97+
const [allowedModels, setAllowedModels] = useState<ValidModelName[]>([])
9498
const [chatLeftSidePanelOpen, setChatLeftSidePanelOpen] = useState<boolean>(false)
95-
// RAG states
96-
const [ragIndexId, _setRagIndexId] = useState<number | undefined>()
9799
const [activeToolResult, setActiveToolResult0] = useState<ToolCallResultEvent | undefined>()
98-
const ragIndex = ragIndices?.find((index) => index.id === ragIndexId)
99100

100101
// Analytics
101102
const dispatchAnalytics = useAnalyticsDispatch()
@@ -106,11 +107,9 @@ const ChatV2Content = () => {
106107
model: activeModel,
107108
courseId,
108109
nMessages: messages.length,
109-
ragIndexId,
110-
ragIndexName: ragIndex?.metadata.name,
111110
},
112111
})
113-
}, [messages, courseId, ragIndexId, activeModel, dispatchAnalytics])
112+
}, [messages, courseId, activeModel, dispatchAnalytics])
114113

115114
// Refs
116115
const chatContainerRef = useRef<HTMLDivElement | null>(null)
@@ -194,16 +193,23 @@ const ChatV2Content = () => {
194193
}
195194

196195
try {
197-
const { tokenUsageAnalysis, stream } = await postCompletionStreamV3({
198-
generationInfo,
199-
messages: newMessages,
200-
ragIndexId,
196+
if (!streamController) {
197+
throw new Error('streamController is not defined')
198+
}
199+
200+
const { tokenUsageAnalysis, stream } = await postCompletionStreamV3(
201201
formData,
202-
modelTemperature: parseFloat(modelTemperature),
203-
courseId,
204-
abortController: streamController,
205-
saveConsent,
206-
})
202+
{
203+
options: {
204+
generationInfo,
205+
chatMessages: newMessages,
206+
modelTemperature,
207+
saveConsent,
208+
},
209+
courseId,
210+
},
211+
streamController,
212+
)
207213

208214
if (!stream && !tokenUsageAnalysis) {
209215
console.error('getCompletionStream did not return a stream or token usage analysis')
@@ -263,7 +269,7 @@ const ChatV2Content = () => {
263269

264270
const { usage, limit, model: defaultCourseModel, models: courseModels } = userStatus
265271

266-
let allowedModels: string[] = []
272+
let allowedModels: ValidModelName[] = []
267273

268274
if (course && courseModels) {
269275
allowedModels = courseModels
@@ -558,8 +564,8 @@ const ChatV2Content = () => {
558564
<SettingsModal
559565
open={settingsModalOpen}
560566
setOpen={setSettingsModalOpen}
561-
modelTemperature={parseFloat(modelTemperature)}
562-
setModelTemperature={(updatedTemperature) => setModelTemperature(String(updatedTemperature))}
567+
modelTemperature={modelTemperature}
568+
setModelTemperature={(updatedTemperature) => setModelTemperature(updatedTemperature)}
563569
/>
564570

565571
<DisclaimerModal disclaimer={disclaimerInfo} disclaimerStatus={disclaimerStatus} setDisclaimerStatus={setDisclaimerStatus} />
@@ -586,9 +592,9 @@ const LeftMenu = ({
586592
setSettingsModalOpen: React.Dispatch<React.SetStateAction<boolean>>
587593
setDisclaimerStatus: React.Dispatch<React.SetStateAction<boolean>>
588594
messages: ChatMessage[]
589-
currentModel: string
590-
setModel: (model: string) => void
591-
availableModels: string[]
595+
currentModel: ValidModelName
596+
setModel: (model: ValidModelName) => void
597+
availableModels: ValidModelName[]
592598
}) => {
593599
const { t } = useTranslation()
594600
const { courseId } = useParams()

src/client/components/ChatV2/ModelSelector.tsx

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import React from 'react'
22
import { useTranslation } from 'react-i18next'
33
import { MenuItem, Typography, Tooltip, Menu } from '@mui/material'
44
import { KeyboardArrowDown, SmartToy } from '@mui/icons-material'
5-
import { FREE_MODEL } from '../../../config'
5+
import { FREE_MODEL, ValidModelName } from '../../../config'
66
import { OutlineButtonBlack } from './general/Buttons'
77

88
const ModelSelector = ({
@@ -11,9 +11,9 @@ const ModelSelector = ({
1111
availableModels,
1212
isTokenLimitExceeded,
1313
}: {
14-
currentModel: string
15-
setModel: (model: string) => void
16-
availableModels: string[]
14+
currentModel: ValidModelName
15+
setModel: (model: ValidModelName) => void
16+
availableModels: ValidModelName[]
1717
isTokenLimitExceeded: boolean
1818
}) => {
1919
const { t } = useTranslation()
@@ -25,19 +25,14 @@ const ModelSelector = ({
2525
setAnchorEl(event.currentTarget)
2626
}
2727

28-
const handleSelect = (model: string) => {
28+
const handleSelect = (model: ValidModelName) => {
2929
setModel(model)
3030
setAnchorEl(null)
3131
}
3232

3333
return (
3434
<>
35-
<OutlineButtonBlack
36-
startIcon={<SmartToy />}
37-
endIcon={<KeyboardArrowDown />}
38-
onClick={handleClick}
39-
data-testid="model-selector"
40-
>
35+
<OutlineButtonBlack startIcon={<SmartToy />} endIcon={<KeyboardArrowDown />} onClick={handleClick} data-testid="model-selector">
4136
{`${t('admin:model')}: ${validModel}`}
4237
</OutlineButtonBlack>
4338
<Menu
@@ -56,8 +51,10 @@ const ModelSelector = ({
5651
<MenuItem
5752
key={model}
5853
value={model}
59-
onClick={() => handleSelect(model)} disabled={isTokenLimitExceeded && model !== FREE_MODEL}
60-
data-testid={`${model}-option`}>
54+
onClick={() => handleSelect(model)}
55+
disabled={isTokenLimitExceeded && model !== FREE_MODEL}
56+
data-testid={`${model}-option`}
57+
>
6158
<Typography>
6259
{model}
6360
{model === FREE_MODEL && (
Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { ChatMessage, MessageGenerationInfo } from '../../../shared/chat'
1+
import type { PostStreamSchemaV3Type } from '../../../shared/chat'
22
import { postAbortableStream } from '../../util/apiClient'
33
import type { ChatToolOutput } from '../../../shared/tools'
44
import { useGetQuery } from '../../hooks/apiHooks'
@@ -12,45 +12,8 @@ export const useToolResults = (toolCallId: string) => {
1212
})
1313
}
1414

15-
interface PostCompletionStreamProps {
16-
generationInfo: MessageGenerationInfo
17-
courseId?: string
18-
messages: ChatMessage[]
19-
formData: FormData
20-
ragIndexId?: number
21-
userConsent?: boolean
22-
modelTemperature: number
23-
prevResponseId?: string
24-
abortController?: AbortController
25-
saveConsent: boolean
26-
}
27-
export const postCompletionStreamV3 = async ({
28-
generationInfo,
29-
courseId,
30-
messages,
31-
formData,
32-
ragIndexId,
33-
userConsent,
34-
modelTemperature,
35-
prevResponseId,
36-
abortController,
37-
saveConsent,
38-
}: PostCompletionStreamProps) => {
39-
const data = {
40-
courseId,
41-
options: {
42-
chatMessages: messages,
43-
systemMessage: generationInfo.promptInfo.systemMessage,
44-
model: generationInfo.model,
45-
ragIndexId,
46-
userConsent,
47-
modelTemperature,
48-
saveConsent,
49-
prevResponseId,
50-
},
51-
}
52-
53-
formData.set('data', JSON.stringify(data))
15+
export const postCompletionStreamV3 = async (formData: FormData, input: PostStreamSchemaV3Type, abortController: AbortController) => {
16+
formData.set('data', JSON.stringify(input))
5417

5518
return postAbortableStream('/ai/v3/stream', formData, abortController)
5619
}

src/client/components/Courses/Course/EditCourseForm.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import { enqueueSnackbar } from 'notistack'
77

88
import { Course, SetState, User } from '../../../types'
99
import { useEditCourseMutation } from '../../../hooks/useCourseMutation'
10-
import { validModels } from '../../../../config'
10+
import { ValidModelName, validModels } from '../../../../config'
1111

1212
const EditCourseForm = forwardRef(({ course, setOpen, user }: { course: Course; setOpen: SetState<boolean>; user: User }, ref) => {
1313
const { t } = useTranslation()
@@ -64,7 +64,7 @@ const EditCourseForm = forwardRef(({ course, setOpen, user }: { course: Course;
6464
{t('admin:model')}
6565
</Typography>
6666
<Typography mb={1}>{t('admin:modelInfo')}</Typography>
67-
<Select sx={{ m: 1, width: '300px' }} value={model} onChange={(e) => setModel(e.target.value)}>
67+
<Select sx={{ m: 1, width: '300px' }} value={model} onChange={(e) => setModel(e.target.value as ValidModelName)}>
6868
{validModels.map(({ name: modelName }) => (
6969
<MenuItem key={modelName} value={modelName}>
7070
{modelName}

src/client/types.ts

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { ValidModelName } from '../config'
12
import type { ChatMessage } from '../shared/chat'
23
import type { UserPreferences } from '../shared/user'
34

@@ -71,7 +72,7 @@ export interface ChatInstance {
7172
id: string
7273
name: Locales
7374
description: string
74-
model: string
75+
model: ValidModelName
7576
usageLimit: number
7677
resetCron?: string
7778
courseId?: string
@@ -83,7 +84,7 @@ export interface ChatInstance {
8384
export interface AccessGroup {
8485
id: string
8586
iamGroup: string
86-
model: string
87+
model: ValidModelName
8788
usageLimit: number | null
8889
resetCron: string | null
8990
}
@@ -139,8 +140,8 @@ export type ChatInstanceUsage = {
139140
}
140141

141142
export type UserStatus = {
142-
model: string
143-
models: string[]
143+
model: ValidModelName
144+
models: ValidModelName[]
144145
usage: number
145146
limit: number
146147
isTike: boolean

src/config.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import z from 'zod/v4'
2+
13
export const inDevelopment = process.env.NODE_ENV === 'development'
24

35
export const inStaging = process.env.STAGING === 'true'
@@ -12,8 +14,6 @@ export const PUBLIC_URL = process.env.PUBLIC_URL || ''
1214

1315
export const DEFAULT_TOKEN_LIMIT = Number(process.env.DEFAULT_TOKEN_LIMIT) || 150_000
1416

15-
export const FREE_MODEL = process.env.FREE_MODEL || 'gpt-4o-mini' // as it was decided in 23th Sept 2024 meeting
16-
export const DEFAULT_MODEL = process.env.DEFAUL_MODEL || 'gpt-4o-mini'
1717
export const DEFAUL_CONTEXT_LIMIT = Number(process.env.DEFAUL_CONTEXT_LIMIT) || 4_096
1818

1919
export const DEFAULT_RESET_CRON = process.env.DEFAULT_RESET_CRON || '0 0 1 */3 *'
@@ -41,7 +41,15 @@ export const validModels = [
4141
name: 'mock',
4242
context: 128_000,
4343
},
44-
]
44+
] as const
45+
46+
export const ValidModelNameSchema = z.union(validModels.map((model) => z.literal(model.name)))
47+
48+
export type ValidModelName = z.infer<typeof ValidModelNameSchema>
49+
50+
export const DEFAULT_MODEL = ValidModelNameSchema.parse(process.env.DEFAULT_MODEL || 'gpt-4o-mini')
51+
52+
export const FREE_MODEL = ValidModelNameSchema.parse(process.env.FREE_MODEL || 'gpt-4o-mini') // as it was decided in 23th Sept 2024 meeting
4553

4654
export const DEFAULT_MODEL_ON_ENABLE = 'gpt-5'
4755

src/server/db/models/prompt.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import { type CreationOptional, DataTypes, type InferAttributes, type InferCreationAttributes, Model } from 'sequelize'
1+
import { type CreationOptional, DataTypes, type InferAttributes, type InferCreationAttributes, Model, NonAttribute } from 'sequelize'
22

33
import type { CustomMessage } from '../../types'
44
import { sequelize } from '../connection'
5+
import type RagIndex from './ragIndex'
56

67
export const PromptTypeValues = ['CHAT_INSTANCE', 'PERSONAL'] as const
78
export type PromptType = (typeof PromptTypeValues)[number]
@@ -26,6 +27,8 @@ class Prompt extends Model<InferAttributes<Prompt>, InferCreationAttributes<Prom
2627
declare hidden: CreationOptional<boolean>
2728

2829
declare mandatory: CreationOptional<boolean>
30+
31+
declare ragIndex?: NonAttribute<RagIndex>
2932
}
3033

3134
Prompt.init(

0 commit comments

Comments
 (0)