|
| 1 | +import numpy as np |
| 2 | +from hls4ml.optimization.keras import optimize_model |
| 3 | +from hls4ml.optimization.attributes import get_attributes_from_keras_model_and_hls4ml_config |
| 4 | + |
| 5 | +def optimize_keras_for_hls4ml( |
| 6 | + keras_model, hls_config, objective, scheduler, X_train, y_train, X_val, y_val, |
| 7 | + batch_size, epochs, optimizer, loss_fn, validation_metric, increasing, rtol, |
| 8 | + callbacks=[], ranking_metric='l1', local=False, verbose=False, rewinding_epochs=1, cutoff_bad_trials=3, |
| 9 | + directory='hls4ml-optimization', tuner='Bayesian', knapsack_solver='CBC_MIP', |
| 10 | + regularization_range=np.logspace(-6, -2, num=16).tolist() |
| 11 | +): |
| 12 | + ''' |
| 13 | + Top-level function for optimizing a Keras model, given hls4ml config and a hardware objective(s) |
| 14 | +
|
| 15 | + Args: |
| 16 | + - keras_model (keras.Model): Model to be optimized |
| 17 | + - hls_config (dict): hls4ml configuration, obtained from hls4ml.utils.config.config_from_keras_model(...) |
| 18 | + - objective (hls4ml.optimization.objectives.ObjectiveEstimator): Parameter, hardware or user-defined objective of optimization |
| 19 | + - scheduler (hls4ml.optimization.schduler.OptimizationScheduler): Sparsity scheduler, choose between constant, polynomial and binary |
| 20 | + - X_train (np.array): Training inputs |
| 21 | + - y_train (np.array): Training labels |
| 22 | + - X_val (np.array): Validation inputs |
| 23 | + - y_val (np.array): Validation labels |
| 24 | + - batch_size (int): Batch size during training |
| 25 | + - epochs (int): Maximum number of epochs to fine-tune model, in one iteration of pruning |
| 26 | + - optimizer (keras.optimizers.Optimizer or equivalent-string description): Optimizer used during training |
| 27 | + - loss_fn (keras.losses.Loss or equivalent loss description): Loss function used during training |
| 28 | + - validation_metric (keras.metrics.Metric or equivalent loss description): Validation metric, used as a baseline |
| 29 | + - increasing (boolean): If the metric improves with increased values; e.g. accuracy -> increasing = True, MSE -> increasing = False |
| 30 | + - rtol (float): Relative tolerance; pruning stops when pruned_validation_metric < (or >) rtol * baseline_validation_metric |
| 31 | + |
| 32 | + Kwargs: |
| 33 | + - callbacks (list of keras.callbacks.Callback) Currently not supported, developed in future versions |
| 34 | + - ranking_metric (string): Metric used for rannking weights and structures; currently supported l1, l2, saliency and Oracle |
| 35 | + - local (boolean): Layer-wise or global pruning |
| 36 | + - verbose (boolean): Display debug logs during model optimization |
| 37 | + - rewinding_epochs (int): Number of epochs to retrain model without weight freezing, allows regrowth of previously pruned weights |
| 38 | + - cutoff_bad_trials (int): After how many bad trials (performance below threshold), should model pruning / weight sharing stop |
| 39 | + - directory (string): Directory to store temporary results |
| 40 | + - tuner (str): Tuning alogorithm, choose between Bayesian, Hyperband and None |
| 41 | + - knapsack_solver (str): Algorithm to solve Knapsack problem when optimizing; default usually works well; for very large networks, greedy algorithm might be more suitable |
| 42 | + - regularization_range (list): List of suitable hyperparameters for weight decay |
| 43 | + ''' |
| 44 | + |
| 45 | + # Extract model attributes |
| 46 | + model_attributes = get_attributes_from_keras_model_and_hls4ml_config(keras_model, hls_config) |
| 47 | + |
| 48 | + # Optimize model |
| 49 | + return optimize_model( |
| 50 | + keras_model, model_attributes, objective, scheduler, |
| 51 | + X_train, y_train, X_val, y_val, batch_size, epochs, |
| 52 | + optimizer, loss_fn, validation_metric, increasing, rtol, |
| 53 | + callbacks=callbacks, ranking_metric=ranking_metric, local=local, verbose=verbose, |
| 54 | + rewinding_epochs=rewinding_epochs, cutoff_bad_trials=cutoff_bad_trials, |
| 55 | + directory=directory, tuner=tuner, knapsack_solver=knapsack_solver, |
| 56 | + regularization_range=regularization_range |
| 57 | + ) |
0 commit comments