|
1 | 1 | import { desc, eq, inArray, sql, sum } from "drizzle-orm" |
2 | 2 |
|
| 3 | +import { ToolUsage } from "@evals/types" |
| 4 | + |
3 | 5 | import { RecordNotFoundError, RecordNotCreatedError } from "./errors.js" |
4 | 6 | import type { InsertRun, UpdateRun } from "../schema.js" |
5 | 7 | import { insertRunSchema, schema } from "../schema.js" |
6 | 8 | import { db } from "../db.js" |
7 | 9 | import { createTaskMetrics } from "./taskMetrics.js" |
| 10 | +import { getTasks } from "./tasks.js" |
8 | 11 |
|
9 | 12 | const table = schema.runs |
10 | 13 |
|
@@ -71,17 +74,30 @@ export const finishRun = async (runId: number) => { |
71 | 74 | throw new RecordNotFoundError() |
72 | 75 | } |
73 | 76 |
|
| 77 | + const tasks = await getTasks(runId) |
| 78 | + |
| 79 | + const toolUsage = tasks.reduce((acc, task) => { |
| 80 | + Object.entries(task.taskMetrics?.toolUsage || {}).forEach(([key, { attempts, failures }]) => { |
| 81 | + const tool = key as keyof ToolUsage |
| 82 | + acc[tool] ??= { attempts: 0, failures: 0 } |
| 83 | + acc[tool].attempts += attempts |
| 84 | + acc[tool].failures += failures |
| 85 | + }) |
| 86 | + |
| 87 | + return acc |
| 88 | + }, {} as ToolUsage) |
| 89 | + |
74 | 90 | const { passed, failed, ...rest } = values |
75 | | - const taskMetrics = await createTaskMetrics(rest) |
| 91 | + const taskMetrics = await createTaskMetrics({ ...rest, toolUsage }) |
76 | 92 | await updateRun(runId, { taskMetricsId: taskMetrics.id, passed, failed }) |
77 | 93 |
|
78 | | - const run = await db.query.runs.findFirst({ where: eq(table.id, runId), with: { taskMetrics: true } }) |
| 94 | + const run = await findRun(runId) |
79 | 95 |
|
80 | 96 | if (!run) { |
81 | 97 | throw new RecordNotFoundError() |
82 | 98 | } |
83 | 99 |
|
84 | | - return run |
| 100 | + return { ...run, taskMetrics } |
85 | 101 | } |
86 | 102 |
|
87 | 103 | export const deleteRun = async (runId: number) => { |
|
0 commit comments