|
| 1 | +import { useEffect, useState } from 'react' |
| 2 | +import type { PromptCreationParams, PromptEditableParams } from '@shared/prompt' |
| 3 | +import type { ValidModelName } from '@config' |
| 4 | +import { TextField, Box, Checkbox, FormControlLabel, FormControl, InputLabel, Select, MenuItem, Slider } from '@mui/material' |
| 5 | +import { validModels } from '@config' |
| 6 | +import { useTranslation } from 'react-i18next' |
| 7 | +import type { RagIndexAttributes } from '@shared/types' |
| 8 | +import { useCreatePromptMutation, useEditPromptMutation } from '../../hooks/usePromptMutation' |
| 9 | +import { enqueueSnackbar } from 'notistack' |
| 10 | +import { BlueButton } from '../ChatV2/general/Buttons' |
| 11 | + |
| 12 | +interface PromptEditorProps { |
| 13 | + prompt?: PromptEditableParams & { id: string } |
| 14 | + ragIndices?: RagIndexAttributes[] |
| 15 | + type: PromptCreationParams['type'] |
| 16 | + chatInstanceId?: string |
| 17 | +} |
| 18 | + |
| 19 | +export const PromptEditor = ({ prompt, ragIndices, type, chatInstanceId }: PromptEditorProps) => { |
| 20 | + const { t } = useTranslation() |
| 21 | + |
| 22 | + const editMutation = useEditPromptMutation() |
| 23 | + const createMutation = useCreatePromptMutation() |
| 24 | + |
| 25 | + const [name, setName] = useState<string>(prompt?.name ?? '') |
| 26 | + const [systemMessage, setSystemMessage] = useState<string>(prompt?.systemMessage ?? '') |
| 27 | + const [hidden, setHidden] = useState<boolean>(prompt?.hidden ?? false) |
| 28 | + const [mandatory, setMandatory] = useState<boolean>(prompt?.mandatory ?? false) |
| 29 | + const [ragIndexId, setRagIndexId] = useState<number | undefined>(prompt?.ragIndexId) |
| 30 | + |
| 31 | + const [selectedModel, setModel] = useState<ValidModelName | 'none'>(prompt?.model ?? 'none') |
| 32 | + |
| 33 | + const [temperatureDefined, setTemperatureDefined] = useState<boolean>(prompt?.temperature !== undefined) |
| 34 | + const [temperature, setTemperature] = useState<number>(prompt?.temperature ?? 0.5) |
| 35 | + |
| 36 | + useEffect(() => { |
| 37 | + const selectedModelConfig = validModels.find((m) => m.name === selectedModel) |
| 38 | + if (selectedModelConfig && 'temperature' in selectedModelConfig) { |
| 39 | + setTemperature(selectedModelConfig.temperature) |
| 40 | + setTemperatureDefined(false) |
| 41 | + } |
| 42 | + }, [selectedModel]) |
| 43 | + |
| 44 | + const handleSubmit = async (event: React.FormEvent<HTMLFormElement>) => { |
| 45 | + event.preventDefault() |
| 46 | + const model = selectedModel !== 'none' ? selectedModel : undefined |
| 47 | + try { |
| 48 | + if (prompt) { |
| 49 | + await editMutation.mutateAsync({ |
| 50 | + id: prompt.id, |
| 51 | + name, |
| 52 | + systemMessage, |
| 53 | + hidden, |
| 54 | + mandatory, |
| 55 | + ragIndexId, |
| 56 | + model, |
| 57 | + temperature, |
| 58 | + }) |
| 59 | + enqueueSnackbar('Prompt updated', { variant: 'success' }) |
| 60 | + } else { |
| 61 | + await createMutation.mutateAsync({ |
| 62 | + name, |
| 63 | + type, |
| 64 | + ...(type === 'CHAT_INSTANCE' ? { chatInstanceId } : {}), |
| 65 | + systemMessage, |
| 66 | + hidden, |
| 67 | + mandatory, |
| 68 | + ragIndexId, |
| 69 | + model, |
| 70 | + temperature, |
| 71 | + }) |
| 72 | + } |
| 73 | + } catch (error: any) { |
| 74 | + enqueueSnackbar(error.message, { variant: 'error' }) |
| 75 | + } |
| 76 | + } |
| 77 | + |
| 78 | + /** |
| 79 | + * If model has temperature, temperature is not relevant option and should not be shown to user. |
| 80 | + */ |
| 81 | + const modelHasTemperature = selectedModel && 'temperature' in (validModels.find((m) => m.name === selectedModel) ?? {}) |
| 82 | + |
| 83 | + return ( |
| 84 | + <Box component="form" onSubmit={handleSubmit} sx={{ mt: 2 }}> |
| 85 | + <TextField |
| 86 | + slotProps={{ |
| 87 | + htmlInput: { |
| 88 | + 'data-testid': 'prompt-name-input', |
| 89 | + minLength: 3, |
| 90 | + }, |
| 91 | + }} |
| 92 | + label={t('common:promptName')} |
| 93 | + value={name} |
| 94 | + onChange={(e) => setName(e.target.value)} |
| 95 | + fullWidth |
| 96 | + margin="normal" |
| 97 | + /> |
| 98 | + <TextField |
| 99 | + slotProps={{ |
| 100 | + htmlInput: { |
| 101 | + 'data-testid': 'system-message-input', |
| 102 | + }, |
| 103 | + }} |
| 104 | + label={t('prompt:systemMessage')} |
| 105 | + value={systemMessage} |
| 106 | + onChange={(e) => setSystemMessage(e.target.value)} |
| 107 | + fullWidth |
| 108 | + margin="normal" |
| 109 | + multiline |
| 110 | + minRows={4} |
| 111 | + maxRows={24} |
| 112 | + /> |
| 113 | + <FormControlLabel control={<Checkbox checked={hidden} onChange={(e) => setHidden(e.target.checked)} />} label={t('prompt:hidePrompt')} /> |
| 114 | + <FormControlLabel control={<Checkbox checked={mandatory} onChange={(e) => setMandatory(e.target.checked)} />} label={t('prompt:editMandatoryPrompt')} /> |
| 115 | + |
| 116 | + <FormControl fullWidth margin="normal"> |
| 117 | + <InputLabel>{t('rag:sourceMaterials')}</InputLabel> |
| 118 | + {ragIndices && ( |
| 119 | + <Select |
| 120 | + value={ragIndexId || ''} |
| 121 | + onChange={(e) => setRagIndexId(e.target.value ? Number(e.target.value) : undefined)} |
| 122 | + disabled={ragIndices === undefined || ragIndices.length === 0} |
| 123 | + > |
| 124 | + <MenuItem value=""> |
| 125 | + <em>{t('prompt:noSourceMaterials')}</em> |
| 126 | + </MenuItem> |
| 127 | + {ragIndices?.map((index) => ( |
| 128 | + <MenuItem key={index.id} value={index.id}> |
| 129 | + {index.metadata.name} |
| 130 | + </MenuItem> |
| 131 | + ))} |
| 132 | + </Select> |
| 133 | + )} |
| 134 | + </FormControl> |
| 135 | + <FormControl fullWidth margin="normal"> |
| 136 | + <InputLabel>{t('common:model')}</InputLabel> |
| 137 | + <Select value={selectedModel || ''} onChange={(e) => setModel(e.target.value as ValidModelName | 'none')}> |
| 138 | + <MenuItem value="none"> |
| 139 | + <em>{t('prompt:modelFreeToChoose')}</em> |
| 140 | + </MenuItem> |
| 141 | + {validModels.map((m) => ( |
| 142 | + <MenuItem key={m.name} value={m.name}> |
| 143 | + {m.name} |
| 144 | + </MenuItem> |
| 145 | + ))} |
| 146 | + </Select> |
| 147 | + </FormControl> |
| 148 | + <Box> |
| 149 | + <FormControlLabel |
| 150 | + control={ |
| 151 | + <Checkbox |
| 152 | + checked={temperatureDefined && !modelHasTemperature} |
| 153 | + onChange={(e) => setTemperatureDefined(e.target.checked)} |
| 154 | + disabled={modelHasTemperature} |
| 155 | + /> |
| 156 | + } |
| 157 | + label={t('chat:temperature')} |
| 158 | + /> |
| 159 | + {temperatureDefined && !modelHasTemperature && ( |
| 160 | + <Slider |
| 161 | + value={temperature} |
| 162 | + onChange={(_, newValue) => setTemperature(newValue as number)} |
| 163 | + aria-labelledby="temperature-slider" |
| 164 | + valueLabelDisplay="auto" |
| 165 | + step={0.1} |
| 166 | + marks |
| 167 | + min={0} |
| 168 | + max={1} |
| 169 | + disabled={modelHasTemperature} |
| 170 | + /> |
| 171 | + )} |
| 172 | + </Box> |
| 173 | + <BlueButton type="submit" variant="contained" sx={{ mt: 2 }}> |
| 174 | + {t('common:save')} |
| 175 | + </BlueButton> |
| 176 | + </Box> |
| 177 | + ) |
| 178 | +} |
0 commit comments