@@ -459,13 +459,6 @@ def __init__(self, mconfig, name=None):
459459
460460 self ._avg_pooling = tf .keras .layers .GlobalAveragePooling2D (
461461 data_format = mconfig .data_format )
462- if mconfig .num_classes :
463- self ._fc = tf .keras .layers .Dense (
464- mconfig .num_classes ,
465- kernel_initializer = dense_kernel_initializer ,
466- bias_initializer = tf .constant_initializer (mconfig .headbias or 0 ))
467- else :
468- self ._fc = None
469462
470463 if mconfig .dropout_rate > 0 :
471464 self ._dropout = tf .keras .layers .Dropout (mconfig .dropout_rate )
@@ -498,9 +491,6 @@ def call(self, inputs, training):
498491 self .endpoints ['pooled_features' ] = outputs
499492 if self ._dropout :
500493 outputs = self ._dropout (outputs , training = training )
501- self .endpoints ['global_pool' ] = outputs
502- if self ._fc :
503- outputs = self ._fc (outputs )
504494 self .endpoints ['head' ] = outputs
505495 return outputs
506496
@@ -514,12 +504,13 @@ class EffNetV2Model(tf.keras.Model):
514504 def __init__ (self ,
515505 model_name = 'efficientnetv2-s' ,
516506 model_config = None ,
507+ include_top = True ,
517508 name = None ):
518509 """Initializes an `Model` instance.
519510
520511 Args:
521512 model_name: A string of model name.
522- model_config: A dict of model configureations or a string of hparams.
513+ model_config: A dict of model configurations or a string of hparams.
523514 name: A string of layer name.
524515
525516 Raises:
@@ -533,6 +524,7 @@ def __init__(self,
533524 self .cfg = cfg
534525 self ._mconfig = cfg .model
535526 self .endpoints = None
527+ self .include_top = include_top
536528 self ._build ()
537529
538530 def _build (self ):
@@ -574,12 +566,25 @@ def _build(self):
574566 # Head part.
575567 self ._head = Head (self ._mconfig )
576568
569+ # top part for classification
570+ if self .include_top and self ._mconfig .num_classes :
571+ self ._fc = tf .keras .layers .Dense (
572+ self ._mconfig .num_classes ,
573+ kernel_initializer = dense_kernel_initializer ,
574+ bias_initializer = tf .constant_initializer (self ._mconfig .headbias or 0 ))
575+ else :
576+ self ._fc = None
577+
577578 def summary (self , input_shape = (224 , 224 , 3 ), ** kargs ):
578579 x = tf .keras .Input (shape = input_shape )
579580 model = tf .keras .Model (inputs = [x ], outputs = self .call (x , training = True ))
580581 return model .summary ()
581582
582- def call (self , inputs , training , features_only = None , single_out = None ):
583+ def get_model_with_inputs (self , inputs , ** kargs ):
584+ model = tf .keras .Model (inputs = [inputs ], outputs = self .call (inputs , training = True ))
585+ return model
586+
587+ def call (self , inputs , training , with_endpoints = False ):
583588 """Implementation of call().
584589
585590 Args:
@@ -624,19 +629,70 @@ def call(self, inputs, training, features_only=None, single_out=None):
624629 self .endpoints ['reduction_%s/%s' % (reduction_idx , k )] = v
625630 self .endpoints ['features' ] = outputs
626631
627- if not features_only :
628- # Calls final layers and returns logits.
629- outputs = self ._head (outputs , training )
630- self .endpoints .update (self ._head .endpoints )
631-
632- if single_out : # Use for building sequential models.
633- return outputs
634-
635- return [outputs ] + list (
636- filter (lambda endpoint : endpoint is not None , [
637- self .endpoints .get ('reduction_1' ),
638- self .endpoints .get ('reduction_2' ),
639- self .endpoints .get ('reduction_3' ),
640- self .endpoints .get ('reduction_4' ),
641- self .endpoints .get ('reduction_5' ),
642- ]))
632+ # Head to obtain the final feature.
633+ outputs = self ._head (outputs , training )
634+ self .endpoints .update (self ._head .endpoints )
635+
636+ # Calls final dense layers and returns logits.
637+ if self ._fc :
638+ with tf .name_scope ('head' ): # legacy
639+ outputs = self ._fc (outputs )
640+
641+ if with_endpoints : # Use for building sequential models.
642+ return [outputs ] + list (
643+ filter (lambda endpoint : endpoint is not None , [
644+ self .endpoints .get ('reduction_1' ),
645+ self .endpoints .get ('reduction_2' ),
646+ self .endpoints .get ('reduction_3' ),
647+ self .endpoints .get ('reduction_4' ),
648+ self .endpoints .get ('reduction_5' ),
649+ ]))
650+
651+ return outputs
652+
653+
654+ def get_model (model_name ,
655+ model_config = None ,
656+ include_top = True ,
657+ pretrained = True ,
658+ training = True ,
659+ with_endpoints = False ,
660+ ** kargs ):
661+ """Get a EfficientNet V1 or V2 model instance.
662+
663+ This is a simply utility for finetuning or inference.
664+
665+ Args:
666+ model_name: a string such as 'efficientnetv2-s' or 'efficientnet-b0'.
667+ model_config: A dict of model configurations or a string of hparams.
668+ include_top: whether to include the final dense layer for classification.
669+ pretrained: if true, download the checkpoint. If string, load the ckpt.
670+ training: If true, all model variables are trainable.
671+ with_endpoints: whether to return all intermedia endpoints.
672+
673+ Returns:
674+ A single tensor if with_endpoints if False; otherwise, a list of tensor.
675+ """
676+ net = EffNetV2Model (model_name , model_config , include_top )
677+ net (tf .keras .Input (shape = (None , None , 3 )),
678+ training = training ,
679+ with_endpoints = with_endpoints )
680+ if pretrained is True :
681+ # download checkpoint and set pretrained path. Supported models include:
682+ # efficientnetv2-s, efficientnetv2-m, efficientnetv2-l,
683+ # efficientnetv2-b0, efficientnetv2-b1, efficientnetv2-b2, efficientnetv2-b3,
684+ # efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3,
685+ # efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7, efficientnet-l2
686+ # More V2 ckpts: https://github.com/google/automl/tree/master/efficientnetv2
687+ # More V1 ckpts: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
688+ url = f'https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/v2/{ model_name } .tgz'
689+ pretrained_ckpt = tf .keras .utils .get_file (model_name , url , untar = True )
690+ else :
691+ pretrained_ckpt = pretrained
692+
693+ if pretrained_ckpt :
694+ if tf .io .gfile .isdir (pretrained_ckpt ):
695+ pretrained_ckpt = tf .train .latest_checkpoint (pretrained_ckpt )
696+ net .load_weights (pretrained_ckpt )
697+
698+ return net
0 commit comments