Skip to content

Commit 8ce7151

Browse files
committed
fix(evals): infer task input types from eval data
1 parent 2989dd2 commit 8ce7151

File tree

3 files changed

+90
-5
lines changed

3 files changed

+90
-5
lines changed

packages/ai/src/evals/builder.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class EvalBuilderImpl<
9595
}
9696

9797
// Call existing Eval function - this handles all Vitest registration
98-
Eval<TInput, TExpected, TOutput>(finalName, finalParams);
98+
Eval(finalName, finalParams);
9999
}
100100
}
101101

packages/ai/src/evals/eval.ts

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,30 @@ declare module 'vitest' {
5050

5151
const createVersionId = customAlphabet('0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ', 10);
5252

53+
type EvalDataSourceBase = { input: unknown; expected: unknown };
54+
type EvalDataSource =
55+
| readonly EvalDataSourceBase[]
56+
| Promise<readonly EvalDataSourceBase[]>
57+
| (() => readonly EvalDataSourceBase[] | Promise<readonly EvalDataSourceBase[]>);
58+
59+
type ResolveEvalData<TData extends EvalDataSource> = Awaited<
60+
TData extends (...args: any[]) => infer TResult ? TResult : TData
61+
>;
62+
63+
type InferEvalDataInput<TData extends EvalDataSource> =
64+
ResolveEvalData<TData> extends readonly (infer TRecord)[]
65+
? TRecord extends { input: infer TInput }
66+
? TInput
67+
: never
68+
: never;
69+
70+
type InferEvalDataExpected<TData extends EvalDataSource> =
71+
ResolveEvalData<TData> extends readonly (infer TRecord)[]
72+
? TRecord extends { expected: infer TExpected }
73+
? TExpected
74+
: never
75+
: never;
76+
5377
type RunTaskFailureDetails = {
5478
duration: number;
5579
outOfScopeFlags: OutOfScopeFlagAccess[];
@@ -111,17 +135,24 @@ function getRunTaskFailureDetails(error: unknown): RunTaskFailureDetails | undef
111135
* ```
112136
*/
113137
export function Eval<
114-
TInput,
115-
TExpected,
116-
TOutput,
138+
TData extends EvalDataSource,
139+
TInput extends string | Record<string, any> = InferEvalDataInput<TData>,
140+
TExpected extends string | Record<string, any> = InferEvalDataExpected<TData>,
141+
TOutput extends string | Record<string, any> = string | Record<string, any>,
117142
Name extends string = string,
118143
Capability extends string = string,
119144
Step extends string = string,
120145
>(
121146
name: ValidateName<Name>,
122-
params: Omit<EvalParams<TInput, TExpected, TOutput>, 'capability' | 'step'> & {
147+
params: Omit<
148+
EvalParams<TInput, TExpected, TOutput>,
149+
'capability' | 'step' | 'data' | 'task' | 'scorers'
150+
> & {
123151
capability: ValidateName<Capability>;
124152
step?: ValidateName<Step> | undefined;
153+
data: TData;
154+
task: EvalTask<TInput, TExpected, TOutput>;
155+
scorers: ReadonlyArray<ScorerLike<TInput, TExpected, TOutput>>;
125156
},
126157
): void {
127158
// Record eval name for validation
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import { describe, it } from 'vitest';
2+
import { Eval } from '../../src/evals';
3+
import { Scorer } from '../../src/scorers/scorers';
4+
5+
describe('Eval type inference', () => {
6+
it('infers task input and expected from data when scorer omits input', () => {
7+
const answerSimilarityScorer = Scorer(
8+
'answer-similarity',
9+
({ output, expected }: { output: string; expected: string }) => {
10+
output;
11+
expected;
12+
return 1;
13+
},
14+
);
15+
16+
const compileOnly = () =>
17+
Eval('name-apl-query', {
18+
capability: 'name_query',
19+
data: () => [
20+
{
21+
input: "['nginx-access-logs'] | where status >= 500",
22+
expected: 'Nginx 5xx Errors',
23+
},
24+
],
25+
task: async ({ input }: { input: string }) => input,
26+
scorers: [answerSimilarityScorer],
27+
});
28+
29+
compileOnly;
30+
});
31+
32+
it('rejects task input that conflicts with the data source', () => {
33+
const OutputOnlyScorer = Scorer(
34+
'output-only',
35+
({ output }: { output: string }) => output.length > 0,
36+
);
37+
38+
const invalid = () =>
39+
Eval('mismatched-task-input', {
40+
capability: 'name_query',
41+
data: () => [
42+
{
43+
input: 'foo',
44+
expected: 'bar',
45+
},
46+
],
47+
// @ts-expect-error task input must match the data input type
48+
task: async ({ input }: { input: number }) => String(input),
49+
scorers: [OutputOnlyScorer],
50+
});
51+
52+
invalid;
53+
});
54+
});

0 commit comments

Comments
 (0)