1
1
import { api } from '@/trpc/server' ;
2
2
import { trackEvent } from '@/utils/analytics/server' ;
3
- import { AgentStreamer , RootAgent } from '@onlook/ai' ;
3
+ import { AgentStreamer , BaseAgent , RootAgent , UserAgent } from '@onlook/ai' ;
4
4
import { toDbMessage } from '@onlook/db' ;
5
- import { ChatType , type ChatMessage } from '@onlook/models' ;
5
+ import { AgentType , ChatType } from '@onlook/models' ;
6
6
import { type NextRequest } from 'next/server' ;
7
7
import { v4 as uuidv4 } from 'uuid' ;
8
8
import { checkMessageLimit , decrementUsage , errorHandler , getSupabaseUser , incrementUsage , repairToolCall } from './helpers' ;
9
+ import { z } from 'zod' ;
9
10
10
11
export async function POST ( req : NextRequest ) {
11
12
try {
@@ -51,14 +52,24 @@ export async function POST(req: NextRequest) {
51
52
}
52
53
}
53
54
55
+ const streamResponseSchema = z . object ( {
56
+ agentType : z . enum ( AgentType ) . optional ( ) . default ( AgentType . ROOT ) ,
57
+ messages : z . array ( z . any ( ) ) ,
58
+ chatType : z . enum ( ChatType ) . optional ( ) ,
59
+ conversationId : z . string ( ) ,
60
+ projectId : z . string ( ) ,
61
+ } ) . refine ( ( data ) => {
62
+ // Only allow chatType if agentType is ROOT
63
+ if ( data . chatType !== undefined && data . agentType !== AgentType . ROOT ) {
64
+ return false ;
65
+ }
66
+ return true ;
67
+ } , { message : "chatType is only allowed if agentType is root" } ) ;
68
+
54
69
export const streamResponse = async ( req : NextRequest , userId : string ) => {
55
70
const body = await req . json ( ) ;
56
- const { messages, chatType, conversationId, projectId } = body as {
57
- messages : ChatMessage [ ] ,
58
- chatType : ChatType ,
59
- conversationId : string ,
60
- projectId : string ,
61
- } ;
71
+ const { agentType, messages, chatType, conversationId, projectId } = streamResponseSchema . parse ( body ) ;
72
+
62
73
// Updating the usage record and rate limit is done here to avoid
63
74
// abuse in the case where a single user sends many concurrent requests.
64
75
// If the call below fails, the user will not be penalized.
@@ -71,12 +82,20 @@ export const streamResponse = async (req: NextRequest, userId: string) => {
71
82
const lastUserMessage = messages . findLast ( ( message ) => message . role === 'user' ) ;
72
83
const traceId = lastUserMessage ?. id ?? uuidv4 ( ) ;
73
84
74
- if ( chatType === ChatType . EDIT ) {
75
- usageRecord = await incrementUsage ( req , traceId ) ;
76
- }
77
-
78
85
// Create RootAgent instance
79
- const agent = await RootAgent . create ( chatType ) ;
86
+ let agent : BaseAgent ;
87
+ if ( agentType === AgentType . ROOT ) {
88
+ if ( chatType === ChatType . EDIT ) {
89
+ usageRecord = await incrementUsage ( req , traceId ) ;
90
+ }
91
+
92
+ agent = new RootAgent ( chatType ! ) ;
93
+ } else if ( agentType === AgentType . USER ) {
94
+ agent = new UserAgent ( ) ;
95
+ } else {
96
+ // agent = new WeatherAgent();
97
+ throw new Error ( 'Agent type not supported' ) ;
98
+ }
80
99
const streamer = new AgentStreamer ( agent , conversationId ) ;
81
100
82
101
return streamer . streamText ( messages , {
@@ -87,7 +106,8 @@ export const streamResponse = async (req: NextRequest, userId: string) => {
87
106
conversationId,
88
107
projectId,
89
108
userId,
90
- chatType : chatType ,
109
+ agentType : agentType ?? AgentType . ROOT ,
110
+ chatType : chatType ?? "null" ,
91
111
tags : [ 'chat' ] ,
92
112
langfuseTraceId : traceId ,
93
113
sessionId : conversationId ,
0 commit comments