Skip to content

Commit 482df3a

Browse files
committed
merging latest changes into austin-vimeo
Merge branch 'dev' into austin-vimeo
2 parents a83ead7 + 10d0cf3 commit 482df3a

15 files changed

+2055
-372
lines changed

mlflow-site/src/app/components/Button.tsx

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,8 @@
33
const Button = () => {
44
return (
55
<div className='button'>
6-
<button
7-
onClick={() => {
8-
window.location.assign('https://github.com/oslabs-beta/mlflow-js');
9-
}}
10-
className='homeButton homeButtonDownload text-white'
11-
>
12-
Download
13-
</button>
14-
<button
15-
onClick={() => {
16-
window.location.assign('https://github.com/oslabs-beta/mlflow-js/tree/dev/mlflow/docs');
17-
}}
18-
className='homeButton homeButtonRead'
19-
>
20-
Read the Docs
21-
</button>
6+
<a href='https://github.com/oslabs-beta/mlflow-js' className='homeButton homeButtonDownload text-white'>Download</a>
7+
<a href='https://github.com/oslabs-beta/mlflow-js/tree/dev/mlflow/docs' className='homeButton homeButtonRead'>Read the Docs</a>
228
</div>
239
);
2410
};

mlflow-site/src/app/globals.css

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,14 @@ body {
175175

176176
.homeButtonDownload {
177177
background-color: rgb(66, 107, 31);
178+
padding-top: 0.6rem;
179+
padding-bottom: 0.6rem;
178180
}
179181

180182
.homeButtonRead {
181183
background-color: rgb(204, 204, 204);
184+
padding-top: 0.6rem;
185+
padding-bottom: 0.6rem;
182186
}
183187

184188
.button {

mlflow/.gitignore

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,4 @@ mlruns/
1313
lib/
1414

1515
# Temporary files
16-
temp/
17-
18-
package-lock.json
16+
temp/

mlflow/docs/tech-design-doc.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
## 1. Overview
55

66

7-
The MLflow-JS library aims to provide a comprehensive and intuitive interface for interacting with MLflow services in a JavaScript/Typescript environment. It simplifies the integration with MLflow REST APIs by offering high-level methods for tracking, model management, and advanced workflows.
7+
The MLflow-JS library aims to provide a comprehensive and intuitive interface for interacting with MLflow services in a JavaScript/TypeScript environment. It simplifies the integration with MLflow REST APIs by offering high-level methods for tracking, model management, and advanced workflows.
88

99

1010
## 2. Architecture
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
/**
2+
* @fileoverview Example of using MLflow.js for machine learning projects
3+
* The example creates a synthetic dataset and trains a linear regression model using TensorFlow.js,
4+
* tracking various metrics and parameters throughout the training process with MLflow tracking server.
5+
* It showcases:
6+
* - Experiment and run management
7+
* - Hyperparameter logging
8+
* - Metric tracking
9+
* - Tagging
10+
* - Model registry
11+
* - Model versioning
12+
* - Artifact storage
13+
*
14+
* @requires @tensorflow/tfjs-node
15+
*
16+
* @note Ensure MLflow server is running at http://localhost:5001 before executing
17+
*/
18+
19+
import * as tf from '@tensorflow/tfjs-node';
20+
import MLflow from 'mlflow-js';
21+
import { fileURLToPath } from 'url';
22+
import { dirname } from 'path';
23+
24+
const mlflow = new MLflow('http://localhost:5001');
25+
26+
const HYPERPARAMETERS = {
27+
learningRate: 0.1,
28+
epochs: 10,
29+
batchSize: 32,
30+
validationSplit: 0.2,
31+
};
32+
33+
const DATA_CONFIG = {
34+
numSamples: 1000,
35+
trainTestSplit: 0.7,
36+
trueSlope: 2,
37+
noiseStd: 0.1,
38+
};
39+
40+
function generateData() {
41+
// Generate synthetic data: y = 2x + noise
42+
const x = tf.randomUniform([DATA_CONFIG.numSamples, 1]);
43+
const y = tf.tidy(() => {
44+
const trueValues = x.mul(DATA_CONFIG.trueSlope);
45+
const noise = tf.randomNormal(
46+
[DATA_CONFIG.numSamples, 1],
47+
0,
48+
DATA_CONFIG.noiseStd
49+
);
50+
return trueValues.add(noise);
51+
});
52+
53+
// Split into train and test sets
54+
const splitIdx = Math.floor(
55+
DATA_CONFIG.numSamples * DATA_CONFIG.trainTestSplit
56+
);
57+
const xTrain = x.slice([0, 0], [splitIdx, 1]);
58+
const xTest = x.slice([splitIdx, 0], [-1, 1]);
59+
const yTrain = y.slice([0, 0], [splitIdx, 1]);
60+
const yTest = y.slice([splitIdx, 0], [-1, 1]);
61+
62+
return { xTrain, yTrain, xTest, yTest };
63+
}
64+
65+
function createModel() {
66+
const model = tf.sequential({
67+
layers: [tf.layers.dense({ units: 1, inputShape: [1] })],
68+
});
69+
70+
model.compile({
71+
optimizer: tf.train.sgd(HYPERPARAMETERS.learningRate),
72+
loss: 'meanSquaredError',
73+
metrics: ['mae'],
74+
});
75+
76+
return model;
77+
}
78+
79+
async function trainModel(model, xTrain, yTrain, runId) {
80+
const history = await model.fit(xTrain, yTrain, {
81+
epochs: HYPERPARAMETERS.epochs,
82+
batchSize: HYPERPARAMETERS.batchSize,
83+
validationSplit: HYPERPARAMETERS.validationSplit,
84+
verbose: 1,
85+
callbacks: {
86+
onEpochEnd: async (epoch, logs) => {
87+
// Log metrics for both training and validation
88+
const metrics = [
89+
{
90+
key: 'train_loss',
91+
value: logs.loss,
92+
timestamp: Date.now(),
93+
step: epoch,
94+
},
95+
{
96+
key: 'train_mae',
97+
value: logs.mae,
98+
timestamp: Date.now(),
99+
step: epoch,
100+
},
101+
{
102+
key: 'val_loss',
103+
value: logs.val_loss,
104+
timestamp: Date.now(),
105+
step: epoch,
106+
},
107+
{
108+
key: 'val_mae',
109+
value: logs.val_mae,
110+
timestamp: Date.now(),
111+
step: epoch,
112+
},
113+
];
114+
115+
await mlflow.logBatch(runId, metrics);
116+
},
117+
},
118+
});
119+
120+
return history;
121+
}
122+
123+
function evaluateModel(model, xTest, yTest) {
124+
return tf.tidy(() => {
125+
const yPred = model.predict(xTest);
126+
127+
// Evaluation Metrics
128+
const mse = tf.metrics.meanSquaredError(yTest, yPred).dataSync()[0];
129+
const rmse = Math.sqrt(mse);
130+
const mae = tf.metrics.meanAbsoluteError(yTest, yPred).dataSync()[0];
131+
132+
// R-squared
133+
const yMean = tf.mean(yTest);
134+
const totalSS = yTest.sub(yMean).square().sum().dataSync()[0];
135+
const residualSS = yTest.sub(yPred).square().sum().dataSync()[0];
136+
const r2 = 1 - residualSS / totalSS;
137+
138+
// Model Parameters
139+
const weight = model.layers[0].getWeights()[0].dataSync()[0];
140+
const bias = model.layers[0].getWeights()[1].dataSync()[0];
141+
142+
return {
143+
metrics: {
144+
mse,
145+
rmse,
146+
mae,
147+
r2,
148+
},
149+
parameters: {
150+
weight,
151+
bias,
152+
},
153+
};
154+
});
155+
}
156+
157+
async function main() {
158+
try {
159+
// Initialize experiment
160+
const experimentName = 'Linear_Regression_Example';
161+
let experimentId;
162+
try {
163+
const experiment = await mlflow.getExperimentByName(experimentName);
164+
experimentId = experiment.experiment_id;
165+
} catch {
166+
experimentId = await mlflow.createExperiment(experimentName);
167+
}
168+
console.log(`MLflow Experiment ID: ${experimentId}`);
169+
170+
// Create run
171+
const run = await mlflow.createRun(
172+
experimentId,
173+
'Simple Linear Regression'
174+
);
175+
const runId = run.info.run_id;
176+
console.log(`MLflow Run ID: ${runId}`);
177+
178+
// Log hyperparameters
179+
const params = [
180+
{ key: 'learning_rate', value: HYPERPARAMETERS.learningRate.toString() },
181+
{ key: 'epochs', value: HYPERPARAMETERS.epochs.toString() },
182+
{ key: 'batch_size', value: HYPERPARAMETERS.batchSize.toString() },
183+
{
184+
key: 'validation_split',
185+
value: HYPERPARAMETERS.validationSplit.toString(),
186+
},
187+
{ key: 'train_test_split', value: DATA_CONFIG.trainTestSplit.toString() },
188+
{ key: 'num_samples', value: DATA_CONFIG.numSamples.toString() },
189+
];
190+
await mlflow.logBatch(runId, undefined, params);
191+
192+
// Log dataset info
193+
const datasetInfo = {
194+
dataset: {
195+
name: 'synthetic_linear_regression_data',
196+
},
197+
};
198+
await mlflow.logInputs(runId, [datasetInfo]);
199+
200+
// Generate and prepare data
201+
console.log('Generating data...');
202+
const { xTrain, yTrain, xTest, yTest } = generateData();
203+
204+
// Train model
205+
console.log('Creating and training model...');
206+
const model = createModel();
207+
const history = await trainModel(model, xTrain, yTrain, runId);
208+
209+
// Evaluate model
210+
const evaluation = evaluateModel(model, xTest, yTest);
211+
212+
// Save model artifects
213+
const __filename = fileURLToPath(import.meta.url);
214+
const __dirname = dirname(__filename);
215+
const artifactsPath = `${__dirname}/../mlruns/${experimentId}/${runId}/artifacts`;
216+
await model.save(`file://${artifactsPath}`);
217+
218+
// Log evaluation metrics
219+
const finalMetrics = [
220+
{ key: 'test_mse', value: evaluation.metrics.mse, timestamp: Date.now() },
221+
{
222+
key: 'test_rmse',
223+
value: evaluation.metrics.rmse,
224+
timestamp: Date.now(),
225+
},
226+
{ key: 'test_mae', value: evaluation.metrics.mae, timestamp: Date.now() },
227+
{ key: 'test_r2', value: evaluation.metrics.r2, timestamp: Date.now() },
228+
];
229+
await mlflow.logBatch(runId, finalMetrics);
230+
231+
// Log model parameters
232+
const paramTags = [
233+
{ key: 'model_weight', value: evaluation.parameters.weight.toString() },
234+
{ key: 'model_bias', value: evaluation.parameters.bias.toString() },
235+
];
236+
await mlflow.logBatch(runId, undefined, undefined, paramTags);
237+
238+
// Register model if performance meets threshold
239+
if (evaluation.metrics.r2 > 0.9) {
240+
const modelName = 'LinearRegression';
241+
try {
242+
let modelExists = true;
243+
try {
244+
await mlflow.getRegisteredModel(modelName);
245+
} catch (err) {
246+
modelExists = false;
247+
}
248+
249+
if (!modelExists) {
250+
await mlflow.createRegisteredModel(
251+
modelName,
252+
[{ key: 'task', value: 'regression' }],
253+
'Simple linear regression model'
254+
);
255+
}
256+
257+
const modelVersion = await mlflow.createModelVersion(
258+
modelName,
259+
`runs:/${runId}/model`,
260+
runId,
261+
[
262+
{ key: 'r2', value: evaluation.metrics.r2.toString() },
263+
{ key: 'rmse', value: evaluation.metrics.rmse.toString() },
264+
]
265+
);
266+
267+
if (evaluation.metrics.r2 > 0.95) {
268+
await mlflow.transitionModelVersionStage(
269+
modelName,
270+
modelVersion.version,
271+
'staging',
272+
true
273+
);
274+
}
275+
} catch (e) {
276+
console.error('Model registration error:', e.message);
277+
}
278+
}
279+
280+
// Log additional metadata
281+
const tags = [
282+
{ key: 'model_type', value: 'linear_regression' },
283+
{ key: 'data_source', value: 'synthetic' },
284+
];
285+
await mlflow.logBatch(runId, undefined, undefined, tags);
286+
287+
// Finish run
288+
await mlflow.updateRun(runId, 'FINISHED');
289+
290+
// Cleanup
291+
tf.dispose([xTrain, yTrain, xTest, yTest]);
292+
293+
console.log('\nMLflow tracking completed successfully!');
294+
console.log(
295+
`View run details at http://localhost:5001/#/experiments/${experimentId}/runs/${runId}`
296+
);
297+
} catch (error) {
298+
console.error('Error:', error);
299+
}
300+
}
301+
302+
main();

0 commit comments

Comments
 (0)