|
13 | 13 | from hls4ml.model.layers import GRU, LSTM, SeparableConv1D, SeparableConv2D |
14 | 14 |
|
15 | 15 | try: |
16 | | - import qkeras |
17 | | - from tensorflow import keras |
| 16 | + import keras |
18 | 17 |
|
19 | | - __tf_profiling_enabled__ = True |
| 18 | + __keras_profiling_enabled__ = True |
20 | 19 | except ImportError: |
21 | | - __tf_profiling_enabled__ = False |
| 20 | + __keras_profiling_enabled__ = False |
22 | 21 |
|
23 | 22 | try: |
24 | 23 | import torch |
|
27 | 26 | except ImportError: |
28 | 27 | __torch_profiling_enabled__ = False |
29 | 28 |
|
| 29 | +try: |
| 30 | + import qkeras |
| 31 | + |
| 32 | + __qkeras_profiling_enabled__ = True |
| 33 | +except ImportError: |
| 34 | + __qkeras_profiling_enabled__ = False |
| 35 | + |
| 36 | +_activations = list() |
| 37 | +if __keras_profiling_enabled__: |
| 38 | + _activations.append(keras.layers.Activation) |
| 39 | +if __qkeras_profiling_enabled__: |
| 40 | + _activations.append(qkeras.qactivations) |
| 41 | + |
30 | 42 |
|
31 | 43 | def get_unoptimized_hlsmodel(model): |
32 | 44 | from hls4ml.converters import convert_from_config |
@@ -482,7 +494,7 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): |
482 | 494 | if hls_model_present: |
483 | 495 | data = weights_hlsmodel(hls_model_unoptimized, fmt='summary', plot=plot) |
484 | 496 | elif model_present: |
485 | | - if __tf_profiling_enabled__ and isinstance(model, keras.Model): |
| 497 | + if __keras_profiling_enabled__ and isinstance(model, keras.Model): |
486 | 498 | data = weights_keras(model, fmt='summary', plot=plot) |
487 | 499 | elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential): |
488 | 500 | data = weights_torch(model, fmt='summary', plot=plot) |
@@ -520,7 +532,7 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): |
520 | 532 | if X is not None: |
521 | 533 | print("Profiling activations" + before) |
522 | 534 | data = None |
523 | | - if __tf_profiling_enabled__ and isinstance(model, keras.Model): |
| 535 | + if __keras_profiling_enabled__ and isinstance(model, keras.Model): |
524 | 536 | data = activations_keras(model, X, fmt='summary', plot=plot) |
525 | 537 | elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential): |
526 | 538 | data = activations_torch(model, X, fmt='summary', plot=plot) |
@@ -590,7 +602,7 @@ def get_ymodel_keras(keras_model, X): |
590 | 602 | if ( |
591 | 603 | hasattr(layer, 'activation') |
592 | 604 | and layer.activation is not None |
593 | | - and not isinstance(layer, (keras.layers.Activation, qkeras.qlayers.QActivation)) |
| 605 | + and not isinstance(layer, _activations) |
594 | 606 | and layer.activation.__name__ != 'linear' |
595 | 607 | ): |
596 | 608 | tmp_activation = layer.activation |
|
0 commit comments