1
1
/**
2
2
* @fileoverview Example of using MLflow.js for machine learning projects
3
- * The example creates a synthetic dataset and trains a linear regression model using TensorFlow.js,
4
- * tracking various metrics and parameters throughout the training process with MLflow tracking server.
3
+ * The example creates a synthetic dataset and trains a linear regression model using TensorFlow.js,
4
+ * tracking various metrics, hyperparameters and other meta data throughout the training process with MLflow tracking server.
5
5
* It showcases:
6
6
* - Experiment and run management
7
7
* - Hyperparameter logging
12
12
* - Artifact storage
13
13
*
14
14
* @requires @tensorflow /tfjs-node
15
+ * @requires @mlflow .js
15
16
*
16
17
* @note Ensure MLflow server is running at http://localhost:5001 before executing
17
18
*/
@@ -24,7 +25,7 @@ import { dirname } from 'path';
24
25
const mlflow = new MLflow ( 'http://localhost:5001' ) ;
25
26
26
27
const HYPERPARAMETERS = {
27
- learningRate : 0.1 ,
28
+ learningRate : 0.25 ,
28
29
epochs : 10 ,
29
30
batchSize : 32 ,
30
31
validationSplit : 0.2 ,
@@ -175,6 +176,18 @@ async function main() {
175
176
const runId = run . info . run_id ;
176
177
console . log ( `MLflow Run ID: ${ runId } ` ) ;
177
178
179
+ // Generate and prepare data
180
+ console . log ( 'Generating data...' ) ;
181
+ const { xTrain, yTrain, xTest, yTest } = generateData ( ) ;
182
+
183
+ // Log dataset info
184
+ const datasetInfo = {
185
+ dataset : {
186
+ name : 'synthetic_linear_regression_data' ,
187
+ } ,
188
+ } ;
189
+ await mlflow . logInputs ( runId , [ datasetInfo ] ) ;
190
+
178
191
// Log hyperparameters
179
192
const params = [
180
193
{ key : 'learning_rate' , value : HYPERPARAMETERS . learningRate . toString ( ) } ,
@@ -189,18 +202,6 @@ async function main() {
189
202
] ;
190
203
await mlflow . logBatch ( runId , undefined , params ) ;
191
204
192
- // Log dataset info
193
- const datasetInfo = {
194
- dataset : {
195
- name : 'synthetic_linear_regression_data' ,
196
- } ,
197
- } ;
198
- await mlflow . logInputs ( runId , [ datasetInfo ] ) ;
199
-
200
- // Generate and prepare data
201
- console . log ( 'Generating data...' ) ;
202
- const { xTrain, yTrain, xTest, yTest } = generateData ( ) ;
203
-
204
205
// Train model
205
206
console . log ( 'Creating and training model...' ) ;
206
207
const model = createModel ( ) ;
@@ -209,12 +210,6 @@ async function main() {
209
210
// Evaluate model
210
211
const evaluation = evaluateModel ( model , xTest , yTest ) ;
211
212
212
- // Save model artifects
213
- const __filename = fileURLToPath ( import . meta. url ) ;
214
- const __dirname = dirname ( __filename ) ;
215
- const artifactsPath = `${ __dirname } /../mlruns/${ experimentId } /${ runId } /artifacts` ;
216
- await model . save ( `file://${ artifactsPath } ` ) ;
217
-
218
213
// Log evaluation metrics
219
214
const finalMetrics = [
220
215
{ key : 'test_mse' , value : evaluation . metrics . mse , timestamp : Date . now ( ) } ,
@@ -235,6 +230,12 @@ async function main() {
235
230
] ;
236
231
await mlflow . logBatch ( runId , undefined , undefined , paramTags ) ;
237
232
233
+ // Save model artifects
234
+ const __filename = fileURLToPath ( import . meta. url ) ;
235
+ const __dirname = dirname ( __filename ) ;
236
+ const artifactsPath = `${ __dirname } /../mlruns/${ experimentId } /${ runId } /artifacts` ;
237
+ await model . save ( `file://${ artifactsPath } ` ) ;
238
+
238
239
// Register model if performance meets threshold
239
240
if ( evaluation . metrics . r2 > 0.9 ) {
240
241
const modelName = 'LinearRegression' ;
@@ -259,7 +260,10 @@ async function main() {
259
260
`runs:/${ runId } /model` ,
260
261
runId ,
261
262
[
262
- { key : 'r2' , value : evaluation . metrics . r2 . toString ( ) } ,
263
+ {
264
+ key : 'learning_rate' ,
265
+ value : HYPERPARAMETERS . learningRate . toString ( ) ,
266
+ } ,
263
267
{ key : 'rmse' , value : evaluation . metrics . rmse . toString ( ) } ,
264
268
]
265
269
) ;
@@ -277,13 +281,6 @@ async function main() {
277
281
}
278
282
}
279
283
280
- // Log additional metadata
281
- const tags = [
282
- { key : 'model_type' , value : 'linear_regression' } ,
283
- { key : 'data_source' , value : 'synthetic' } ,
284
- ] ;
285
- await mlflow . logBatch ( runId , undefined , undefined , tags ) ;
286
-
287
284
// Finish run
288
285
await mlflow . updateRun ( runId , 'FINISHED' ) ;
289
286
0 commit comments