Skip to content

Commit 55d2d77

Browse files
committed
fix type issues
1 parent 4a71943 commit 55d2d77

File tree

3 files changed

+109
-124
lines changed

3 files changed

+109
-124
lines changed

src/helpers/beta/zod.ts

Lines changed: 24 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import * as z from 'zod';
44
import { OpenAIError } from '../../core/error';
55
import type { BetaRunnableTool, Promisable } from '../../lib/beta/BetaRunnableTool';
66
import type { AutoParseableBetaOutputFormat } from '../../lib/beta-parser';
7+
import { FunctionTool } from '../../resources/beta';
78
// import { AutoParseableBetaOutputFormat } from '../../lib/beta-parser';
89
// import { BetaRunnableTool, Promisable } from '../../lib/tools/BetaRunnableTool';
910
// import { BetaToolResultContentBlockParam } from '../../resources/beta';
@@ -52,7 +53,7 @@ export function betaZodTool<InputSchema extends ZodType>(options: {
5253
name: string;
5354
inputSchema: InputSchema;
5455
description: string;
55-
run: (args: zodInfer<InputSchema>) => Promisable<string | Array<BetaToolResultContentBlockParam>>;
56+
run: (args: zodInfer<InputSchema>) => Promisable<string | Array<FunctionTool>>; // TODO: I changed this but double check
5657
}): BetaRunnableTool<zodInfer<InputSchema>> {
5758
const jsonSchema = z.toJSONSchema(options.inputSchema, { reused: 'ref' });
5859

@@ -63,27 +64,18 @@ export function betaZodTool<InputSchema extends ZodType>(options: {
6364
// TypeScript doesn't narrow the type after the runtime check, so we need to assert it
6465
const objectSchema = jsonSchema as typeof jsonSchema & { type: 'object' };
6566

66-
// return {
67-
// type: 'function', // TODO: should this be custom or function?
68-
// name: options.name,
69-
// input_schema: objectSchema,
70-
// description: options.description,
71-
// run: options.run,
72-
// parse: (args: unknown) => options.inputSchema.parse(args) as zodInfer<InputSchema>,
73-
// };
7467
return {
7568
type: 'function',
7669
function: {
7770
name: options.name,
78-
// input_schema: objectSchema,
7971
description: options.description,
80-
// run: options.run,
81-
// parse: (args: unknown) => options.inputSchema.parse(args) as zodInfer<InputSchema>,
8272
parameters: {
8373
type: 'object',
8474
properties: objectSchema.properties,
8575
},
8676
},
77+
run: options.run,
78+
parse: (args: unknown) => options.inputSchema.parse(args) as zodInfer<InputSchema>,
8779
};
8880
}
8981

@@ -94,37 +86,26 @@ export function betaZodTool<InputSchema extends ZodType>(options: {
9486
// * input arguments will also be validated against the provided schema.
9587
// */
9688
// export function betaZodTool<InputSchema extends ZodType>(options: {
97-
// name: string;
98-
// inputSchema: InputSchema;
99-
// description: string;
100-
// run: (args: zodInfer<InputSchema>) => Promisable<string | Array<any>>;
101-
// }): BetaRunnableTool<zodInfer<InputSchema>> {
102-
// const jsonSchema = z.toJSONSchema(options.inputSchema, { reused: 'ref' });
89+
// name: string;
90+
// inputSchema: InputSchema;
91+
// description: string;
92+
// run: (args: zodInfer<InputSchema>) => Promisable<string | Array<BetaToolResultContentBlockParam>>;
93+
// }): BetaRunnableTool<zodInfer<InputSchema>> {
94+
// const jsonSchema = z.toJSONSchema(options.inputSchema, { reused: 'ref' });
10395

104-
// if (jsonSchema.type !== 'object') {
105-
// throw new Error(`Zod schema for tool "${options.name}" must be an object, but got ${jsonSchema.type}`);
106-
// }
96+
// if (jsonSchema.type !== 'object') {
97+
// throw new Error(`Zod schema for tool "${options.name}" must be an object, but got ${jsonSchema.type}`);
98+
// }
10799

108-
// // TypeScript doesn't narrow the type after the runtime check, so we need to assert it
109-
// const objectSchema = jsonSchema as typeof jsonSchema & { type: 'object' };
100+
// // TypeScript doesn't narrow the type after the runtime check, so we need to assert it
101+
// const objectSchema = jsonSchema as typeof jsonSchema & { type: 'object' };
110102

111-
// // return {
112-
// // type: 'function', // TODO: should this be custom or function?
113-
// // name: options.name,
114-
// // input_schema: objectSchema,
115-
// // description: options.description,
116-
// // run: options.run,
117-
// // parse: (args: unknown) => options.inputSchema.parse(args) as zodInfer<InputSchema>,
118-
// // };
119-
// return {
120-
// type: 'function',
121-
// function: {
122-
// name: options.name,
123-
// // input_schema: objectSchema,
124-
// description: options.description,
125-
// // run: options.run,
126-
// // parse: (args: unknown) => options.inputSchema.parse(args) as zodInfer<InputSchema>,
127-
// parameters: objectSchema.properties ?? {}, // the json schema
128-
// },
129-
// };
130-
// }
103+
// return {
104+
// type: 'custom',
105+
// name: options.name,
106+
// input_schema: objectSchema,
107+
// description: options.description,
108+
// run: options.run,
109+
// parse: (args: unknown) => options.inputSchema.parse(args) as zodInfer<InputSchema>,
110+
// };
111+
// }

src/lib/beta/BetaRunnableTool.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import { FunctionTool } from '../../resources/beta';
12
import type { ChatCompletionTool } from '../../resources';
23

34
export type Promisable<T> = T | Promise<T>;
45

56
// this type is just an extension of BetaTool with a run and parse method
67
// that will be called by `toolRunner()` helpers
78
export type BetaRunnableTool<Input = any> = ChatCompletionTool & {
8-
run: (args: Input) => Promisable<string | Array<BetaToolResultContentBlockParam>>;
9+
run: (args: Input) => Promisable<string | Array<FunctionTool>>;
910
parse: (content: unknown) => Input;
1011
};

src/lib/beta/BetaToolRunner.ts

Lines changed: 83 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ export class BetaToolRunner<Stream extends boolean> {
4545
/** Promise for the last message received from the assistant */
4646
#message?: Promise<ChatCompletion> | undefined;
4747
/** Cached tool response to avoid redundant executions */
48-
#toolResponse?: Promise<ChatCompletionToolMessageParam | null> | undefined;
48+
#toolResponse?: Promise<null | ChatCompletionToolMessageParam[]> | undefined;
4949
/** Promise resolvers for waiting on completion */
5050
#completion: {
5151
promise: Promise<ChatCompletion>;
@@ -107,24 +107,21 @@ export class BetaToolRunner<Stream extends boolean> {
107107
this.#toolResponse = undefined;
108108
this.#iterationCount++;
109109

110-
const { max_iterations, ...params } = this.#state.params;
110+
const { ...params } = this.#state.params;
111111
if (params.stream) {
112-
throw new Error('TODO'); // TODO
113-
// stream = this.client.beta.chat.completions.stream({ ...params });
114-
// this.#message = stream.finalMessage();
115-
// // Make sure that this promise doesn't throw before we get the option to do something about it.
116-
// // Error will be caught when we call await this.#message ultimately
117-
// this.#message.catch(() => {});
118-
// yield stream as any;
112+
stream = this.client.beta.chat.completions.stream({ ...params, stream: true });
113+
this.#message = stream.finalMessage();
114+
// Make sure that this promise doesn't throw before we get the option to do something about it.
115+
// Error will be caught when we call await this.#message ultimately
116+
this.#message?.catch(() => {});
117+
yield stream as any;
119118
} else {
120-
console.log('making request with params:', JSON.stringify(params, null, 2));
121119
this.#message = this.client.beta.chat.completions.create({
122120
stream: false,
123121
tools: params.tools,
124122
messages: params.messages,
125123
model: params.model,
126124
});
127-
console.log('Message created:', JSON.stringify(await this.#message, null, 2));
128125
yield this.#message as any;
129126
}
130127

@@ -133,27 +130,24 @@ export class BetaToolRunner<Stream extends boolean> {
133130
}
134131

135132
// TODO: we should probably hit the user with a callback or somehow offer for them to choice between the choices
136-
const { choices } = await this.#message;
137-
138133
if (!this.#firstChoiceInCurrentMessage) {
139134
throw new Error('No choices found in message'); // TODO: use better error
140135
}
141136

142-
const { role: firstChoiceRole, content: firstChoiceContent } = this.#firstChoiceInCurrentMessage;
143-
144137
if (!this.#mutated) {
145-
console.log(choices);
146138
// this.#state.params.messages.push({ role, content }); TODO: we want to add all
147-
this.#state.params.messages.push(this.#firstChoiceInCurrentMessage as ChatCompletionMessageParam);
139+
this.#state.params.messages.push(this.#firstChoiceInCurrentMessage);
148140
}
149141

150-
const toolMessage = await this.#generateToolResponse((await this.#message).choices[0]!);
151-
console.log('Tool message:', toolMessage);
152-
if (toolMessage) {
153-
this.#state.params.messages.push(toolMessage);
142+
const toolMessages = await this.#generateToolResponse(await this.#message);
143+
if (toolMessages) {
144+
for (const toolMessage of toolMessages) {
145+
this.#state.params.messages.push(toolMessage);
146+
}
154147
}
155148

156-
if (!toolMessage && !this.#mutated) {
149+
// TODO: make sure this is correct?
150+
if (!toolMessages && !this.#mutated) {
157151
break;
158152
}
159153
} finally {
@@ -229,17 +223,17 @@ export class BetaToolRunner<Stream extends boolean> {
229223
if (!message) {
230224
return null;
231225
}
232-
console.log("Message:", message[0]);
233-
// TODO: this cast is probably bad
234-
return this.#generateToolResponse(message[0]);
226+
return this.#generateToolResponse(message);
235227
}
236228

237-
async #generateToolResponse(lastMessage: ChatCompletion.Choice) {
238-
console.log('Last message:', lastMessage.message);
229+
async #generateToolResponse(lastMessage: ChatCompletion | ChatCompletionMessageParam) {
239230
if (this.#toolResponse !== undefined) {
240231
return this.#toolResponse;
241232
}
242-
this.#toolResponse = generateToolResponse(this.#state.params, lastMessage.message!); // TODO: maybe undefined
233+
this.#toolResponse = generateToolResponse(
234+
lastMessage,
235+
this.#state.params.tools.filter((tool): tool is BetaRunnableTool<any> => 'run' in tool),
236+
);
243237
return this.#toolResponse;
244238
}
245239

@@ -339,70 +333,79 @@ export class BetaToolRunner<Stream extends boolean> {
339333
}
340334

341335
async function generateToolResponse(
342-
params: BetaToolRunnerParams,
343-
lastMessageFirstChoice = params.messages.at(-1),
344-
): Promise<ChatCompletionToolMessageParam | null> {
336+
params: ChatCompletion | ChatCompletionMessageParam,
337+
tools: BetaRunnableTool<any>[],
338+
): Promise<null | ChatCompletionToolMessageParam[]> {
339+
if (!('choices' in params)) {
340+
return null;
341+
}
342+
const { choices } = params;
343+
const lastMessage = choices[0]?.message;
344+
if (!lastMessage) {
345+
return null;
346+
}
347+
345348
// Only process if the last message is from the assistant and has tool use blocks
346349
if (
347-
!lastMessageFirstChoice ||
348-
lastMessageFirstChoice.role !== 'assistant' ||
349-
!lastMessageFirstChoice.tool_calls ||
350-
lastMessageFirstChoice.tool_calls.length === 0
350+
!lastMessage ||
351+
lastMessage.role !== 'assistant' ||
352+
!lastMessage.content ||
353+
typeof lastMessage.content === 'string'
351354
) {
352355
return null;
353356
}
354357

355-
const toolUseBlocks = lastMessageFirstChoice.tool_calls.filter((toolCall) => toolCall.type === 'function');
356-
if (toolUseBlocks.length === 0) {
358+
const { tool_calls: prevToolCalls = [] } = lastMessage;
359+
360+
if ((lastMessage.tool_calls ?? []).length === 0) {
357361
return null;
358362
}
359363

360-
const toolResults = await Promise.all(
361-
toolUseBlocks.map(async (toolUse) => {
362-
// TODO: we should be able to infer that toolUseBlocks is FunctionDefinition[] or can cast it (maybe!)
363-
const tool = params.tools.find(
364-
(t) =>
365-
('name' in t && 'function' in t ? t.function.name
366-
: 'name' in t ? t.name
367-
: undefined) === toolUse.function.name,
368-
);
369-
if (!tool || !('run' in tool)) {
370-
return {
371-
type: 'tool_result' as const,
372-
tool_use_id: toolUse.id,
373-
content: `Error: Tool '${toolUse.function.name}' not found`,
374-
is_error: true,
375-
};
376-
}
377-
378-
try {
379-
let input = JSON.parse(toolUse.function.arguments);
380-
if ('parse' in tool && tool.parse) {
381-
input = tool.parse(input);
364+
return (
365+
await Promise.all(
366+
prevToolCalls.map(async (toolUse) => {
367+
if (toolUse.type !== 'function') return; // TODO: what about other calls?
368+
369+
const tool = tools.find(
370+
(t) => t.type === 'function' && toolUse.function.name === t.function.name,
371+
) as BetaRunnableTool;
372+
373+
if (!tool || !('run' in tool)) {
374+
return {
375+
type: 'tool_result' as const,
376+
tool_call_id: toolUse.id,
377+
content: `Error: Tool '${toolUse.function.name}' not found`,
378+
is_error: true,
379+
};
382380
}
383381

384-
const result = await tool.run(input);
385-
return {
386-
type: 'tool_result' as const,
387-
tool_use_id: toolUse.id,
388-
content: result,
389-
};
390-
} catch (error) {
391-
return {
392-
type: 'tool_result' as const,
393-
tool_use_id: toolUse.id,
394-
content: `Error: ${error instanceof Error ? error.message : String(error)}`,
395-
is_error: true,
396-
};
397-
}
398-
}),
399-
);
382+
try {
383+
let input = toolUse.function.arguments;
384+
input = tool.parse(input);
400385

401-
return {
402-
role: 'tool' as const,
403-
content: JSON.stringify(toolResults),
404-
tool_call_id: toolUseBlocks[0]!.id,
405-
};
386+
const result = await tool.run(input);
387+
return {
388+
type: 'tool_result' as const,
389+
tool_call_id: toolUse.id,
390+
content: typeof result === 'string' ? result : JSON.stringify(result),
391+
};
392+
} catch (error) {
393+
return {
394+
type: 'tool_result' as const,
395+
tool_call_id: toolUse.id,
396+
content: `Error: ${error instanceof Error ? error.message : String(error)}`,
397+
is_error: true,
398+
};
399+
}
400+
}),
401+
)
402+
)
403+
.filter((result): result is NonNullable<typeof result> => result != null)
404+
.map((toolResult) => ({
405+
role: 'tool' as const,
406+
content: toolResult.content,
407+
tool_call_id: toolResult.tool_call_id,
408+
}));
406409
}
407410

408411
// vendored from typefest just to make things look a bit nicer on hover

0 commit comments

Comments
 (0)