Skip to content

Commit a778e39

Browse files
committed
hls4ml objectives & top-level optimization function
1 parent 47392ba commit a778e39

File tree

2 files changed

+313
-0
lines changed

2 files changed

+313
-0
lines changed

hls4ml/optimization/__init__.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import numpy as np
2+
from hls4ml.optimization.keras import optimize_model
3+
from hls4ml.optimization.attributes import get_attributes_from_keras_model_and_hls4ml_config
4+
5+
def optimize_keras_for_hls4ml(
6+
keras_model, hls_config, objective, scheduler, X_train, y_train, X_val, y_val,
7+
batch_size, epochs, optimizer, loss_fn, validation_metric, increasing, rtol,
8+
callbacks=[], ranking_metric='l1', local=False, verbose=False, rewinding_epochs=1, cutoff_bad_trials=3,
9+
directory='hls4ml-optimization', tuner='Bayesian', knapsack_solver='CBC_MIP',
10+
regularization_range=np.logspace(-6, -2, num=16).tolist()
11+
):
12+
'''
13+
Top-level function for optimizing a Keras model, given hls4ml config and a hardware objective(s)
14+
15+
Args:
16+
- keras_model (keras.Model): Model to be optimized
17+
- hls_config (dict): hls4ml configuration, obtained from hls4ml.utils.config.config_from_keras_model(...)
18+
- objective (hls4ml.optimization.objectives.ObjectiveEstimator): Parameter, hardware or user-defined objective of optimization
19+
- scheduler (hls4ml.optimization.schduler.OptimizationScheduler): Sparsity scheduler, choose between constant, polynomial and binary
20+
- X_train (np.array): Training inputs
21+
- y_train (np.array): Training labels
22+
- X_val (np.array): Validation inputs
23+
- y_val (np.array): Validation labels
24+
- batch_size (int): Batch size during training
25+
- epochs (int): Maximum number of epochs to fine-tune model, in one iteration of pruning
26+
- optimizer (keras.optimizers.Optimizer or equivalent-string description): Optimizer used during training
27+
- loss_fn (keras.losses.Loss or equivalent loss description): Loss function used during training
28+
- validation_metric (keras.metrics.Metric or equivalent loss description): Validation metric, used as a baseline
29+
- increasing (boolean): If the metric improves with increased values; e.g. accuracy -> increasing = True, MSE -> increasing = False
30+
- rtol (float): Relative tolerance; pruning stops when pruned_validation_metric < (or >) rtol * baseline_validation_metric
31+
32+
Kwargs:
33+
- callbacks (list of keras.callbacks.Callback) Currently not supported, developed in future versions
34+
- ranking_metric (string): Metric used for rannking weights and structures; currently supported l1, l2, saliency and Oracle
35+
- local (boolean): Layer-wise or global pruning
36+
- verbose (boolean): Display debug logs during model optimization
37+
- rewinding_epochs (int): Number of epochs to retrain model without weight freezing, allows regrowth of previously pruned weights
38+
- cutoff_bad_trials (int): After how many bad trials (performance below threshold), should model pruning / weight sharing stop
39+
- directory (string): Directory to store temporary results
40+
- tuner (str): Tuning alogorithm, choose between Bayesian, Hyperband and None
41+
- knapsack_solver (str): Algorithm to solve Knapsack problem when optimizing; default usually works well; for very large networks, greedy algorithm might be more suitable
42+
- regularization_range (list): List of suitable hyperparameters for weight decay
43+
'''
44+
45+
# Extract model attributes
46+
model_attributes = get_attributes_from_keras_model_and_hls4ml_config(keras_model, hls_config)
47+
48+
# Optimize model
49+
return optimize_model(
50+
keras_model, model_attributes, objective, scheduler,
51+
X_train, y_train, X_val, y_val, batch_size, epochs,
52+
optimizer, loss_fn, validation_metric, increasing, rtol,
53+
callbacks=callbacks, ranking_metric=ranking_metric, local=local, verbose=verbose,
54+
rewinding_epochs=rewinding_epochs, cutoff_bad_trials=cutoff_bad_trials,
55+
directory=directory, tuner=tuner, knapsack_solver=knapsack_solver,
56+
regularization_range=regularization_range
57+
)

0 commit comments

Comments
 (0)