Skip to content

Commit 485fd5b

Browse files
committed
better prompt formatting and string only support
1 parent 7d62d4c commit 485fd5b

File tree

7 files changed

+100
-82
lines changed

7 files changed

+100
-82
lines changed

packages/prai/src/history.ts

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { Schema } from 'zod'
1+
import { Schema, ZodString } from 'zod'
22
import { isStepResponse, Message, MessageContent, wrapStepResponse } from './step.js'
33
import { getSchemaOptional, setSchema } from './schema/store.js'
44
import { buildSchemaType } from './schema/type.js'
@@ -16,13 +16,28 @@ export function buildStepRequestMessage(
1616
reason?: string
1717
}>,
1818
): Message {
19+
const schemaDescriptionLines: Array<string> = []
20+
if (!(schema instanceof ZodString)) {
21+
schemaDescriptionLines.push(
22+
'', //newline
23+
`Types:`,
24+
buildSchemaType(schema, `Response`, { prefix: `Step${stepId + 1}`, schemaTypeDefinitions, usedSchemas }),
25+
)
26+
} else if (schema.description != null) {
27+
schemaDescriptionLines.push(
28+
'', //new line
29+
`Response Format Description:`,
30+
schema.description,
31+
)
32+
}
1933
return {
2034
role: 'user',
2135
content: [
2236
{
2337
type: 'text',
2438
text: [
2539
`# Step${stepId + 1}`,
40+
'', //new line
2641
`Instructions:`,
2742
prompt,
2843
...(examples?.map(
@@ -31,8 +46,7 @@ export function buildStepRequestMessage(
3146
example.reason != null ? `, since ${example.reason}` : ''
3247
}.`,
3348
) ?? []),
34-
`Types:`,
35-
buildSchemaType(schema, `Response`, { prefix: `Step${stepId + 1}`, schemaTypeDefinitions, usedSchemas }),
49+
...schemaDescriptionLines,
3650
].join('\n'),
3751
},
3852
],

packages/prai/src/model.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { Schema } from 'zod'
1+
import { Schema, ZodString } from 'zod'
22
import { Message } from './step.js'
33

44
export type Provider = {
@@ -43,9 +43,12 @@ export class Model {
4343
): { value: Promise<unknown>; stream?: AsyncIterable<string> } {
4444
if (!streamOption) {
4545
return {
46-
value: this.options.provider
47-
.query(this.options.name, messages, schema, abortSignal)
48-
.then((response) => JSON.parse(response)),
46+
value: this.options.provider.query(this.options.name, messages, schema, abortSignal).then((response) => {
47+
if (!(schema instanceof ZodString) || (response.at(0) === '"' && response.at(-1) === '"')) {
48+
return JSON.parse(response)
49+
}
50+
return response
51+
}),
4952
}
5053
}
5154
const responseStream = this.options.provider.streamingQuery(this.options.name, messages, schema, abortSignal)

packages/prai/src/provider/gemini.ts

Lines changed: 32 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import { GenerateContentConfig, GoogleGenAI, GoogleGenAIOptions, Part } from '@google/genai'
22
import { Provider } from '../model.js'
33
import { buildJsonSchema } from '../schema/json.js'
4-
import { Schema } from 'zod'
4+
import { Schema, ZodString } from 'zod'
55
import { Message } from '../step.js'
66

77
function buildAdditionalConfig(schema: Schema): GenerateContentConfig {
8+
if (schema instanceof ZodString) {
9+
return {}
10+
}
811
let responseSchema = buildJsonSchema(schema, false)
912
return {
1013
responseMimeType: 'application/json',
@@ -16,65 +19,41 @@ export function gemini(options: GoogleGenAIOptions): Provider {
1619
const client = new GoogleGenAI(options)
1720
return {
1821
async query(model, messages, schema, abortSignal) {
19-
if (schema == null) {
20-
return query(model, client, messages, abortSignal)
21-
}
22-
return query(model, client, messages, abortSignal, buildAdditionalConfig(schema))
22+
const additionalConfig = buildAdditionalConfig(schema)
23+
const chat = client.chats.create({
24+
model,
25+
history: messages.slice(0, -1).map(({ role, content }) => ({
26+
role: role === 'assistant' ? 'model' : role,
27+
parts: content.map(messageContentToPartUnion),
28+
})),
29+
})
30+
const response = await chat.sendMessage({
31+
config: { abortSignal, ...additionalConfig },
32+
message: messages.at(-1)!.content.map(messageContentToPartUnion),
33+
})
34+
return response.text ?? ''
2335
},
2436
async *streamingQuery(model, messages, schema, abortSignal) {
25-
if (schema == null) {
26-
return streamingQueryOpenai(model, client, messages, abortSignal)
37+
const additionalConfig = buildAdditionalConfig(schema)
38+
const chat = client.chats.create({
39+
model,
40+
history: messages.slice(0, -1).map(({ role, content }) => ({
41+
role: role === 'assistant' ? 'model' : role,
42+
parts: content.map(messageContentToPartUnion),
43+
})),
44+
})
45+
const response = await chat.sendMessageStream({
46+
config: { abortSignal, ...additionalConfig },
47+
message: messages.at(-1)!.content.map(messageContentToPartUnion),
48+
})
49+
50+
for await (const chunk of response) {
51+
yield chunk.text ?? ''
2752
}
28-
return streamingQueryOpenai(model, client, messages, abortSignal, buildAdditionalConfig(schema))
2953
},
3054
}
3155
}
3256

33-
export async function* streamingQueryOpenai(
34-
model: string,
35-
client: GoogleGenAI,
36-
messages: Array<Message>,
37-
abortSignal: AbortSignal | undefined,
38-
additionalConfig?: GenerateContentConfig,
39-
): AsyncIterable<string> {
40-
const chat = client.chats.create({
41-
model,
42-
history: messages.slice(0, -1).map(({ role, content }) => ({
43-
role: role === 'assistant' ? 'model' : role,
44-
parts: content.map(messageContentToPartUnion),
45-
})),
46-
})
47-
const response = await chat.sendMessageStream({
48-
config: { abortSignal, ...additionalConfig },
49-
message: messages.at(-1)!.content.map(messageContentToPartUnion),
50-
})
51-
52-
for await (const chunk of response) {
53-
yield chunk.text ?? ''
54-
}
55-
}
56-
57-
export async function query(
58-
model: string,
59-
client: GoogleGenAI,
60-
messages: Array<Message>,
61-
abortSignal: AbortSignal | undefined,
62-
additionalConfig?: GenerateContentConfig,
63-
): Promise<string> {
64-
const chat = client.chats.create({
65-
model,
66-
history: messages.slice(0, -1).map(({ role, content }) => ({
67-
role: role === 'assistant' ? 'model' : role,
68-
parts: content.map(messageContentToPartUnion),
69-
})),
70-
})
71-
const response = await chat.sendMessage({
72-
config: { abortSignal, ...additionalConfig },
73-
message: messages.at(-1)!.content.map(messageContentToPartUnion),
74-
})
75-
return response.text ?? ''
76-
}
77-
7857
function messageContentToPartUnion(content: Message['content'][number]): Part {
7958
switch (content.type) {
8059
case 'image_url':

packages/prai/src/provider/groq.ts

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ import OpenAI, { ClientOptions } from 'openai'
22
import { Provider } from '../model.js'
33
import { Message } from '../step.js'
44
import { buildJsonSchema } from '../schema/json.js'
5-
import { Schema, ZodObject, ZodUnion } from 'zod'
5+
import { Schema, ZodObject, ZodString, ZodUnion } from 'zod'
66
import { extractResultProperty, streamingQueryOpenai, queryOpenai } from './utils.js'
77

88
function buildAdditionalParams(schema: Schema, wrapInObject: boolean) {
9+
if (schema instanceof ZodString) {
10+
return {}
11+
}
912
let responseSchema = buildJsonSchema(schema)
1013
if (wrapInObject) {
1114
responseSchema = {
@@ -35,10 +38,7 @@ export function groq(options: ClientOptions): Provider {
3538
return {
3639
async query(model, messages, schema, abortSignal) {
3740
const transformedMessages = transformMessages(messages)
38-
if (schema == null) {
39-
return queryOpenai(model, client, transformedMessages, abortSignal)
40-
}
41-
if (!(schema instanceof ZodObject || schema instanceof ZodUnion)) {
41+
if (!(schema instanceof ZodObject || schema instanceof ZodUnion || schema instanceof ZodString)) {
4242
const { result } = JSON.parse(
4343
await queryOpenai(model, client, transformedMessages, abortSignal, buildAdditionalParams(schema, true)),
4444
)
@@ -48,10 +48,7 @@ export function groq(options: ClientOptions): Provider {
4848
},
4949
async *streamingQuery(model, messages, schema, abortSignal) {
5050
const transformedMessages = transformMessages(messages)
51-
if (schema == null) {
52-
return streamingQueryOpenai(model, client, transformedMessages, abortSignal)
53-
}
54-
if (!(schema instanceof ZodObject || schema instanceof ZodUnion)) {
51+
if (!(schema instanceof ZodObject || schema instanceof ZodUnion || schema instanceof ZodString)) {
5552
return extractResultProperty(
5653
streamingQueryOpenai(model, client, transformedMessages, abortSignal, buildAdditionalParams(schema, true)),
5754
)

packages/prai/src/provider/openai.ts

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import OpenAI, { ClientOptions } from 'openai'
22
import { Provider } from '../model.js'
33
import { Message } from '../step.js'
4-
import { Schema, ZodObject, ZodUnion } from 'zod'
4+
import { Schema, ZodObject, ZodString, ZodUnion } from 'zod'
55
import { extractResultProperty } from './utils.js'
66
import { buildJsonSchema } from '../schema/json.js'
77

88
function buildAdditionalParams(schema: Schema, wrapInObject: boolean) {
9+
if (schema instanceof ZodString) {
10+
return {}
11+
}
912
let responseSchema = buildJsonSchema(schema)
1013
if (wrapInObject) {
1114
responseSchema = {
@@ -34,10 +37,7 @@ export function openai(options: ClientOptions): Provider {
3437
const client = new OpenAI(options)
3538
return {
3639
async query(model, messages, schema, abortSignal) {
37-
if (schema == null) {
38-
return openaiQuery(model, client, messages, abortSignal)
39-
}
40-
if (!(schema instanceof ZodObject || schema instanceof ZodUnion)) {
40+
if (!(schema instanceof ZodObject || schema instanceof ZodUnion || schema instanceof ZodString)) {
4141
const { result } = JSON.parse(
4242
await openaiQuery(model, client, messages, abortSignal, buildAdditionalParams(schema, true)),
4343
)
@@ -46,10 +46,7 @@ export function openai(options: ClientOptions): Provider {
4646
return openaiQuery(model, client, messages, abortSignal, buildAdditionalParams(schema, false))
4747
},
4848
async *streamingQuery(model, messages, schema, abortSignal) {
49-
if (schema == null) {
50-
return openaiStreamingQuery(model, client, messages, abortSignal)
51-
}
52-
if (!(schema instanceof ZodObject || schema instanceof ZodUnion)) {
49+
if (!(schema instanceof ZodObject || schema instanceof ZodUnion || schema instanceof ZodString)) {
5350
return extractResultProperty(
5451
openaiStreamingQuery(model, client, messages, abortSignal, buildAdditionalParams(schema, true)),
5552
)

packages/prai/src/utils.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
export const NewLine = `\n`
2-
31
export function lines(...lines: Array<string>): string {
4-
return lines.join(NewLine)
2+
return lines.join('\n')
53
}

packages/prai/test/step.test.ts

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { describe, expect, it } from 'vitest'
2-
import { listStep, mapStep, History, Model, mock } from '../src/index.js'
2+
import { listStep, mapStep, History, Model, mock, step } from '../src/index.js'
33
import { object, string } from 'zod'
44

55
describe('step', () => {
@@ -13,8 +13,10 @@ describe('step', () => {
1313
{
1414
text: [
1515
'# Step1',
16+
'',
1617
'Instructions:',
1718
'Return a list of list colors',
19+
'',
1820
'Types:',
1921
'type Step1Response = Array<{',
2022
'\tname: string',
@@ -32,8 +34,10 @@ describe('step', () => {
3234
{
3335
text: [
3436
`# Step2`,
37+
'',
3538
`Instructions:`,
3639
`For each entry in response of the previous step, a word that tymes with the name field from the object of each entry. The resulting list should have the same length and order as the input list of response of the previous step`,
40+
'',
3741
`Types:`,
3842
`type Step2Response = Array<string>`,
3943
``,
@@ -44,4 +48,30 @@ describe('step', () => {
4448
role: 'user',
4549
})
4650
})
51+
52+
it('string schema with describe should contain Response Format Description', async () => {
53+
const history = new History()
54+
const model = new Model({ provider: mock({ startupDelaySeconds: 0, tokensPerSecond: Infinity }), name: 'unknown' })
55+
56+
await step('generate a greeting', string().describe('A friendly greeting message'), { history, model })
57+
58+
const messageContent = history['messages'][0].content[0]
59+
if (messageContent.type === 'text') {
60+
expect(messageContent.text).to.include('Response Format Description')
61+
expect(messageContent.text).to.not.include('Types:')
62+
}
63+
})
64+
65+
it('plain string schema should not contain Types or Response Format Description', async () => {
66+
const history = new History()
67+
const model = new Model({ provider: mock({ startupDelaySeconds: 0, tokensPerSecond: Infinity }), name: 'unknown' })
68+
69+
await step('generate a greeting', string(), { history, model })
70+
71+
const messageContent = history['messages'][0].content[0]
72+
if (messageContent.type === 'text') {
73+
expect(messageContent.text).to.not.include('Types:')
74+
expect(messageContent.text).to.not.include('Response Format Description')
75+
}
76+
})
4777
})

0 commit comments

Comments
 (0)