Skip to content

Commit 5710b61

Browse files
authored
Merge pull request #85 from oslabs-beta/stephany/jest
feat: Add jest test for RunManager
2 parents a327093 + 3970cc5 commit 5710b61

File tree

5 files changed

+312
-254
lines changed

5 files changed

+312
-254
lines changed

mlflow/src/utils/interface.ts

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,40 @@
11
export interface Run {
2-
info: {
3-
run_id: string;
4-
run_name: string;
5-
experiment_id: string;
6-
status: string;
7-
start_time: number;
8-
end_time: number;
9-
artifact_uri: string;
10-
lifecycle_stage: string;
11-
};
12-
data: {
13-
metrics: Metrics[];
14-
params: Params[];
15-
tags: Tags[];
2+
info: RunInfo;
3+
data: RunData;
4+
inputs?: RunInputs;
5+
}
6+
7+
interface RunInfo {
8+
run_id: string;
9+
run_name: string;
10+
experiment_id: string;
11+
status: 'RUNNING' | 'SCHEDULED' | 'FINISHED' | 'FAILED' | 'KILLED';
12+
start_time: number;
13+
end_time: number;
14+
artifact_uri: string;
15+
lifecycle_stage: string;
16+
}
17+
18+
interface RunData {
19+
metrics: Metrics[];
20+
params: Params[];
21+
tags: Tags[];
22+
}
23+
24+
interface RunInputs {
25+
dataset_inputs?: DatasetInput[];
26+
}
27+
28+
interface DatasetInput {
29+
tags?: Tags[];
30+
dataset: {
31+
name: string;
32+
digest: string;
33+
source_type: string;
34+
source: string;
35+
schema?: string;
36+
profile?: string;
1637
};
17-
inputs: Array<{
18-
tags?: Tags[];
19-
dataset: {
20-
name: string;
21-
digest: string;
22-
source_type: string;
23-
source: string;
24-
schema?: string;
25-
profile?: string;
26-
};
27-
}>;
2838
}
2939

3040
export interface Metrics {
@@ -47,3 +57,20 @@ export interface MetricHistoryResponse {
4757
metrics: Metrics[];
4858
next_page_token?: string;
4959
}
60+
61+
export interface SearchedRuns {
62+
runs: Run[];
63+
next_page_token?: string;
64+
}
65+
66+
export interface CleanupRuns {
67+
deletedRuns: Run[];
68+
total: number;
69+
dryRun: boolean;
70+
}
71+
72+
export interface CopyRun {
73+
originalRunId: string;
74+
newRunId: string;
75+
targetExperimentId: string;
76+
}

mlflow/src/workflows/RunManager.ts

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
import RunClient from '@tracking/RunClient';
22
import ModelVersionClient from '@model-registry/ModelVersionClient';
3-
import { Run } from '@utils/interface';
3+
import { Run, SearchedRuns } from '@utils/interface';
44
import { ApiError } from '@utils/apiError';
55

6-
interface keyable {
7-
[key: string]: any;
8-
}
9-
106
class RunManager {
117
private runClient: RunClient;
128
private modelVersion: ModelVersionClient;
@@ -34,65 +30,61 @@ class RunManager {
3430
): Promise<object> {
3531
const deletedRuns = [];
3632
const keepRunIds = new Set();
37-
let pageToken = null;
33+
const runViewType = undefined;
34+
let pageToken = undefined;
3835
const maxResults = 1000;
3936

4037
try {
4138
do {
4239
// get all runs
43-
const searchResult: keyable = await this.runClient.searchRuns(
40+
const searchResult = (await this.runClient.searchRuns(
4441
experimentIds,
4542
'',
46-
undefined, // run_view_type
43+
runViewType,
4744
maxResults,
4845
['start_time DESC'],
4946
pageToken
50-
);
47+
)) as SearchedRuns;
5148

5249
// get runs that match the keep crteria
53-
const keepRunsResult: keyable = await this.runClient.searchRuns(
50+
const keepRunsResult = (await this.runClient.searchRuns(
5451
experimentIds,
5552
query_string,
56-
undefined, // run_view_type
53+
runViewType,
5754
maxResults,
5855
['start_time DESC'],
5956
pageToken
60-
);
57+
)) as SearchedRuns;
6158

6259
// Add runs from keepRunsResult to keepResult
63-
keepRunsResult.runs.forEach((run: Run) =>
60+
keepRunsResult.runs?.forEach((run: Run) =>
6461
keepRunIds.add(run.info.run_id)
6562
);
6663

67-
// Add runs without the specified metric key to keepRunIds
64+
// Process each run
6865
for (const run of searchResult.runs) {
69-
if (Array.isArray(run.data.metrics)) {
70-
const hasMetricKey = run.data.metrics.some(
71-
(metric: { key: string }) => metric.key === metric_key
72-
);
73-
if (!hasMetricKey) {
74-
keepRunIds.add(run.info.run_id);
75-
}
76-
} else {
77-
// If run.data.metrics is not an array (e.g., undefined), keep the run
78-
keepRunIds.add(run.info.run_id);
79-
}
80-
}
66+
const metrics = run.data.metrics;
67+
const hasMetricKey = Array.isArray(metrics)
68+
? metrics.some((metric) => metric.key === metric_key)
69+
: metric_key in metrics;
8170

82-
// Delete runs that are not in keepRunIds
83-
for (const run of searchResult.runs) {
84-
if (!keepRunIds.has(run.info.run_id)) {
71+
if (!hasMetricKey || keepRunIds.has(run.info.run_id)) {
72+
keepRunIds.add(run.info.run_id);
73+
} else {
74+
deletedRuns.push(run);
8575
if (!dryRun) {
8676
await this.runClient.deleteRun(run.info.run_id);
8777
}
88-
deletedRuns.push(run);
8978
}
9079
}
91-
92-
pageToken = searchResult.page_token;
80+
pageToken = searchResult.next_page_token;
9381
} while (pageToken);
9482

95-
return { deletedRuns, total: deletedRuns.length, dryRun };
83+
return {
84+
deletedRuns: deletedRuns,
85+
total: deletedRuns.length,
86+
dryRun,
87+
};
9688
} catch (error) {
9789
if (error instanceof ApiError) {
9890
console.error(`API Error (${error.statusCode}): ${error.message}`);
@@ -119,14 +111,14 @@ class RunManager {
119111
): Promise<object> {
120112
try {
121113
// get original run
122-
const originalRun: keyable = await this.runClient.getRun(runId);
114+
const originalRun = (await this.runClient.getRun(runId)) as Run;
123115

124116
// create a new run in the target experiment
125-
const newRun: keyable = await this.runClient.createRun(
117+
const newRun = (await this.runClient.createRun(
126118
targetExperimentId,
127119
undefined,
128120
originalRun.info.start_time
129-
);
121+
)) as Run;
130122

131123
const newRunId = newRun.info.run_id;
132124

@@ -138,6 +130,7 @@ class RunManager {
138130
originalRun.info.status,
139131
endTime
140132
);
133+
141134
if (originalRun.info.lifecycle_stage !== 'active') {
142135
await this.runClient.setTag(
143136
newRunId,

mlflow/tests/RunClient.test.ts

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,12 @@ describe('RunClient', () => {
128128

129129
// POST - Restore a deleted run
130130
describe('restoreRun', () => {
131+
let run: Run;
132+
beforeEach(async () => {
133+
run = (await runClient.createRun(experimentId)) as Run;
134+
});
135+
131136
test('- Should restore a deleted run with run_id', async () => {
132-
const run = (await runClient.createRun(experimentId)) as Run;
133137
await runClient.deleteRun(run.info.run_id);
134138

135139
// check if the run is marked as deleted
@@ -147,8 +151,6 @@ describe('RunClient', () => {
147151
});
148152

149153
test('- Should not throw error when trying to restore a non-deleted run', async () => {
150-
const run = (await runClient.createRun(experimentId)) as Run;
151-
152154
// Attempt to restore a non-deleted run
153155
await expect(
154156
runClient.restoreRun(run.info.run_id)
@@ -403,17 +405,19 @@ describe('RunClient', () => {
403405

404406
// POST - Log a batch of metrics, params, and tags for a run
405407
describe('logBatch', () => {
406-
test('- Should not throw error with just run_id', async () => {
407-
const run = (await runClient.createRun(experimentId)) as Run;
408+
let run: Run;
409+
410+
beforeEach(async () => {
411+
run = (await runClient.createRun(experimentId)) as Run;
412+
});
408413

414+
test('- Should not throw error with just run_id', async () => {
409415
await expect(
410416
runClient.logBatch(run.info.run_id)
411417
).resolves.toBeUndefined();
412418
});
413419

414420
test('- Should log batch with optional metrics', async () => {
415-
const run = (await runClient.createRun(experimentId)) as Run;
416-
417421
const metrics: Metrics[] = [
418422
{ key: 'accuracy', value: 0.83, timestamp: 1694000700000 },
419423
{ key: 'loss', value: 0.18, timestamp: 1694000700000 },
@@ -448,8 +452,6 @@ describe('RunClient', () => {
448452
});
449453

450454
test('- Should log batch with optional params', async () => {
451-
const run = (await runClient.createRun(experimentId)) as Run;
452-
453455
const params: Params[] = [
454456
{ key: 'learning_rate', value: '0.0001' },
455457
{ key: 'batch_size', value: '256' },
@@ -464,8 +466,6 @@ describe('RunClient', () => {
464466
});
465467

466468
test('- Should log batch with optional tags', async () => {
467-
const run = (await runClient.createRun(experimentId)) as Run;
468-
469469
const tags: Tags[] = [
470470
{ key: 'model_type', value: 'GradientBoosting' },
471471
{ key: 'data_version', value: 'v1.7' },
@@ -480,8 +480,6 @@ describe('RunClient', () => {
480480
});
481481

482482
test('- Should be able to log up to 1000 metrics', async () => {
483-
const run = (await runClient.createRun(experimentId)) as Run;
484-
485483
const metrics = Array.from({ length: 1000 }, (_, index) => ({
486484
key: `metric${index}`,
487485
value: index,
@@ -495,8 +493,6 @@ describe('RunClient', () => {
495493
});
496494

497495
test('- Should throw error when exceeding 1000 metrics', async () => {
498-
const run = (await runClient.createRun(experimentId)) as Run;
499-
500496
const metrics = Array.from({ length: 1001 }, (_, index) => ({
501497
key: `metric${index}`,
502498
value: index,
@@ -510,8 +506,6 @@ describe('RunClient', () => {
510506
});
511507

512508
test('- Should be able to log up to 100 params', async () => {
513-
const run = (await runClient.createRun(experimentId)) as Run;
514-
515509
const params = Array.from({ length: 100 }, (_, index) => ({
516510
key: `param${index}`,
517511
value: `value${index}`,
@@ -523,8 +517,6 @@ describe('RunClient', () => {
523517
});
524518

525519
test('- Should throw error when exceeding 100 params', async () => {
526-
const run = (await runClient.createRun(experimentId)) as Run;
527-
528520
const params = Array.from({ length: 101 }, (_, index) => ({
529521
key: `param${index}`,
530522
value: `value${index}`,
@@ -536,8 +528,6 @@ describe('RunClient', () => {
536528
});
537529

538530
test('- Should be able to log up to 100 tags', async () => {
539-
const run = (await runClient.createRun(experimentId)) as Run;
540-
541531
const tags = Array.from({ length: 100 }, (_, index) => ({
542532
key: `tag${index}`,
543533
value: `value${index}`,
@@ -549,8 +539,6 @@ describe('RunClient', () => {
549539
});
550540

551541
test('- Should throw error when exceeding 100 tags', async () => {
552-
const run = (await runClient.createRun(experimentId)) as Run;
553-
554542
const tags = Array.from({ length: 101 }, (_, index) => ({
555543
key: `tag${index}`,
556544
value: `value${index}`,

0 commit comments

Comments
 (0)