Skip to content

Commit b18cf18

Browse files
committed
code change for demo
1 parent 0abdf8d commit b18cf18

File tree

1 file changed

+26
-29
lines changed

1 file changed

+26
-29
lines changed

mlflow/examples/LinearRegressionExample.js

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/**
22
* @fileoverview Example of using MLflow.js for machine learning projects
3-
* The example creates a synthetic dataset and trains a linear regression model using TensorFlow.js,
4-
* tracking various metrics and parameters throughout the training process with MLflow tracking server.
3+
* The example creates a synthetic dataset and trains a linear regression model using TensorFlow.js,
4+
* tracking various metrics, hyperparameters and other meta data throughout the training process with MLflow tracking server.
55
* It showcases:
66
* - Experiment and run management
77
* - Hyperparameter logging
@@ -12,6 +12,7 @@
1212
* - Artifact storage
1313
*
1414
* @requires @tensorflow/tfjs-node
15+
* @requires @mlflow.js
1516
*
1617
* @note Ensure MLflow server is running at http://localhost:5001 before executing
1718
*/
@@ -24,7 +25,7 @@ import { dirname } from 'path';
2425
const mlflow = new MLflow('http://localhost:5001');
2526

2627
const HYPERPARAMETERS = {
27-
learningRate: 0.1,
28+
learningRate: 0.25,
2829
epochs: 10,
2930
batchSize: 32,
3031
validationSplit: 0.2,
@@ -175,6 +176,18 @@ async function main() {
175176
const runId = run.info.run_id;
176177
console.log(`MLflow Run ID: ${runId}`);
177178

179+
// Generate and prepare data
180+
console.log('Generating data...');
181+
const { xTrain, yTrain, xTest, yTest } = generateData();
182+
183+
// Log dataset info
184+
const datasetInfo = {
185+
dataset: {
186+
name: 'synthetic_linear_regression_data',
187+
},
188+
};
189+
await mlflow.logInputs(runId, [datasetInfo]);
190+
178191
// Log hyperparameters
179192
const params = [
180193
{ key: 'learning_rate', value: HYPERPARAMETERS.learningRate.toString() },
@@ -189,18 +202,6 @@ async function main() {
189202
];
190203
await mlflow.logBatch(runId, undefined, params);
191204

192-
// Log dataset info
193-
const datasetInfo = {
194-
dataset: {
195-
name: 'synthetic_linear_regression_data',
196-
},
197-
};
198-
await mlflow.logInputs(runId, [datasetInfo]);
199-
200-
// Generate and prepare data
201-
console.log('Generating data...');
202-
const { xTrain, yTrain, xTest, yTest } = generateData();
203-
204205
// Train model
205206
console.log('Creating and training model...');
206207
const model = createModel();
@@ -209,12 +210,6 @@ async function main() {
209210
// Evaluate model
210211
const evaluation = evaluateModel(model, xTest, yTest);
211212

212-
// Save model artifects
213-
const __filename = fileURLToPath(import.meta.url);
214-
const __dirname = dirname(__filename);
215-
const artifactsPath = `${__dirname}/../mlruns/${experimentId}/${runId}/artifacts`;
216-
await model.save(`file://${artifactsPath}`);
217-
218213
// Log evaluation metrics
219214
const finalMetrics = [
220215
{ key: 'test_mse', value: evaluation.metrics.mse, timestamp: Date.now() },
@@ -235,6 +230,12 @@ async function main() {
235230
];
236231
await mlflow.logBatch(runId, undefined, undefined, paramTags);
237232

233+
// Save model artifects
234+
const __filename = fileURLToPath(import.meta.url);
235+
const __dirname = dirname(__filename);
236+
const artifactsPath = `${__dirname}/../mlruns/${experimentId}/${runId}/artifacts`;
237+
await model.save(`file://${artifactsPath}`);
238+
238239
// Register model if performance meets threshold
239240
if (evaluation.metrics.r2 > 0.9) {
240241
const modelName = 'LinearRegression';
@@ -259,7 +260,10 @@ async function main() {
259260
`runs:/${runId}/model`,
260261
runId,
261262
[
262-
{ key: 'r2', value: evaluation.metrics.r2.toString() },
263+
{
264+
key: 'learning_rate',
265+
value: HYPERPARAMETERS.learningRate.toString(),
266+
},
263267
{ key: 'rmse', value: evaluation.metrics.rmse.toString() },
264268
]
265269
);
@@ -277,13 +281,6 @@ async function main() {
277281
}
278282
}
279283

280-
// Log additional metadata
281-
const tags = [
282-
{ key: 'model_type', value: 'linear_regression' },
283-
{ key: 'data_source', value: 'synthetic' },
284-
];
285-
await mlflow.logBatch(runId, undefined, undefined, tags);
286-
287284
// Finish run
288285
await mlflow.updateRun(runId, 'FINISHED');
289286

0 commit comments

Comments
 (0)