@@ -423,6 +423,8 @@ class PaliGemmaVit(keras.Model):
423423 Args:
424424 image_size: int. The height/width of the image. Both height and width is
425425 expected to be the same.
426+ include_rescaling: bool. If true, the image input will be rescaled from
427+ the range `[0, 255]`, to the range `[0, 1]`.
426428 patch_size: int. The size of each square patch in the input image.
427429 num_heads: int. The number of attention heads for the vision(image)
428430 transformer encoder.
@@ -463,6 +465,7 @@ def __init__(
463465 num_layers ,
464466 intermediate_dim ,
465467 num_classes ,
468+ include_rescaling = True ,
466469 pooling = None ,
467470 classifier_activation = None ,
468471 dtype = None ,
@@ -472,7 +475,13 @@ def __init__(
472475 image_input = keras .Input (
473476 shape = (image_size , image_size , 3 ), name = "images"
474477 )
475- encoded = PaliGemmaVitEncoder (
478+ x = image_input # Intermediate result.
479+ if include_rescaling :
480+ rescaling = keras .layers .Rescaling (
481+ scale = 1.0 / 127.5 , offset = - 1.0 , name = "rescaling"
482+ )
483+ x = rescaling (image_input )
484+ x = PaliGemmaVitEncoder (
476485 hidden_dim = hidden_dim ,
477486 num_layers = num_layers ,
478487 num_heads = num_heads ,
@@ -481,20 +490,20 @@ def __init__(
481490 image_size = image_size ,
482491 dtype = dtype ,
483492 name = "image_encoder" ,
484- )(image_input )
493+ )(x )
485494 if pooling == "map" :
486- pooled = MultiHeadAttentionPooling (
495+ x = MultiHeadAttentionPooling (
487496 num_heads = num_heads ,
488497 hidden_dim = hidden_dim ,
489498 dtype = dtype ,
490499 name = "pooling" ,
491- )(encoded )
500+ )(x )
492501 elif pooling == "gap" :
493- pooled = ops .mean (encoded , axis = 1 )
502+ x = ops .mean (x , axis = 1 )
494503 elif pooling == "zero" :
495- pooled = encoded [:, 0 ]
504+ x = x [:, 0 ]
496505 elif pooling is None :
497- pooled = encoded
506+ x = x
498507 else :
499508 raise ValueError (
500509 "Invalid value for argument `pooling`. "
@@ -506,7 +515,7 @@ def __init__(
506515 activation = classifier_activation ,
507516 dtype = dtype ,
508517 name = "image_classifier" ,
509- )(pooled )
518+ )(x )
510519 super ().__init__ (
511520 inputs = image_input ,
512521 outputs = outputs ,
@@ -521,6 +530,7 @@ def __init__(
521530 self .pooling = pooling
522531 self .num_classes = num_classes
523532 self .image_size = image_size
533+ self .include_rescaling = include_rescaling
524534 self .patch_size = patch_size
525535 self .classifier_activation = keras .activations .get (
526536 classifier_activation
@@ -541,6 +551,7 @@ def get_config(self):
541551 self .classifier_activation
542552 ),
543553 "image_size" : self .image_size ,
554+ "include_rescaling" : self .include_rescaling ,
544555 "patch_size" : self .patch_size ,
545556 }
546557 )
0 commit comments