Skip to content

Commit a9ca2cb

Browse files
committed
Can select rag when creating prompt
1 parent 6bd3623 commit a9ca2cb

File tree

4 files changed

+35
-25
lines changed

4 files changed

+35
-25
lines changed

src/client/components/ChatV2/ChatV2.tsx

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,12 @@ import { handleCompletionStreamError } from './error'
3232
import ToolResult from './ToolResult'
3333
import { OutlineButtonBlack } from './general/Buttons'
3434
import { ChatInfo } from './general/ChatInfo'
35-
import RagSelector from './RagSelector'
35+
import RagSelector, { RagSelectorDescription } from './RagSelector'
3636
import { SettingsModal, useUrlPromptId } from './SettingsModal'
3737
import { useChatStream } from './useChatStream'
3838
import { getCompletionStreamV3 } from './util'
3939
import PromptSelector from './PromptSelector'
4040
import { useQuery } from '@tanstack/react-query'
41-
import { useMutation } from '@tanstack/react-query'
42-
import apiClient from '../../util/apiClient'
4341
import ModelSelector from './ModelSelector'
4442

4543
function useLocalStorageStateWithURLDefault(key: string, defaultValue: string, urlKey: string) {
@@ -382,7 +380,6 @@ export const ChatV2 = () => {
382380
onClose={() => {
383381
setChatLeftSidePanelOpen(false)
384382
}}
385-
t={t}
386383
setSettingsModalOpen={setSettingsModalOpen}
387384
setDisclaimerStatus={setDisclaimerStatus}
388385
showRagSelector={showRagSelector}
@@ -406,7 +403,6 @@ export const ChatV2 = () => {
406403
}}
407404
course={course}
408405
handleReset={handleReset}
409-
t={t}
410406
setSettingsModalOpen={setSettingsModalOpen}
411407
setDisclaimerStatus={setDisclaimerStatus}
412408
showRagSelector={showRagSelector}
@@ -599,7 +595,6 @@ const LeftMenu = ({
599595
course,
600596
handleReset,
601597
onClose,
602-
t,
603598
setSettingsModalOpen,
604599
setDisclaimerStatus,
605600
showRagSelector,
@@ -612,13 +607,11 @@ const LeftMenu = ({
612607
currentModel,
613608
setModel,
614609
availableModels,
615-
616610
}: {
617611
sx?: object
618612
course?: Course
619613
handleReset: () => void
620614
onClose?: () => void
621-
t: TFunction
622615
setSettingsModalOpen: React.Dispatch<React.SetStateAction<boolean>>
623616
setDisclaimerStatus: React.Dispatch<React.SetStateAction<boolean>>
624617
showRagSelector: boolean
@@ -628,11 +621,11 @@ const LeftMenu = ({
628621
messages: Message[]
629622
activePrompt: Prompt | undefined
630623
setActivePrompt: (prompt: Prompt | undefined) => void
631-
632624
currentModel: string
633625
setModel: (model: string) => void
634626
availableModels: string[]
635627
}) => {
628+
const { t } = useTranslation()
636629
const { courseId } = useParams()
637630
const { userStatus, isLoading: statusLoading } = useUserStatus(courseId)
638631
const [isTokenLimitExceeded, setIsTokenLimitExceeded] = useState<boolean>(false)
@@ -668,12 +661,7 @@ const LeftMenu = ({
668661
<OutlineButtonBlack startIcon={<RestartAltIcon />} onClick={handleReset} data-testid="empty-conversation-button">
669662
{t('chat:emptyConversation')}
670663
</OutlineButtonBlack>
671-
<ModelSelector
672-
currentModel={currentModel}
673-
setModel={setModel}
674-
availableModels={availableModels}
675-
isTokenLimitExceeded={isTokenLimitExceeded}
676-
/>
664+
<ModelSelector currentModel={currentModel} setModel={setModel} availableModels={availableModels} isTokenLimitExceeded={isTokenLimitExceeded} />
677665
<PromptSelector
678666
sx={{ width: '100%' }}
679667
coursePrompts={course?.prompts ?? []}
@@ -692,13 +680,7 @@ const LeftMenu = ({
692680
</OutlineButtonBlack>
693681
{course && showRagSelector && (
694682
<>
695-
<Typography variant="h6" sx={{ mb: 1, display: 'flex', gap: 1, alignItems: 'center' }} fontWeight="bold">
696-
<MenuBookTwoTone />
697-
{t('chat:sources')}
698-
</Typography>
699-
<Typography variant="body2" sx={{ mb: 2 }}>
700-
{t('settings:sourceDescription')}
701-
</Typography>
683+
<RagSelectorDescription />
702684
<RagSelector currentRagIndex={ragIndex} setRagIndex={setRagIndexId} ragIndices={ragIndices ?? []} />
703685
</>
704686
)}

src/client/components/ChatV2/RagSelector.tsx

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import { useTranslation } from 'react-i18next'
2-
import { Box, MenuItem, Menu } from '@mui/material'
2+
import { Box, MenuItem, Menu, Typography } from '@mui/material'
33
import KeyboardArrowDownIcon from '@mui/icons-material/KeyboardArrowDown'
44
import { useState } from 'react'
55
import { RagIndexAttributes } from '../../../shared/types'
66
import { OutlineButtonBlack } from './general/Buttons'
7-
import { MenuBook } from '@mui/icons-material'
7+
import { MenuBook, MenuBookTwoTone } from '@mui/icons-material'
88

99
const RagSelector = ({
1010
currentRagIndex,
@@ -72,4 +72,20 @@ const RagSelector = ({
7272
)
7373
}
7474

75+
export const RagSelectorDescription = () => {
76+
const { t } = useTranslation()
77+
78+
return (
79+
<>
80+
<Typography variant="h6" sx={{ mb: 1, display: 'flex', gap: 1, alignItems: 'center' }} fontWeight="bold">
81+
<MenuBookTwoTone />
82+
{t('chat:sources')}
83+
</Typography>
84+
<Typography variant="body2" sx={{ mb: 2 }}>
85+
{t('settings:sourceDescription')}
86+
</Typography>
87+
</>
88+
)
89+
}
90+
7591
export default RagSelector

src/client/components/Courses/Course/index.tsx

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import Discussion from './Discussions'
2323
import { ApiErrorView } from '../../common/ApiErrorView'
2424
import apiClient from '../../../util/apiClient'
2525
import { ActionUserSearch } from '../../Admin/UserSearch'
26+
import { useCourseRagIndices } from '../../../hooks/useRagIndices'
27+
import RagSelector, { RagSelectorDescription } from '../../ChatV2/RagSelector'
2628

2729
const Course = () => {
2830
const [showTeachers, setShowTeachers] = useState(false)
@@ -320,11 +322,13 @@ const AssignedResponsibilityManagement = ({ responsibility, handleRemove }) => {
320322

321323
const Prompts = ({ courseId, chatInstanceId }: { courseId: string; chatInstanceId: string }) => {
322324
const { t } = useTranslation()
325+
const { ragIndices } = useCourseRagIndices(courseId)
323326
const [name, setName] = useState('')
324327
const [system, setSystem] = useState('')
325328
const [messages, setMessages] = useState<MessageType[]>([])
326329
const [hidden, setHidden] = useState(false)
327330
const [mandatory, setMandatory] = useState(false)
331+
const [ragIndexId, setRagIndexId] = useState<number | undefined>(undefined)
328332

329333
const createMutation = useCreatePromptMutation()
330334
const deleteMutation = useDeletePromptMutation()
@@ -351,6 +355,7 @@ const Prompts = ({ courseId, chatInstanceId }: { courseId: string; chatInstanceI
351355
messages,
352356
hidden,
353357
mandatory,
358+
ragIndexId,
354359
})
355360
enqueueSnackbar('Prompt created', { variant: 'success' })
356361
handleReset()
@@ -397,7 +402,7 @@ const Prompts = ({ courseId, chatInstanceId }: { courseId: string; chatInstanceI
397402

398403
<Conversation messages={messages} completion="" />
399404

400-
<Box sx={{ paddingBottom: 2 }}>
405+
<Box sx={{ py: 2, display: 'flex', alignItems: 'start' }}>
401406
{!mandatoryPromptId ? (
402407
<FormControlLabel
403408
control={<Checkbox checked={mandatory} onChange={() => setMandatory((prev) => !prev)} />}
@@ -411,6 +416,12 @@ const Prompts = ({ courseId, chatInstanceId }: { courseId: string; chatInstanceI
411416
)}
412417
<FormControlLabel control={<Checkbox value={hidden} onChange={() => setHidden((prev) => !prev)} />} label={t('hidePrompt')} />
413418
</Box>
419+
{ragIndices && (
420+
<div>
421+
<RagSelectorDescription />
422+
<RagSelector ragIndices={ragIndices} setRagIndex={setRagIndexId} currentRagIndex={ragIndices.find((rag) => rag.id === ragIndexId)} />
423+
</div>
424+
)}
414425
<Button variant="contained" onClick={handleSave} sx={{ mr: 2 }}>
415426
{t('common:save')}
416427
</Button>

src/client/hooks/usePromptMutation.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ interface NewPromptData {
1313
messages: Message[]
1414
hidden: boolean
1515
mandatory: boolean
16+
ragIndexId: number | undefined
1617
}
1718

1819
export const useCreatePromptMutation = () => {

0 commit comments

Comments
 (0)