Skip to content

Commit 6bd3623

Browse files
committed
Prompt relation to ragIndex change
1 parent 71ba0f3 commit 6bd3623

File tree

2 files changed

+19
-42
lines changed

2 files changed

+19
-42
lines changed

src/server/db/models/prompt.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { type CreationOptional, DataTypes, type InferAttributes, type InferCreat
33
import type { CustomMessage } from '../../types'
44
import { sequelize } from '../connection'
55

6-
export const PromptTypeValues = ['CHAT_INSTANCE', 'PERSONAL', 'RAG_INDEX'] as const
6+
export const PromptTypeValues = ['CHAT_INSTANCE', 'PERSONAL'] as const
77
export type PromptType = (typeof PromptTypeValues)[number]
88

99
class Prompt extends Model<InferAttributes<Prompt>, InferCreationAttributes<Prompt>> {

src/server/routes/prompt.ts

Lines changed: 18 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { ChatInstance, Prompt, RagIndex, Responsibility } from '../db/models'
44
import type { RequestWithUser } from '../types'
55
import type { User } from '../../shared/user'
66
import { ApplicationError } from '../util/ApplicationError'
7-
import { InferAttributes } from 'sequelize'
7+
import type { InferAttributes } from 'sequelize'
88

99
const promptRouter = express.Router()
1010

@@ -51,6 +51,7 @@ const PromptUpdateableParams = z.object({
5151
systemMessage: z.string().max(20_000),
5252
hidden: z.boolean().default(false),
5353
mandatory: z.boolean().default(false),
54+
ragIndexId: z.number().min(1).optional(),
5455
})
5556

5657
const PromptCreationParams = z.intersection(
@@ -73,11 +74,6 @@ const PromptCreationParams = z.intersection(
7374
z.object({
7475
type: z.literal('PERSONAL'),
7576
}),
76-
z.object({
77-
type: z.literal('RAG_INDEX'),
78-
ragIndexId: z.number().min(1),
79-
chatInstanceId: z.string().min(1).optional(),
80-
}),
8177
]),
8278
)
8379

@@ -101,14 +97,6 @@ const getPotentialNameConflicts = async (prompt: InferAttributes<Prompt, { omit:
10197
},
10298
})
10399
}
104-
case 'RAG_INDEX': {
105-
return Prompt.findAll({
106-
attributes: ['id', 'name'],
107-
where: {
108-
ragIndexId: prompt.ragIndexId,
109-
},
110-
})
111-
}
112100
}
113101
}
114102

@@ -138,33 +126,12 @@ const authorizeChatInstancePromptResponsible = async (user: User, prompt: ChatIn
138126
}
139127
}
140128

141-
interface RagIndexPrompt {
142-
ragIndexId: number
143-
chatInstanceId?: string
144-
}
145-
146-
const authorizeRagIndexPromptResponsible = async (user: User, prompt: RagIndexPrompt) => {
147-
const ragIndex = await RagIndex.findByPk(prompt.ragIndexId)
148-
const isAuthor = ragIndex?.userId === user.id
149-
150-
if (!isAuthor && !user.isAdmin) {
151-
if (!prompt.chatInstanceId) {
152-
throw ApplicationError.Forbidden('Not allowed')
153-
}
154-
await authorizeChatInstancePromptResponsible(user, prompt as ChatInstancePrompt)
155-
}
156-
}
157-
158129
const authorizePromptCreation = async (user: User, promptParams: PromptCreationParamsType) => {
159130
switch (promptParams.type) {
160131
case 'CHAT_INSTANCE': {
161132
await authorizeChatInstancePromptResponsible(user, promptParams)
162133
break
163134
}
164-
case 'RAG_INDEX': {
165-
await authorizeRagIndexPromptResponsible(user, promptParams)
166-
break
167-
}
168135
case 'PERSONAL': {
169136
// This is fine. Anyone can create a personal prompt. Lets just limit the number of prompts per user to 200
170137
const count = await Prompt.count({ where: { userId: user.id } })
@@ -173,6 +140,9 @@ const authorizePromptCreation = async (user: User, promptParams: PromptCreationP
173140
}
174141
break
175142
}
143+
default: {
144+
throw ApplicationError.InternalServerError('Unknown prompt type')
145+
}
176146
}
177147
}
178148

@@ -200,16 +170,15 @@ const authorizePromptUpdate = async (user: User, prompt: Prompt) => {
200170
await authorizeChatInstancePromptResponsible(user, prompt as ChatInstancePrompt)
201171
break
202172
}
203-
case 'RAG_INDEX': {
204-
await authorizeRagIndexPromptResponsible(user, prompt as RagIndexPrompt)
205-
break
206-
}
207173
case 'PERSONAL': {
208174
if (user.id !== prompt.userId) {
209175
throw ApplicationError.Forbidden('Not allowed')
210176
}
211177
break
212178
}
179+
default: {
180+
throw ApplicationError.InternalServerError('Unknown prompt type')
181+
}
213182
}
214183
}
215184

@@ -234,7 +203,7 @@ promptRouter.put('/:id', async (req, res) => {
234203
const { id } = req.params
235204
const { user } = req as unknown as RequestWithUser
236205
const updates = PromptUpdateableParams.parse(req.body)
237-
const { systemMessage, name, hidden, mandatory } = updates
206+
const { systemMessage, name, hidden, mandatory, ragIndexId } = updates
238207

239208
const prompt = await Prompt.findByPk(id)
240209

@@ -249,6 +218,7 @@ promptRouter.put('/:id', async (req, res) => {
249218
throw ApplicationError.Conflict('Prompt name already exists')
250219
}
251220

221+
prompt.ragIndexId = ragIndexId
252222
prompt.systemMessage = systemMessage
253223
prompt.name = name
254224
prompt.hidden = hidden
@@ -263,7 +233,14 @@ promptRouter.get('/:id', async (req, res) => {
263233
const { id } = req.params
264234

265235
// Note: we dont have any authorization checks here. Consider?
266-
const prompt = await Prompt.findByPk(id)
236+
const prompt = await Prompt.findByPk(id, {
237+
include: [
238+
{
239+
model: RagIndex,
240+
as: 'ragIndex',
241+
},
242+
],
243+
})
267244

268245
if (!prompt) {
269246
// We dont throw error here, since this is expected behaviour when the prompt has been deleted but someone still has it in their local storage.

0 commit comments

Comments
 (0)