diff --git a/.github/workflows/mlflow-js.yaml b/.github/workflows/mlflow-js.yaml index 1209f844..8f1664ee 100644 --- a/.github/workflows/mlflow-js.yaml +++ b/.github/workflows/mlflow-js.yaml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - dev jobs: build-and-test: diff --git a/mlflow/tests/RunClient.test.ts b/mlflow/tests/RunClient.test.ts index fe0ec1d3..2b79997d 100644 --- a/mlflow/tests/RunClient.test.ts +++ b/mlflow/tests/RunClient.test.ts @@ -1,35 +1,50 @@ -import { describe, test, expect, beforeAll, beforeEach } from '@jest/globals'; +import { + describe, + test, + expect, + beforeAll, + beforeEach, + afterAll, +} from '@jest/globals'; import RunClient from '../src/tracking/RunClient'; import ExperimentClient from '../src/tracking/ExperimentClient'; -import { - Run, - Metrics, - Params, - Tags, - MetricHistoryResponse, -} from '../src/utils/interface'; +import { Run, Metrics, MetricHistoryResponse } from '../src/utils/interface'; +import { TRACKING_SERVER_URI } from './testUtils'; +import { TEST_DATA } from './testUtils'; describe('RunClient', () => { let runClient: RunClient; let experimentClient: ExperimentClient; let experimentId: string; + let run: Run; + const testIds: string[] = []; beforeAll(async () => { await new Promise((resolve) => setTimeout(resolve, 2000)); - runClient = new RunClient('http://127.0.0.1:5002'); - experimentClient = new ExperimentClient('http://127.0.0.1:5002'); + runClient = new RunClient(TRACKING_SERVER_URI); + experimentClient = new ExperimentClient(TRACKING_SERVER_URI); // Generate the experiment ID for test runs const timestamp = Date.now(); experimentId = await experimentClient.createExperiment( `Testing ${timestamp}` ); + testIds.push(experimentId); + }); + + beforeEach(async () => { + run = (await runClient.createRun(experimentId)) as Run; + }); + + afterAll(async () => { + for (const testId of testIds) { + await experimentClient.deleteExperiment(testId); + } }); // POST - Create a new run within an experiment describe('createRun', () => { test('- Should create a run with experiment_id', async () => { - const run = (await runClient.createRun(experimentId)) as Run; expect(run.info.experiment_id).toBe(experimentId); }); @@ -46,10 +61,7 @@ describe('RunClient', () => { }); test('- Should create a run with optional tags', async () => { - const tags = [ - { key: 'test_key1', value: 'test_value1' }, - { key: 'test_key2', value: 'test_value2' }, - ]; + const { tags } = TEST_DATA; const run = (await runClient.createRun( experimentId, @@ -67,7 +79,7 @@ describe('RunClient', () => { test('- Should create a run with all parameters', async () => { const run_name = 'Test Run 2'; const start_time = Date.now(); - const tags = [{ key: 'test_key', value: 'test_value' }]; + const { tags } = TEST_DATA; const run = (await runClient.createRun( experimentId, @@ -101,8 +113,6 @@ describe('RunClient', () => { // DELETE - Mark a run for deletion describe('deleteRun', () => { test('- Should delete a run with run_id', async () => { - const run = (await runClient.createRun(experimentId)) as Run; - await expect(runClient.deleteRun(run.info.run_id)).resolves.not.toThrow(); // check if the run's lifecycle_stage has changed to "deleted" @@ -120,7 +130,7 @@ describe('RunClient', () => { await expect(runClient.deleteRun(invalid_id)).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching(/Error deleting run:.+invalid_id/), + message: expect.stringMatching(/Error deleting run:.+/), }) ); }); @@ -128,11 +138,6 @@ describe('RunClient', () => { // POST - Restore a deleted run describe('restoreRun', () => { - let run: Run; - beforeEach(async () => { - run = (await runClient.createRun(experimentId)) as Run; - }); - test('- Should restore a deleted run with run_id', async () => { await runClient.deleteRun(run.info.run_id); @@ -171,7 +176,7 @@ describe('RunClient', () => { await expect(runClient.restoreRun(invalid_id)).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching(/Error restoring run:.+invalid_id/), + message: expect.stringMatching(/Error restoring run:.+/), }) ); }); @@ -180,21 +185,8 @@ describe('RunClient', () => { // GET - Get metadata, metrics, params, and tags for a run describe('getRun', () => { test('- Should retrieve metadata for a run with run_id', async () => { - const run = (await runClient.createRun(experimentId)) as Run; - - // create dummy data for created run - const metrics: Metrics[] = [ - { key: 'accuracy', value: 0.83, timestamp: 1694000700000 }, - { key: 'loss', value: 0.18, timestamp: 1694000700000 }, - ]; - const params = [ - { key: 'learning_rate', value: '0.0001' }, - { key: 'batch_size', value: '256' }, - ]; - const tags = [ - { key: 'model_type', value: 'GradientBoosting' }, - { key: 'data_version', value: 'v1.7' }, - ]; + const { metrics, params, tags } = TEST_DATA; + await runClient.logBatch(run.info.run_id, metrics, params, tags); const fetchedRun = (await runClient.getRun(run.info.run_id)) as Run; @@ -232,7 +224,7 @@ describe('RunClient', () => { await expect(runClient.getRun(invalid_id)).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching(/Error fetching run:.+invalid_id/), + message: expect.stringMatching(/Error fetching run:.+/), }) ); }); @@ -240,12 +232,6 @@ describe('RunClient', () => { // POST - Update run metadata describe('updateRun', () => { - let run: Run; - - beforeEach(async () => { - run = (await runClient.createRun(experimentId)) as Run; - }); - // parameterized testing for input status const allStatuses = [ 'RUNNING', @@ -304,7 +290,7 @@ describe('RunClient', () => { // test missing run_id // @ts-expect-error: testing for missing arguments await expect(runClient.updateRun()).rejects.toThrow( - /Error updating run:/ + /Error updating run:.+/ ); // Test invalid run_id @@ -312,7 +298,7 @@ describe('RunClient', () => { await expect(runClient.updateRun(invalid_id)).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching(/Error updating run:.+invalid_id/), + message: expect.stringMatching(/Error updating run:.+/), }) ); @@ -330,19 +316,18 @@ describe('RunClient', () => { // POST - Log a metric for a run describe('logMetric', () => { - let run: Run; - const key = 'accuracy'; - const value = 0.9; - - beforeEach(async () => { - run = (await runClient.createRun(experimentId)) as Run; - }); + const { metrics } = TEST_DATA; test('- Should log a metric with run_id, key, value, and timestamp', async () => { const timestamp = Date.now(); await expect( - runClient.logMetric(run.info.run_id, key, value, timestamp) + runClient.logMetric( + run.info.run_id, + metrics[0].key, + metrics[0].value, + timestamp + ) ).resolves.not.toThrow(); // fetch run to confirm changes @@ -352,8 +337,8 @@ describe('RunClient', () => { expect(fetchedRun.data.metrics).toEqual( expect.arrayContaining([ expect.objectContaining({ - key: key, - value: value, + key: metrics[0].key, + value: metrics[0].value, timestamp: expect.any(Number), step: expect.any(Number), }), @@ -374,43 +359,41 @@ describe('RunClient', () => { await expect(runClient.logMetric()).rejects.toThrow(); // @ts-expect-error: testing for missing key and value await expect(runClient.logMetric(run.info.run_id)).rejects.toThrow(); - // @ts-expect-error: testing for all missing value - await expect(runClient.logMetric(run.info.run_id, key)).rejects.toThrow(); + await expect( + // @ts-expect-error: testing for all missing value + runClient.logMetric(run.info.run_id, metrics[0].key) + ).rejects.toThrow(); // test invalid run_id const invalid_id = 'invalid_id'; - await expect(runClient.logMetric(invalid_id, key, value)).rejects.toThrow( + await expect( + runClient.logMetric(invalid_id, metrics[0].key, metrics[0].value) + ).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching(/Error logging metric:.+invalid_id/), + message: expect.stringMatching(/Error logging metric:.+/), }) ); // test invalid key await expect( - runClient.logMetric(run.info.run_id, '', value) + runClient.logMetric(run.info.run_id, '', metrics[0].value) ).rejects.toThrow(); // test invalid value await expect( - runClient.logMetric(run.info.run_id, key, NaN) + runClient.logMetric(run.info.run_id, metrics[0].key, NaN) ).rejects.toThrow(); // All required args provided, should not throw await expect( - runClient.logMetric(run.info.run_id, key, value) + runClient.logMetric(run.info.run_id, metrics[0].key, metrics[0].value) ).resolves.not.toThrow(); }); }); // POST - Log a batch of metrics, params, and tags for a run describe('logBatch', () => { - let run: Run; - - beforeEach(async () => { - run = (await runClient.createRun(experimentId)) as Run; - }); - test('- Should not throw error with just run_id', async () => { await expect( runClient.logBatch(run.info.run_id) @@ -418,10 +401,7 @@ describe('RunClient', () => { }); test('- Should log batch with optional metrics', async () => { - const metrics: Metrics[] = [ - { key: 'accuracy', value: 0.83, timestamp: 1694000700000 }, - { key: 'loss', value: 0.18, timestamp: 1694000700000 }, - ]; + const { metrics } = TEST_DATA; await runClient.logBatch(run.info.run_id, metrics); @@ -441,21 +421,11 @@ describe('RunClient', () => { expect(fetchedMetric.timestamp).toBe(metric.timestamp); expect(fetchedMetric).toHaveProperty('step'); } - - const runTag = fetchedRun.data.tags?.find( - (tag) => tag.key === 'mlflow.runName' - ); - - expect(runTag?.key).toBe('mlflow.runName'); - expect(runTag?.value).toBe(run.info.run_name); }); }); test('- Should log batch with optional params', async () => { - const params: Params[] = [ - { key: 'learning_rate', value: '0.0001' }, - { key: 'batch_size', value: '256' }, - ]; + const { params } = TEST_DATA; await runClient.logBatch(run.info.run_id, undefined, params); @@ -466,10 +436,7 @@ describe('RunClient', () => { }); test('- Should log batch with optional tags', async () => { - const tags: Tags[] = [ - { key: 'model_type', value: 'GradientBoosting' }, - { key: 'data_version', value: 'v1.7' }, - ]; + const { tags } = TEST_DATA; await runClient.logBatch(run.info.run_id, undefined, undefined, tags); @@ -554,7 +521,7 @@ describe('RunClient', () => { await expect(runClient.logBatch()).rejects.toThrow(); // @ts-expect-error: testing for missing arguments await expect(runClient.logBatch()).rejects.toThrow( - /Error logging batch:/ + /Error logging batch:.+/ ); // test invalid run_id @@ -564,7 +531,7 @@ describe('RunClient', () => { await expect(runClient.logBatch(invalid_id)).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching(/Error logging batch:.+invalid_id/), + message: expect.stringMatching(/Error logging batch:.+/), }) ); }); @@ -577,28 +544,21 @@ describe('RunClient', () => { beforeEach(async () => { run = (await runClient.createRun(experimentId)) as Run; - validModel = { - artifact_path: 'pytorch_dnn', - flavors: { - python_function: { - env: 'conda.yaml', - loader_module: 'mlflow.pytorch', - model_path: 'model.pth', - python_version: '3.8.10', - }, - pytorch: { - model_data: 'model.pth', - pytorch_version: '1.9.0', - code: 'model-code', - }, - }, - utc_time_created: '2023-09-14 10:15:00.000000', - run_id: run.info.run_id, - }; + // const { validModel } = TEST_DATA; }); test('- Should log a model with run_id and model_json', async () => { - const model_json = JSON.stringify(validModel); + expect(run.info).toBeDefined(); + expect(run.info.run_id).toBeDefined(); + expect(run.info.run_name).toBeDefined(); + + // create moodel JSON witth run_id included + const modelWithRunId = { + ...TEST_DATA.validModel, + run_id: run.info.run_id, + }; + + const model_json = JSON.stringify(modelWithRunId); await expect( runClient.logModel(run.info.run_id, model_json) @@ -615,23 +575,29 @@ describe('RunClient', () => { expect(runNameTag?.key).toBe('mlflow.runName'); expect(runNameTag?.value).toBe(run.info.run_name); - // check mlflow.log-model.history ttag + // check mlflow.log-model.history tag const logModelHistoryTag = fetchedRun.data.tags?.find( (tag) => tag.key === 'mlflow.log-model.history' ); expect(logModelHistoryTag?.key).toBe('mlflow.log-model.history'); - const loggedModelHistory = JSON.parse(logModelHistoryTag?.value || '[]'); + let loggedModelHistory; + try { + loggedModelHistory = JSON.parse(logModelHistoryTag?.value || '[]'); + } catch (e) { + console.error('Failed to parse log-model.history:', e); + throw e; + } expect(loggedModelHistory).toHaveLength(1); const loggedModel = loggedModelHistory[0]; expect(loggedModel).toMatchObject({ run_id: run.info.run_id, - artifact_path: validModel.artifact_path, - utc_time_created: validModel.utc_time_created, - flavors: validModel.flavors, + artifact_path: TEST_DATA.validModel.artifact_path, + utc_time_created: TEST_DATA.validModel.utc_time_created, + flavors: TEST_DATA.validModel.flavors, }); }); @@ -649,7 +615,7 @@ describe('RunClient', () => { ).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching(/Error logging model:/), + message: expect.stringMatching(/Error logging model:.+/), }) ); // Test invalid model_json structure @@ -659,9 +625,7 @@ describe('RunClient', () => { ).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching( - /Error logging model: Model json is missing mandatory fields/ - ), + message: expect.stringMatching(/Error logging model:.+/), }) ); @@ -672,9 +636,7 @@ describe('RunClient', () => { ).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching( - /Error logging model: Malformed model info/ - ), + message: expect.stringMatching(/Error logging model:.+/), }) ); }); @@ -698,7 +660,6 @@ describe('RunClient', () => { ]; test('- Should log inputs with run_id and datasets', async () => { - const run = (await runClient.createRun(experimentId)) as Run; await runClient.logInputs(run.info.run_id, datasets); // fetch run to confirm changes @@ -709,8 +670,6 @@ describe('RunClient', () => { }); test('- Should handle errors and edge cases', async () => { - const run = (await runClient.createRun(experimentId)) as Run; - // test with invalid_id const invalid_id = 'invalid_id'; await expect(runClient.logInputs(invalid_id, datasets)).rejects.toThrow(); @@ -733,25 +692,19 @@ describe('RunClient', () => { // POST - Set a tag on a run describe('setTag', () => { - test('- Should set a tag on a run with run_id, key, and value', async () => { - const run = (await runClient.createRun(experimentId)) as Run; - const key = 'accuracy'; - const value = '0.99'; + const { tags } = TEST_DATA; - await runClient.setTag(run.info.run_id, key, value); + test('- Should set a tag on a run with run_id, key, and value', async () => { + await runClient.setTag(run.info.run_id, tags[0].key, tags[0].value); // fetch run to confirm changes const fetchedRun = (await runClient.getRun(run.info.run_id)) as Run; - const tag = fetchedRun.data.tags?.find((t) => t.key === key); + const tag = fetchedRun.data.tags?.find((t) => t.key === tags[0].key); expect(tag).toBeDefined(); - expect(tag?.value).toBe(value); + expect(tag?.value).toBe(tags[0].value); }); test('- Should handle errors and edge cases', async () => { - const run = (await runClient.createRun(experimentId)) as Run; - const key = 'accuracy'; - const value = '0.99'; - // test missing arguments // @ts-expect-error: testing for all missing arguments await expect(runClient.setTag()).rejects.toThrow(); @@ -761,15 +714,17 @@ describe('RunClient', () => { // test invalid run_id const invalid_id = 'invalid_id'; - await expect(runClient.setTag(invalid_id, key, value)).rejects.toThrow( + await expect( + runClient.setTag(invalid_id, tags[0].key, tags[0].value) + ).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching(/Error setting tag:.+invalid_id/), + message: expect.stringMatching(/Error setting tag:.+/), }) ); await expect( - runClient.setTag(run.info.run_id, '', value) + runClient.setTag(run.info.run_id, '', tags[0].value) ).rejects.toThrow(); await expect( @@ -790,26 +745,21 @@ describe('RunClient', () => { // POST - Delete a tag on a run describe('deleteTag', () => { - const key = 'test_key'; - const value = 'test_value'; + const { tags } = TEST_DATA; test('- Should delete a tag on a run with run_id and key', async () => { - const run = (await runClient.createRun(experimentId)) as Run; - - await runClient.setTag(run.info.run_id, key, value); + await runClient.setTag(run.info.run_id, tags[0].key, tags[0].value); - await runClient.deleteTag(run.info.run_id, key); + await runClient.deleteTag(run.info.run_id, tags[0].key); // fetch run to confirm changes const fetchedRun = (await runClient.getRun(run.info.run_id)) as Run; - expect(fetchedRun.data.tags).not.toContainEqual({ key, value }); + expect(fetchedRun.data.tags).not.toContainEqual(tags[0]); expect( - fetchedRun.data.tags.find((tag) => tag.key === key) + fetchedRun.data.tags.find((tag) => tag.key === tags[0].key) ).toBeUndefined(); }); test('- Should handle errors and edge cases', async () => { - const run = (await runClient.createRun(experimentId)) as Run; - // testing missing arguments // @ts-expect-error: testing for all missing arguments await expect(runClient.deleteTag()).rejects.toThrow(); @@ -818,10 +768,12 @@ describe('RunClient', () => { // test invalid run_id const invalid_id = 'invalid_id'; - await expect(runClient.deleteTag(invalid_id, key)).rejects.toThrow( + await expect( + runClient.deleteTag(invalid_id, tags[0].key) + ).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching(/Error deleting tag:.+invalid_id/), + message: expect.stringMatching(/Error deleting tag:.+/), }) ); @@ -831,38 +783,34 @@ describe('RunClient', () => { ).rejects.toThrow(); // All required args provided, should not throw - await runClient.setTag(run.info.run_id, key, value); + await runClient.setTag(run.info.run_id, tags[0].key, tags[0].value); await expect( - runClient.deleteTag(run.info.run_id, key) + runClient.deleteTag(run.info.run_id, tags[0].key) ).resolves.not.toThrow(); // verify the tag was deleted const fetchedRun = (await runClient.getRun(run.info.run_id)) as Run; - expect(fetchedRun.data.tags).not.toHaveProperty(key); + expect(fetchedRun.data.tags).not.toHaveProperty(tags[0].key); }); }); // POST - Log a param used for a run describe('logParam', () => { - test('- Should log a param used for a run with run_id, key, and value', async () => { - const run = (await runClient.createRun(experimentId)) as Run; + const { params } = TEST_DATA; - const key = 'learning_rate'; - const value = '0.001'; - await runClient.logParam(run.info.run_id, key, value); + test('- Should log a param used for a run with run_id, key, and value', async () => { + await runClient.logParam(run.info.run_id, params[0].key, params[0].value); // fetch run to confirm changes const fetchedRun = (await runClient.getRun(run.info.run_id)) as Run; - const param = fetchedRun.data.params?.find((p) => p.key === key); + const param = fetchedRun.data.params?.find( + (p) => p.key === params[0].key + ); expect(param).toBeDefined(); - expect(param?.value).toBe(value); + expect(param?.value).toBe(params[0].value); }); test('- Should handle errors and edge cases', async () => { - const run = (await runClient.createRun(experimentId)) as Run; - const key = 'learning_rate'; - const value = '0.001'; - // @ts-expect-error: testing for all missing arguments await expect(runClient.logParam()).rejects.toThrow(); // @ts-expect-error: testing for missing key and value @@ -870,31 +818,35 @@ describe('RunClient', () => { // Test invalid run_id const invalid_id = 'invalid_id'; - await expect(runClient.logParam(invalid_id, key, value)).rejects.toThrow( + await expect( + runClient.logParam(invalid_id, params[0].key, params[0].value) + ).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching(/Error logging param:.+invalid_id/), + message: expect.stringMatching(/Error logging param:.+/), }) ); // All required args provided, should not throw await expect( - runClient.logParam(run.info.run_id, key, value) + runClient.logParam(run.info.run_id, params[0].key, params[0].value) ).resolves.not.toThrow(); }); }); // Get a list of all valuse for the specified metric for a given run describe('getMetricHisotry', () => { - test('- Should get a list of all values for the specified metric for a given run with run_id and metric_key', async () => { - const run = (await runClient.createRun(experimentId)) as Run; - const key = 'accuracy'; - const value = 0.95; + const { metrics } = TEST_DATA; - await runClient.logMetric(run.info.run_id, key, value); + test('- Should get a list of all values for the specified metric for a given run with run_id and metric_key', async () => { + await runClient.logMetric( + run.info.run_id, + metrics[0].key, + metrics[0].value + ); const metricHistory = (await runClient.getMetricHistory( run.info.run_id, - key + metrics[0].key )) as MetricHistoryResponse; expect(metricHistory).toHaveProperty('metrics'); @@ -902,13 +854,14 @@ describe('RunClient', () => { expect(metricHistory.metrics.length).toBeGreaterThan(0); const loggedMetric = metricHistory.metrics.find( - (metric) => metric.key === key && metric.value === value + (metric) => + metric.key === metrics[0].key && metric.value === metrics[0].value ); expect(loggedMetric).toBeDefined(); if (loggedMetric) { - expect(loggedMetric).toHaveProperty('key', key); - expect(loggedMetric).toHaveProperty('value', value); + expect(loggedMetric).toHaveProperty('key', metrics[0].key); + expect(loggedMetric).toHaveProperty('value', metrics[0].value); } if (metricHistory.next_page_token) { @@ -921,7 +874,7 @@ describe('RunClient', () => { await expect(runClient.getMetricHistory()).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching(/Error fetching metric history:/), + message: expect.stringMatching(/Error fetching metric history:.+/), }) ); @@ -932,7 +885,7 @@ describe('RunClient', () => { await expect(runClient.getMetricHistory(invalid_id, key)).rejects.toThrow( expect.objectContaining({ name: 'ApiError', - message: expect.stringMatching(/Error fetching metric history:/), + message: expect.stringMatching(/Error fetching metric history:.+/), }) ); }); @@ -940,19 +893,21 @@ describe('RunClient', () => { // Search for runs that satisfy expressions describe('searchRuns', () => { - let run1: Run; - let run2: Run; + let searchRunsExpId: string; beforeEach(async () => { - // create two runs with different metrics - run1 = (await runClient.createRun(experimentId)) as Run; - await runClient.logMetric(run1.info.run_id, 'metric', 1.0); - await new Promise((resolve) => setTimeout(resolve, 100)); - run2 = (await runClient.createRun(experimentId)) as Run; - await runClient.logMetric(run2.info.run_id, 'metric', 2.0); + searchRunsExpId = await experimentClient.createExperiment( + `Search Runs Test ${Date.now()}` + ); + testIds.push(searchRunsExpId); }); test('- Should search for runs with experiment_ids and filter', async () => { + const run1 = (await runClient.createRun(experimentId)) as Run; + await runClient.logMetric(run1.info.run_id, 'metric', 1.0); + const run2 = (await runClient.createRun(experimentId)) as Run; + await runClient.logMetric(run2.info.run_id, 'metric', 2.0); + // search for runs with only tthe experiment_ids const searchResult = (await runClient.searchRuns([ experimentId.toString(), @@ -973,6 +928,13 @@ describe('RunClient', () => { }); test('- Should handle different run_view_types', async () => { + // create two runs with different metrics + const run1 = (await runClient.createRun(experimentId)) as Run; + await runClient.logMetric(run1.info.run_id, 'metric', 1.0); + await new Promise((resolve) => setTimeout(resolve, 100)); + const run2 = (await runClient.createRun(experimentId)) as Run; + await runClient.logMetric(run2.info.run_id, 'metric', 2.0); + await runClient.deleteRun(run2.info.run_id); // search for active runs only @@ -1008,6 +970,7 @@ describe('RunClient', () => { const searchRunsExpId = await experimentClient.createExperiment( `Search Runs Test ${Date.now()}` ); + testIds.push(searchRunsExpId); const runA = (await runClient.createRun(searchRunsExpId)) as Run; await runClient.logMetric(runA.info.run_id, 'metric', 1.0); @@ -1067,16 +1030,47 @@ describe('RunClient', () => { // List artifacts for a run describe('listArtifacts', () => { test('- Should list artifacts with run_id', async () => { - const run = (await runClient.createRun(experimentId)) as Run; const artifacts = await runClient.listArtifacts(run.info.run_id); expect(artifacts).toHaveProperty('root_uri'); }); + test('- Should list artifacts with path parameter', async () => { + const path = 'test/path'; + const artifacts = await runClient.listArtifacts(run.info.run_id, path); + expect(artifacts).toHaveProperty('root_uri'); + }); + + test('- Should list artifacts with page_token', async () => { + const page_token = 'test_token'; + const artifacts = await runClient.listArtifacts( + run.info.run_id, + undefined, + page_token + ); + expect(artifacts).toHaveProperty('root_uri'); + }); + + test('- Should list artifacts with both path and page_token', async () => { + const path = 'test/path'; + const page_token = 'test_token'; + const artifacts = await runClient.listArtifacts( + run.info.run_id, + path, + page_token + ); + expect(artifacts).toHaveProperty('root_uri'); + }); + test('- Should handle errors and edge cases', async () => { // test invalid id const invalid_id = 'invalid_id'; - await expect(runClient.listArtifacts(invalid_id)).rejects.toThrow(); + await expect(runClient.listArtifacts(invalid_id)).rejects.toThrow( + expect.objectContaining({ + name: 'ApiError', + message: expect.stringMatching(/Error listing artifacts:.+/), + }) + ); }); }); }); diff --git a/mlflow/tests/RunManager.test.ts b/mlflow/tests/RunManager.test.ts index 31dcea17..4d425c7d 100644 --- a/mlflow/tests/RunManager.test.ts +++ b/mlflow/tests/RunManager.test.ts @@ -1,8 +1,9 @@ -import { describe, test, expect, beforeEach } from '@jest/globals'; +import { describe, test, expect, beforeEach, afterAll } from '@jest/globals'; import RunManager from '../src/workflows/RunManager'; import RunClient from '../src/tracking/RunClient'; import ExperimentClient from '../src/tracking/ExperimentClient'; import { Run, CleanupRuns, CopyRun } from '../src/utils/interface'; +import { TRACKING_SERVER_URI, TEST_DATA } from './testUtils'; describe('RunManager', () => { let runManager: RunManager; @@ -10,12 +11,23 @@ describe('RunManager', () => { let experimentClient: ExperimentClient; let experimentId: string; const runIds: string[] = []; + const experimentsToDelete: string[] = []; beforeEach(async () => { await new Promise((resolve) => setTimeout(resolve, 2000)); - runClient = new RunClient('http://127.0.0.1:5002'); - experimentClient = new ExperimentClient('http://127.0.0.1:5002'); - runManager = new RunManager('http://127.0.0.1:5002'); + runClient = new RunClient(TRACKING_SERVER_URI); + experimentClient = new ExperimentClient(TRACKING_SERVER_URI); + runManager = new RunManager(TRACKING_SERVER_URI); + }); + + afterAll(async () => { + for (const runId of runIds) { + await runClient.deleteRun(runId); + } + + for (const expId of experimentsToDelete) { + await experimentClient.deleteExperiment(expId); + } }); describe('cleanupRuns', () => { @@ -24,6 +36,7 @@ describe('RunManager', () => { experimentId = await experimentClient.createExperiment( `Testing ${timestamp}` ); + experimentsToDelete.push(experimentId); // create test runs const createRun = async (metricKey: string, metricValue: number) => { @@ -133,6 +146,7 @@ describe('RunManager', () => { targetExperimentId = await experimentClient.createExperiment( `Target Exp ${timestamp}` ); + experimentsToDelete.push(sourceExperimentId, targetExperimentId); // log data for original run const run = (await runClient.createRun(sourceExperimentId)) as Run; @@ -156,21 +170,9 @@ describe('RunManager', () => { const model_json = JSON.stringify(model); - await runClient.logBatch( - originalRunId, - [ - { key: 'metric-key1', value: 10, timestamp: 1694000700000 }, - { key: 'metric-key2', value: 20, timestamp: 1694000700000 }, - ], - [ - { key: 'param-key1', value: 'param-value1' }, - { key: 'param-key2', value: 'param-value2' }, - ], - [ - { key: 'tag-key1', value: 'tag-value1' }, - { key: 'tag-key2', value: 'tag-value2' }, - ] - ); + const { metrics, params, tags } = TEST_DATA; + + await runClient.logBatch(originalRunId, metrics, params, tags); await runClient.logInputs(originalRunId, datasets); await runClient.logModel(originalRunId, model_json); @@ -183,6 +185,8 @@ describe('RunManager', () => { }); test('- Should copy run from one experiment to another', async () => { + const { metrics } = TEST_DATA; + const result = (await runManager.copyRun( originalRunId, targetExperimentId @@ -195,25 +199,15 @@ describe('RunManager', () => { // fetch copied run and check the metrics const copiedRun = (await runClient.getRun(result.newRunId)) as Run; - expect(copiedRun.data.metrics).toEqual( - expect.arrayContaining([ - expect.objectContaining({ key: 'metric-key1', value: 10 }), - expect.objectContaining({ key: 'metric-key2', value: 20 }), - ]) - ); - - expect(copiedRun.data.params).toEqual( - expect.arrayContaining([ - expect.objectContaining({ key: 'param-key1', value: 'param-value1' }), - expect.objectContaining({ key: 'param-key2', value: 'param-value2' }), - ]) + const metricMatchers = metrics.map((metric) => + expect.objectContaining({ + key: metric.key, + value: metric.value, + }) ); - expect(copiedRun.data.tags).toEqual( - expect.arrayContaining([ - expect.objectContaining({ key: 'tag-key1', value: 'tag-value1' }), - expect.objectContaining({ key: 'tag-key2', value: 'tag-value2' }), - ]) + expect(copiedRun.data.metrics).toEqual( + expect.arrayContaining(metricMatchers) ); }); }); diff --git a/mlflow/tests/testUtils.ts b/mlflow/tests/testUtils.ts index e44c5bba..62991ea3 100644 --- a/mlflow/tests/testUtils.ts +++ b/mlflow/tests/testUtils.ts @@ -1,7 +1,8 @@ import ExperimentClient from '../src/tracking/ExperimentClient'; -import { Experiment } from '../src/utils/interface'; +import { Experiment, Metrics, Params, Tags } from '../src/utils/interface'; export const TRACKING_SERVER_URI: string = 'http://127.0.0.1:5002'; + export const experimentProperties: string[] = [ 'experiment_id', 'name', @@ -10,6 +11,7 @@ export const experimentProperties: string[] = [ 'last_update_time', 'creation_time', ]; + export const runProperties: string[] = [ 'run_id', 'run_uuid', @@ -21,6 +23,7 @@ export const runProperties: string[] = [ 'artifact_uri', 'lifecycle_stage', ]; + export type ExpSearchResults = { experiments?: Experiment[]; next_page_token?: string; @@ -28,6 +31,38 @@ export type ExpSearchResults = { const experimentClient = new ExperimentClient(TRACKING_SERVER_URI); +export const TEST_DATA = { + metrics: [ + { key: 'accuracy', value: 0.83, timestamp: 1694000700000 }, + { key: 'loss', value: 0.18, timestamp: 1694000700000 }, + ] as Metrics[], + params: [ + { key: 'learning_rate', value: '0.0001' }, + { key: 'batch_size', value: '256' }, + ] as Params[], + tags: [ + { key: 'model_type', value: 'GradientBoosting' }, + { key: 'data_version', value: 'v1.7' }, + ] as Tags[], + validModel: { + artifact_path: 'pytorch_dnn', + flavors: { + python_function: { + env: 'conda.yaml', + loader_module: 'mlflow.pytorch', + model_path: 'model.pth', + python_version: '3.8.10', + }, + pytorch: { + model_data: 'model.pth', + pytorch_version: '1.9.0', + code: 'model-code', + }, + }, + utc_time_created: '2023-09-14 10:15:00.000000', + }, +}; + export const createTestExperiment = async ( prefix = 'Test experiment' ): Promise => {