@@ -91,7 +91,7 @@ def train_step(self, data):
9191 images , labels = features ['image' ], labels ['label' ]
9292
9393 with tf .GradientTape () as tape :
94- pred = self (images , training = True )[ 0 ]
94+ pred = self (images , training = True )
9595 pred = tf .cast (pred , tf .float32 )
9696 loss = self .compiled_loss (
9797 labels ,
@@ -105,7 +105,7 @@ def train_step(self, data):
105105 def test_step (self , data ):
106106 features , labels = data
107107 images , labels = features ['image' ], labels ['label' ]
108- pred = self (images , training = False )[ 0 ]
108+ pred = self (images , training = False )
109109 pred = tf .cast (pred , tf .float32 )
110110
111111 self .compiled_loss (
@@ -174,9 +174,9 @@ def main(_) -> None:
174174 weight_decay = config .train .weight_decay )
175175
176176 if config .train .ft_init_ckpt : # load pretrained ckpt for finetuning.
177- model (tf .ones ([ 1 , 224 , 224 , 3 ]))
177+ model (tf .keras . Input ([ None , None , 3 ]))
178178 ckpt = config .train .ft_init_ckpt
179- utils .restore_tf2_ckpt (model , ckpt , exclude_layers = ('_head ' , 'optimizer' ))
179+ utils .restore_tf2_ckpt (model , ckpt , exclude_layers = ('_fc ' , 'optimizer' ))
180180
181181 steps_per_epoch = num_train_images // config .train .batch_size
182182 total_steps = steps_per_epoch * config .train .epochs
0 commit comments