|
| 1 | +/** |
| 2 | + * @fileoverview Example of using MLflow.js for machine learning projects |
| 3 | + * The example creates a synthetic dataset and trains a linear regression model using TensorFlow.js, |
| 4 | + * tracking various metrics and parameters throughout the training process with MLflow tracking server. |
| 5 | + * It showcases: |
| 6 | + * - Experiment and run management |
| 7 | + * - Hyperparameter logging |
| 8 | + * - Metric tracking |
| 9 | + * - Tagging |
| 10 | + * - Model registry |
| 11 | + * - Model versioning |
| 12 | + * - Artifact storage |
| 13 | + * |
| 14 | + * @requires @tensorflow/tfjs-node |
| 15 | + * |
| 16 | + * @note Ensure MLflow server is running at http://localhost:5001 before executing |
| 17 | + */ |
| 18 | + |
| 19 | +import * as tf from '@tensorflow/tfjs-node'; |
| 20 | +import MLflow from 'mlflow-js'; |
| 21 | +import { fileURLToPath } from 'url'; |
| 22 | +import { dirname } from 'path'; |
| 23 | + |
| 24 | +const mlflow = new MLflow('http://localhost:5001'); |
| 25 | + |
| 26 | +const HYPERPARAMETERS = { |
| 27 | + learningRate: 0.1, |
| 28 | + epochs: 10, |
| 29 | + batchSize: 32, |
| 30 | + validationSplit: 0.2, |
| 31 | +}; |
| 32 | + |
| 33 | +const DATA_CONFIG = { |
| 34 | + numSamples: 1000, |
| 35 | + trainTestSplit: 0.7, |
| 36 | + trueSlope: 2, |
| 37 | + noiseStd: 0.1, |
| 38 | +}; |
| 39 | + |
| 40 | +function generateData() { |
| 41 | + // Generate synthetic data: y = 2x + noise |
| 42 | + const x = tf.randomUniform([DATA_CONFIG.numSamples, 1]); |
| 43 | + const y = tf.tidy(() => { |
| 44 | + const trueValues = x.mul(DATA_CONFIG.trueSlope); |
| 45 | + const noise = tf.randomNormal( |
| 46 | + [DATA_CONFIG.numSamples, 1], |
| 47 | + 0, |
| 48 | + DATA_CONFIG.noiseStd |
| 49 | + ); |
| 50 | + return trueValues.add(noise); |
| 51 | + }); |
| 52 | + |
| 53 | + // Split into train and test sets |
| 54 | + const splitIdx = Math.floor( |
| 55 | + DATA_CONFIG.numSamples * DATA_CONFIG.trainTestSplit |
| 56 | + ); |
| 57 | + const xTrain = x.slice([0, 0], [splitIdx, 1]); |
| 58 | + const xTest = x.slice([splitIdx, 0], [-1, 1]); |
| 59 | + const yTrain = y.slice([0, 0], [splitIdx, 1]); |
| 60 | + const yTest = y.slice([splitIdx, 0], [-1, 1]); |
| 61 | + |
| 62 | + return { xTrain, yTrain, xTest, yTest }; |
| 63 | +} |
| 64 | + |
| 65 | +function createModel() { |
| 66 | + const model = tf.sequential({ |
| 67 | + layers: [tf.layers.dense({ units: 1, inputShape: [1] })], |
| 68 | + }); |
| 69 | + |
| 70 | + model.compile({ |
| 71 | + optimizer: tf.train.sgd(HYPERPARAMETERS.learningRate), |
| 72 | + loss: 'meanSquaredError', |
| 73 | + metrics: ['mae'], |
| 74 | + }); |
| 75 | + |
| 76 | + return model; |
| 77 | +} |
| 78 | + |
| 79 | +async function trainModel(model, xTrain, yTrain, runId) { |
| 80 | + const history = await model.fit(xTrain, yTrain, { |
| 81 | + epochs: HYPERPARAMETERS.epochs, |
| 82 | + batchSize: HYPERPARAMETERS.batchSize, |
| 83 | + validationSplit: HYPERPARAMETERS.validationSplit, |
| 84 | + verbose: 1, |
| 85 | + callbacks: { |
| 86 | + onEpochEnd: async (epoch, logs) => { |
| 87 | + // Log metrics for both training and validation |
| 88 | + const metrics = [ |
| 89 | + { |
| 90 | + key: 'train_loss', |
| 91 | + value: logs.loss, |
| 92 | + timestamp: Date.now(), |
| 93 | + step: epoch, |
| 94 | + }, |
| 95 | + { |
| 96 | + key: 'train_mae', |
| 97 | + value: logs.mae, |
| 98 | + timestamp: Date.now(), |
| 99 | + step: epoch, |
| 100 | + }, |
| 101 | + { |
| 102 | + key: 'val_loss', |
| 103 | + value: logs.val_loss, |
| 104 | + timestamp: Date.now(), |
| 105 | + step: epoch, |
| 106 | + }, |
| 107 | + { |
| 108 | + key: 'val_mae', |
| 109 | + value: logs.val_mae, |
| 110 | + timestamp: Date.now(), |
| 111 | + step: epoch, |
| 112 | + }, |
| 113 | + ]; |
| 114 | + |
| 115 | + await mlflow.logBatch(runId, metrics); |
| 116 | + }, |
| 117 | + }, |
| 118 | + }); |
| 119 | + |
| 120 | + return history; |
| 121 | +} |
| 122 | + |
| 123 | +function evaluateModel(model, xTest, yTest) { |
| 124 | + return tf.tidy(() => { |
| 125 | + const yPred = model.predict(xTest); |
| 126 | + |
| 127 | + // Evaluation Metrics |
| 128 | + const mse = tf.metrics.meanSquaredError(yTest, yPred).dataSync()[0]; |
| 129 | + const rmse = Math.sqrt(mse); |
| 130 | + const mae = tf.metrics.meanAbsoluteError(yTest, yPred).dataSync()[0]; |
| 131 | + |
| 132 | + // R-squared |
| 133 | + const yMean = tf.mean(yTest); |
| 134 | + const totalSS = yTest.sub(yMean).square().sum().dataSync()[0]; |
| 135 | + const residualSS = yTest.sub(yPred).square().sum().dataSync()[0]; |
| 136 | + const r2 = 1 - residualSS / totalSS; |
| 137 | + |
| 138 | + // Model Parameters |
| 139 | + const weight = model.layers[0].getWeights()[0].dataSync()[0]; |
| 140 | + const bias = model.layers[0].getWeights()[1].dataSync()[0]; |
| 141 | + |
| 142 | + return { |
| 143 | + metrics: { |
| 144 | + mse, |
| 145 | + rmse, |
| 146 | + mae, |
| 147 | + r2, |
| 148 | + }, |
| 149 | + parameters: { |
| 150 | + weight, |
| 151 | + bias, |
| 152 | + }, |
| 153 | + }; |
| 154 | + }); |
| 155 | +} |
| 156 | + |
| 157 | +async function main() { |
| 158 | + try { |
| 159 | + // Initialize experiment |
| 160 | + const experimentName = 'Linear_Regression_Example'; |
| 161 | + let experimentId; |
| 162 | + try { |
| 163 | + const experiment = await mlflow.getExperimentByName(experimentName); |
| 164 | + experimentId = experiment.experiment_id; |
| 165 | + } catch { |
| 166 | + experimentId = await mlflow.createExperiment(experimentName); |
| 167 | + } |
| 168 | + console.log(`MLflow Experiment ID: ${experimentId}`); |
| 169 | + |
| 170 | + // Create run |
| 171 | + const run = await mlflow.createRun( |
| 172 | + experimentId, |
| 173 | + 'Simple Linear Regression' |
| 174 | + ); |
| 175 | + const runId = run.info.run_id; |
| 176 | + console.log(`MLflow Run ID: ${runId}`); |
| 177 | + |
| 178 | + // Log hyperparameters |
| 179 | + const params = [ |
| 180 | + { key: 'learning_rate', value: HYPERPARAMETERS.learningRate.toString() }, |
| 181 | + { key: 'epochs', value: HYPERPARAMETERS.epochs.toString() }, |
| 182 | + { key: 'batch_size', value: HYPERPARAMETERS.batchSize.toString() }, |
| 183 | + { |
| 184 | + key: 'validation_split', |
| 185 | + value: HYPERPARAMETERS.validationSplit.toString(), |
| 186 | + }, |
| 187 | + { key: 'train_test_split', value: DATA_CONFIG.trainTestSplit.toString() }, |
| 188 | + { key: 'num_samples', value: DATA_CONFIG.numSamples.toString() }, |
| 189 | + ]; |
| 190 | + await mlflow.logBatch(runId, undefined, params); |
| 191 | + |
| 192 | + // Log dataset info |
| 193 | + const datasetInfo = { |
| 194 | + dataset: { |
| 195 | + name: 'synthetic_linear_regression_data', |
| 196 | + }, |
| 197 | + }; |
| 198 | + await mlflow.logInputs(runId, [datasetInfo]); |
| 199 | + |
| 200 | + // Generate and prepare data |
| 201 | + console.log('Generating data...'); |
| 202 | + const { xTrain, yTrain, xTest, yTest } = generateData(); |
| 203 | + |
| 204 | + // Train model |
| 205 | + console.log('Creating and training model...'); |
| 206 | + const model = createModel(); |
| 207 | + const history = await trainModel(model, xTrain, yTrain, runId); |
| 208 | + |
| 209 | + // Evaluate model |
| 210 | + const evaluation = evaluateModel(model, xTest, yTest); |
| 211 | + |
| 212 | + // Save model artifects |
| 213 | + const __filename = fileURLToPath(import.meta.url); |
| 214 | + const __dirname = dirname(__filename); |
| 215 | + const artifactsPath = `${__dirname}/../mlruns/${experimentId}/${runId}/artifacts`; |
| 216 | + await model.save(`file://${artifactsPath}`); |
| 217 | + |
| 218 | + // Log evaluation metrics |
| 219 | + const finalMetrics = [ |
| 220 | + { key: 'test_mse', value: evaluation.metrics.mse, timestamp: Date.now() }, |
| 221 | + { |
| 222 | + key: 'test_rmse', |
| 223 | + value: evaluation.metrics.rmse, |
| 224 | + timestamp: Date.now(), |
| 225 | + }, |
| 226 | + { key: 'test_mae', value: evaluation.metrics.mae, timestamp: Date.now() }, |
| 227 | + { key: 'test_r2', value: evaluation.metrics.r2, timestamp: Date.now() }, |
| 228 | + ]; |
| 229 | + await mlflow.logBatch(runId, finalMetrics); |
| 230 | + |
| 231 | + // Log model parameters |
| 232 | + const paramTags = [ |
| 233 | + { key: 'model_weight', value: evaluation.parameters.weight.toString() }, |
| 234 | + { key: 'model_bias', value: evaluation.parameters.bias.toString() }, |
| 235 | + ]; |
| 236 | + await mlflow.logBatch(runId, undefined, undefined, paramTags); |
| 237 | + |
| 238 | + // Register model if performance meets threshold |
| 239 | + if (evaluation.metrics.r2 > 0.9) { |
| 240 | + const modelName = 'LinearRegression'; |
| 241 | + try { |
| 242 | + let modelExists = true; |
| 243 | + try { |
| 244 | + await mlflow.getRegisteredModel(modelName); |
| 245 | + } catch (err) { |
| 246 | + modelExists = false; |
| 247 | + } |
| 248 | + |
| 249 | + if (!modelExists) { |
| 250 | + await mlflow.createRegisteredModel( |
| 251 | + modelName, |
| 252 | + [{ key: 'task', value: 'regression' }], |
| 253 | + 'Simple linear regression model' |
| 254 | + ); |
| 255 | + } |
| 256 | + |
| 257 | + const modelVersion = await mlflow.createModelVersion( |
| 258 | + modelName, |
| 259 | + `runs:/${runId}/model`, |
| 260 | + runId, |
| 261 | + [ |
| 262 | + { key: 'r2', value: evaluation.metrics.r2.toString() }, |
| 263 | + { key: 'rmse', value: evaluation.metrics.rmse.toString() }, |
| 264 | + ] |
| 265 | + ); |
| 266 | + |
| 267 | + if (evaluation.metrics.r2 > 0.95) { |
| 268 | + await mlflow.transitionModelVersionStage( |
| 269 | + modelName, |
| 270 | + modelVersion.version, |
| 271 | + 'staging', |
| 272 | + true |
| 273 | + ); |
| 274 | + } |
| 275 | + } catch (e) { |
| 276 | + console.error('Model registration error:', e.message); |
| 277 | + } |
| 278 | + } |
| 279 | + |
| 280 | + // Log additional metadata |
| 281 | + const tags = [ |
| 282 | + { key: 'model_type', value: 'linear_regression' }, |
| 283 | + { key: 'data_source', value: 'synthetic' }, |
| 284 | + ]; |
| 285 | + await mlflow.logBatch(runId, undefined, undefined, tags); |
| 286 | + |
| 287 | + // Finish run |
| 288 | + await mlflow.updateRun(runId, 'FINISHED'); |
| 289 | + |
| 290 | + // Cleanup |
| 291 | + tf.dispose([xTrain, yTrain, xTest, yTest]); |
| 292 | + |
| 293 | + console.log('\nMLflow tracking completed successfully!'); |
| 294 | + console.log( |
| 295 | + `View run details at http://localhost:5001/#/experiments/${experimentId}/runs/${runId}` |
| 296 | + ); |
| 297 | + } catch (error) { |
| 298 | + console.error('Error:', error); |
| 299 | + } |
| 300 | +} |
| 301 | + |
| 302 | +main(); |
0 commit comments