Skip to content

Commit 8bb325a

Browse files
Merge pull request #107 from oslabs-beta/yiqun/new
feat: add lib, revise gitignore, revise example
2 parents 09d2c64 + d525032 commit 8bb325a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2590
-48
lines changed

mlflow/.gitignore

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# Dependencies
22
node_modules/
33

4+
# Environment variables
5+
.env
6+
7+
# IDE
8+
.vscode
9+
410
# Misc
511
.DS_Store
612
../.DS_Store
@@ -9,8 +15,12 @@ node_modules/
915
.venv/
1016
mlruns/
1117

12-
# Build output
13-
lib/
14-
1518
# Temporary files
16-
temp/
19+
temp/
20+
21+
# Test coverage
22+
coverage/
23+
24+
25+
26+
=

mlflow/examples/LinearRegressionExample.js

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/**
22
* @fileoverview Example of using MLflow.js for machine learning projects
3-
* The example creates a synthetic dataset and trains a linear regression model using TensorFlow.js,
4-
* tracking various metrics and parameters throughout the training process with MLflow tracking server.
3+
* The example creates a synthetic dataset and trains a linear regression model using TensorFlow.js,
4+
* tracking various metrics, hyperparameters and other meta data throughout the training process with MLflow tracking server.
55
* It showcases:
66
* - Experiment and run management
77
* - Hyperparameter logging
@@ -12,19 +12,28 @@
1212
* - Artifact storage
1313
*
1414
* @requires @tensorflow/tfjs-node
15+
* @requires @mlflow.js
1516
*
16-
* @note Ensure MLflow server is running at http://localhost:5001 before executing
17+
* @note Ensure MLflow server is running before executing
1718
*/
1819

1920
import * as tf from '@tensorflow/tfjs-node';
20-
import MLflow from 'mlflow-js';
21+
import Mlflow from 'mlflow-js';
2122
import { fileURLToPath } from 'url';
22-
import { dirname } from 'path';
23+
import { dirname, join } from 'path';
24+
import dotenv from 'dotenv';
2325

24-
const mlflow = new MLflow('http://localhost:5001');
26+
const __filename = fileURLToPath(import.meta.url);
27+
const __dirname = dirname(__filename);
28+
dotenv.config({ path: join(__dirname, '../.env') });
29+
30+
const MLFLOW_TRACKING_URI =
31+
process.env.MLFLOW_TRACKING_URI || 'http://localhost:5001';
32+
const mlflow = new Mlflow(MLFLOW_TRACKING_URI);
33+
console.log('MLflow server URL:', MLFLOW_TRACKING_URI);
2534

2635
const HYPERPARAMETERS = {
27-
learningRate: 0.1,
36+
learningRate: 0.25,
2837
epochs: 10,
2938
batchSize: 32,
3039
validationSplit: 0.2,
@@ -168,13 +177,22 @@ async function main() {
168177
console.log(`MLflow Experiment ID: ${experimentId}`);
169178

170179
// Create run
171-
const run = await mlflow.createRun(
172-
experimentId,
173-
'Simple Linear Regression'
174-
);
180+
const run = await mlflow.createRun(experimentId);
175181
const runId = run.info.run_id;
176182
console.log(`MLflow Run ID: ${runId}`);
177183

184+
// Generate and prepare data
185+
console.log('Generating data...');
186+
const { xTrain, yTrain, xTest, yTest } = generateData();
187+
188+
// Log dataset info
189+
const datasetInfo = {
190+
dataset: {
191+
name: 'synthetic_linear_regression_data',
192+
},
193+
};
194+
await mlflow.logInputs(runId, [datasetInfo]);
195+
178196
// Log hyperparameters
179197
const params = [
180198
{ key: 'learning_rate', value: HYPERPARAMETERS.learningRate.toString() },
@@ -189,18 +207,6 @@ async function main() {
189207
];
190208
await mlflow.logBatch(runId, undefined, params);
191209

192-
// Log dataset info
193-
const datasetInfo = {
194-
dataset: {
195-
name: 'synthetic_linear_regression_data',
196-
},
197-
};
198-
await mlflow.logInputs(runId, [datasetInfo]);
199-
200-
// Generate and prepare data
201-
console.log('Generating data...');
202-
const { xTrain, yTrain, xTest, yTest } = generateData();
203-
204210
// Train model
205211
console.log('Creating and training model...');
206212
const model = createModel();
@@ -209,12 +215,6 @@ async function main() {
209215
// Evaluate model
210216
const evaluation = evaluateModel(model, xTest, yTest);
211217

212-
// Save model artifects
213-
const __filename = fileURLToPath(import.meta.url);
214-
const __dirname = dirname(__filename);
215-
const artifactsPath = `${__dirname}/../mlruns/${experimentId}/${runId}/artifacts`;
216-
await model.save(`file://${artifactsPath}`);
217-
218218
// Log evaluation metrics
219219
const finalMetrics = [
220220
{ key: 'test_mse', value: evaluation.metrics.mse, timestamp: Date.now() },
@@ -235,6 +235,12 @@ async function main() {
235235
];
236236
await mlflow.logBatch(runId, undefined, undefined, paramTags);
237237

238+
// Save model artifects
239+
const __filename = fileURLToPath(import.meta.url);
240+
const __dirname = dirname(__filename);
241+
const artifactsPath = `${__dirname}/../mlruns/${experimentId}/${runId}/artifacts`;
242+
await model.save(`file://${artifactsPath}`);
243+
238244
// Register model if performance meets threshold
239245
if (evaluation.metrics.r2 > 0.9) {
240246
const modelName = 'LinearRegression';
@@ -259,7 +265,10 @@ async function main() {
259265
`runs:/${runId}/model`,
260266
runId,
261267
[
262-
{ key: 'r2', value: evaluation.metrics.r2.toString() },
268+
{
269+
key: 'learning_rate',
270+
value: HYPERPARAMETERS.learningRate.toString(),
271+
},
263272
{ key: 'rmse', value: evaluation.metrics.rmse.toString() },
264273
]
265274
);
@@ -277,13 +286,6 @@ async function main() {
277286
}
278287
}
279288

280-
// Log additional metadata
281-
const tags = [
282-
{ key: 'model_type', value: 'linear_regression' },
283-
{ key: 'data_source', value: 'synthetic' },
284-
];
285-
await mlflow.logBatch(runId, undefined, undefined, tags);
286-
287289
// Finish run
288290
await mlflow.updateRun(runId, 'FINISHED');
289291

mlflow/examples/NeuralNetworkHyperparameterTuning.js

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ const mlflow = new MLflow('http://localhost:5001');
77

88
const HYPERPARAMETER_SPACE = {
99
networkArchitectures: [
10-
[16, 8], // Small network
11-
[32, 16], // Medium network
12-
[64, 32], // Larger network
10+
[16, 8],
11+
[32, 16],
12+
[64, 32],
1313
],
1414
learningRates: [0.001, 0.01],
1515
batchSizes: [32, 64],
@@ -23,7 +23,7 @@ const TRAINING_CONFIG = {
2323
datasetSize: 2000,
2424
inputFeatures: 5,
2525
outputClasses: 3,
26-
minibatchSize: 128, // Added for faster training
26+
minibatchSize: 128,
2727
};
2828

2929
// Data generation
@@ -252,7 +252,7 @@ async function main() {
252252
try {
253253
console.time('Total Execution Time');
254254

255-
const experimentName = 'Neural_Network_Hyperparameter_Tuning_Fast';
255+
const experimentName = 'Neural_Network_Hyperparameter_Tuning';
256256
let experimentId;
257257
try {
258258
const experiment = await mlflow.getExperimentByName(experimentName);

mlflow/lib/index.d.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import Mlflow from './mlflow.js';
2+
export default Mlflow;

mlflow/lib/index.js

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mlflow/lib/index.js.map

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mlflow/lib/mlflow.d.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import ExperimentClient from './tracking/ExperimentClient.js';
2+
import ExperimentManager from './workflows/ExperimentManager.js';
3+
import RunClient from './tracking/RunClient.js';
4+
import RunManager from './workflows/RunManager.js';
5+
import ModelRegistryClient from './model-registry/ModelRegistryClient.js';
6+
import ModelVersionClient from './model-registry/ModelVersionClient.js';
7+
import ModelManager from './workflows/ModelManager.js';
8+
declare class Mlflow {
9+
private components;
10+
constructor(trackingUri: string);
11+
private initializeMethods;
12+
getExperimentClient(): ExperimentClient;
13+
getRunClient(): RunClient;
14+
getModelRegistryClient(): ModelRegistryClient;
15+
getModelVersionClient(): ModelVersionClient;
16+
getExperimentManager(): ExperimentManager;
17+
getRunManager(): RunManager;
18+
getModelManager(): ModelManager;
19+
}
20+
export default Mlflow;

mlflow/lib/mlflow.js

Lines changed: 56 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mlflow/lib/mlflow.js.map

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)