11/**
22 * @fileoverview Example of using MLflow.js for machine learning projects
3- * The example creates a synthetic dataset and trains a linear regression model using TensorFlow.js,
4- * tracking various metrics and parameters throughout the training process with MLflow tracking server.
3+ * The example creates a synthetic dataset and trains a linear regression model using TensorFlow.js,
4+ * tracking various metrics, hyperparameters and other meta data throughout the training process with MLflow tracking server.
55 * It showcases:
66 * - Experiment and run management
77 * - Hyperparameter logging
1212 * - Artifact storage
1313 *
1414 * @requires @tensorflow /tfjs-node
15+ * @requires @mlflow .js
1516 *
16- * @note Ensure MLflow server is running at http://localhost:5001 before executing
17+ * @note Ensure MLflow server is running before executing
1718 */
1819
1920import * as tf from '@tensorflow/tfjs-node' ;
20- import MLflow from 'mlflow-js' ;
21+ import Mlflow from 'mlflow-js' ;
2122import { fileURLToPath } from 'url' ;
22- import { dirname } from 'path' ;
23+ import { dirname , join } from 'path' ;
24+ import dotenv from 'dotenv' ;
2325
24- const mlflow = new MLflow ( 'http://localhost:5001' ) ;
26+ const __filename = fileURLToPath ( import . meta. url ) ;
27+ const __dirname = dirname ( __filename ) ;
28+ dotenv . config ( { path : join ( __dirname , '../.env' ) } ) ;
29+
30+ const MLFLOW_TRACKING_URI =
31+ process . env . MLFLOW_TRACKING_URI || 'http://localhost:5001' ;
32+ const mlflow = new Mlflow ( MLFLOW_TRACKING_URI ) ;
33+ console . log ( 'MLflow server URL:' , MLFLOW_TRACKING_URI ) ;
2534
2635const HYPERPARAMETERS = {
27- learningRate : 0.1 ,
36+ learningRate : 0.25 ,
2837 epochs : 10 ,
2938 batchSize : 32 ,
3039 validationSplit : 0.2 ,
@@ -168,13 +177,22 @@ async function main() {
168177 console . log ( `MLflow Experiment ID: ${ experimentId } ` ) ;
169178
170179 // Create run
171- const run = await mlflow . createRun (
172- experimentId ,
173- 'Simple Linear Regression'
174- ) ;
180+ const run = await mlflow . createRun ( experimentId ) ;
175181 const runId = run . info . run_id ;
176182 console . log ( `MLflow Run ID: ${ runId } ` ) ;
177183
184+ // Generate and prepare data
185+ console . log ( 'Generating data...' ) ;
186+ const { xTrain, yTrain, xTest, yTest } = generateData ( ) ;
187+
188+ // Log dataset info
189+ const datasetInfo = {
190+ dataset : {
191+ name : 'synthetic_linear_regression_data' ,
192+ } ,
193+ } ;
194+ await mlflow . logInputs ( runId , [ datasetInfo ] ) ;
195+
178196 // Log hyperparameters
179197 const params = [
180198 { key : 'learning_rate' , value : HYPERPARAMETERS . learningRate . toString ( ) } ,
@@ -189,18 +207,6 @@ async function main() {
189207 ] ;
190208 await mlflow . logBatch ( runId , undefined , params ) ;
191209
192- // Log dataset info
193- const datasetInfo = {
194- dataset : {
195- name : 'synthetic_linear_regression_data' ,
196- } ,
197- } ;
198- await mlflow . logInputs ( runId , [ datasetInfo ] ) ;
199-
200- // Generate and prepare data
201- console . log ( 'Generating data...' ) ;
202- const { xTrain, yTrain, xTest, yTest } = generateData ( ) ;
203-
204210 // Train model
205211 console . log ( 'Creating and training model...' ) ;
206212 const model = createModel ( ) ;
@@ -209,12 +215,6 @@ async function main() {
209215 // Evaluate model
210216 const evaluation = evaluateModel ( model , xTest , yTest ) ;
211217
212- // Save model artifects
213- const __filename = fileURLToPath ( import . meta. url ) ;
214- const __dirname = dirname ( __filename ) ;
215- const artifactsPath = `${ __dirname } /../mlruns/${ experimentId } /${ runId } /artifacts` ;
216- await model . save ( `file://${ artifactsPath } ` ) ;
217-
218218 // Log evaluation metrics
219219 const finalMetrics = [
220220 { key : 'test_mse' , value : evaluation . metrics . mse , timestamp : Date . now ( ) } ,
@@ -235,6 +235,12 @@ async function main() {
235235 ] ;
236236 await mlflow . logBatch ( runId , undefined , undefined , paramTags ) ;
237237
238+ // Save model artifects
239+ const __filename = fileURLToPath ( import . meta. url ) ;
240+ const __dirname = dirname ( __filename ) ;
241+ const artifactsPath = `${ __dirname } /../mlruns/${ experimentId } /${ runId } /artifacts` ;
242+ await model . save ( `file://${ artifactsPath } ` ) ;
243+
238244 // Register model if performance meets threshold
239245 if ( evaluation . metrics . r2 > 0.9 ) {
240246 const modelName = 'LinearRegression' ;
@@ -259,7 +265,10 @@ async function main() {
259265 `runs:/${ runId } /model` ,
260266 runId ,
261267 [
262- { key : 'r2' , value : evaluation . metrics . r2 . toString ( ) } ,
268+ {
269+ key : 'learning_rate' ,
270+ value : HYPERPARAMETERS . learningRate . toString ( ) ,
271+ } ,
263272 { key : 'rmse' , value : evaluation . metrics . rmse . toString ( ) } ,
264273 ]
265274 ) ;
@@ -277,13 +286,6 @@ async function main() {
277286 }
278287 }
279288
280- // Log additional metadata
281- const tags = [
282- { key : 'model_type' , value : 'linear_regression' } ,
283- { key : 'data_source' , value : 'synthetic' } ,
284- ] ;
285- await mlflow . logBatch ( runId , undefined , undefined , tags ) ;
286-
287289 // Finish run
288290 await mlflow . updateRun ( runId , 'FINISHED' ) ;
289291
0 commit comments