@@ -511,6 +511,7 @@ def __init__(self,
511511 Args:
512512 model_name: A string of model name.
513513 model_config: A dict of model configurations or a string of hparams.
514+ include_top: If True, include the top layer for classification.
514515 name: A string of layer name.
515516
516517 Raises:
@@ -581,17 +582,17 @@ def summary(self, input_shape=(224, 224, 3), **kargs):
581582 return model .summary ()
582583
583584 def get_model_with_inputs (self , inputs , ** kargs ):
584- model = tf .keras .Model (inputs = [inputs ], outputs = self .call (inputs , training = True ))
585- return model
585+ model = tf .keras .Model (
586+ inputs = [inputs ], outputs = self .call (inputs , training = True ))
587+ return model
586588
587589 def call (self , inputs , training , with_endpoints = False ):
588590 """Implementation of call().
589591
590592 Args:
591593 inputs: input tensors.
592594 training: boolean, whether the model is constructed for training.
593- features_only: build the base feature network only.
594- single_out: If true, only return the single output.
595+ with_endpoints: If true, return a list of endpoints.
595596
596597 Returns:
597598 output tensors.
@@ -657,7 +658,7 @@ def get_model(model_name,
657658 pretrained = True ,
658659 training = True ,
659660 with_endpoints = False ,
660- ** kargs ):
661+ ** kwargs ):
661662 """Get a EfficientNet V1 or V2 model instance.
662663
663664 This is a simply utility for finetuning or inference.
@@ -669,23 +670,29 @@ def get_model(model_name,
669670 pretrained: if true, download the checkpoint. If string, load the ckpt.
670671 training: If true, all model variables are trainable.
671672 with_endpoints: whether to return all intermedia endpoints.
673+ **kwargs: additional parameters for keras model, such as name=xx.
672674
673675 Returns:
674676 A single tensor if with_endpoints if False; otherwise, a list of tensor.
675677 """
676- net = EffNetV2Model (model_name , model_config , include_top )
678+ net = EffNetV2Model (model_name , model_config , include_top , ** kwargs )
677679 net (tf .keras .Input (shape = (None , None , 3 )),
678680 training = training ,
679681 with_endpoints = with_endpoints )
680- if pretrained is True :
682+ if pretrained is True : # pylint: disable=g-bool-id-comparison
683+ # pylint: disable=line-too-long
681684 # download checkpoint and set pretrained path. Supported models include:
682- # efficientnetv2-s, efficientnetv2-m, efficientnetv2-l,
683- # efficientnetv2-b0, efficientnetv2-b1, efficientnetv2-b2, efficientnetv2-b3,
684- # efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3,
685- # efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7, efficientnet-l2
686- # More V2 ckpts: https://github.com/google/automl/tree/master/efficientnetv2
687- # More V1 ckpts: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
688- url = f'https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/v2/{ model_name } .tgz'
685+ # efficientnetv2-s, efficientnetv2-m, efficientnetv2-l,
686+ # efficientnetv2-b0, efficientnetv2-b1, efficientnetv2-b2, efficientnetv2-b3,
687+ # efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3,
688+ # efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7,
689+ # efficientnet-l2
690+ # v2: https://github.com/google/automl/tree/master/efficientnetv2
691+ # v1: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
692+ # pylint: enable=line-too-long
693+
694+ url = ('https://storage.googleapis.com/cloud-tpu-checkpoints/'
695+ f'efficientnet/v2/{ model_name } .tgz' )
689696 pretrained_ckpt = tf .keras .utils .get_file (model_name , url , untar = True )
690697 else :
691698 pretrained_ckpt = pretrained
@@ -695,4 +702,4 @@ def get_model(model_name,
695702 pretrained_ckpt = tf .train .latest_checkpoint (pretrained_ckpt )
696703 net .load_weights (pretrained_ckpt )
697704
698- return net
705+ return net
0 commit comments