77
88@keras_hub_export ("keras_hub.models.MoondreamBackbone" )
99class MoondreamBackbone (Backbone ):
10- def __init__ ( self , vision_encoder , text_decoder , projection_dim = 2048 , ** kwargs ):
11- super (). __init__ ( ** kwargs )
10+ """
11+ The Moondream Backbone model.
1212
13- self .vision_encoder = vision_encoder
14- self .text_decoder = text_decoder
13+ This model connects a vision encoder (SigLIP) and a text decoder (Phi-1.5)
14+ using a projection layer. It is designed for vision-language tasks where
15+ image features are projected into the text embedding space.
1516
16- # The Connector
17- self .vision_projection = keras .layers .Dense (
18- projection_dim , name = "vision_projection"
17+ Args:
18+ vision_encoder: A Keras model (e.g., SigLIP). The vision encoder
19+ responsible for processing input images.
20+ text_decoder: A Keras model (e.g., Phi-1.5). The text decoder
21+ responsible for generating text tokens.
22+ projection_dim: int. The dimension to project image features into.
23+ Defaults to `2048`.
24+ **kwargs: Standard Keras keyword arguments.
25+
26+ Example:
27+ ```python
28+ import keras
29+ import numpy as np
30+ from keras_hub.src.models.moondream.moondream_backbone import (
31+ MoondreamBackbone
32+ )
33+
34+ # 1. Create Mock Encoders
35+ # Vision Encoder: Maps (378, 378, 3) -> (729, 1152)
36+ image_input = keras.Input(shape=(378, 378, 3))
37+ vision_output = keras.layers.Lambda(
38+ lambda x: keras.ops.ones((keras.ops.shape(x)[0], 729, 1152))
39+ )(image_input)
40+ vision_encoder = keras.Model(inputs=image_input, outputs=vision_output)
41+
42+ # Text Decoder: Maps (Seq,) -> (Seq, 2048)
43+ text_input = keras.Input(shape=(None,), dtype="int32")
44+ text_output = keras.layers.Lambda(
45+ lambda x: keras.ops.ones(
46+ (keras.ops.shape(x)[0], keras.ops.shape(x)[1], 2048)
1947 )
48+ )(text_input)
49+ text_decoder = keras.Model(inputs=text_input, outputs=text_output)
50+
51+ # Helper for embeddings
52+ text_decoder.get_input_embeddings = lambda x: keras.layers.Embedding(
53+ 50000, 2048
54+ )(x)
2055
21- def call (self , inputs ):
22- images = inputs ["images" ]
23- token_ids = inputs ["token_ids" ]
24- padding_mask = inputs ["padding_mask" ]
56+ # 2. Instantiate Backbone
57+ backbone = MoondreamBackbone(
58+ vision_encoder=vision_encoder,
59+ text_decoder=text_decoder,
60+ projection_dim=2048
61+ )
2562
26- # 1. Image Features
27- image_features = self .vision_encoder (images )
63+ # 3. Run Forward Pass
64+ inputs = {
65+ "images": np.random.rand(2, 378, 378, 3),
66+ "token_ids": np.random.randint(0, 50000, (2, 10)),
67+ "padding_mask": np.ones((2, 10))
68+ }
69+ outputs = backbone(inputs)
70+ ```
71+ """
72+
73+ def __init__ (
74+ self , vision_encoder , text_decoder , projection_dim = 2048 , ** kwargs
75+ ):
76+ images = keras .Input (shape = (None , None , 3 ), name = "images" )
77+ token_ids = keras .Input (shape = (None ,), dtype = "int32" , name = "token_ids" )
78+ padding_mask = keras .Input (
79+ shape = (None ,), dtype = "int32" , name = "padding_mask"
80+ )
2881
29- # 2. Project
82+ inputs = {
83+ "images" : images ,
84+ "token_ids" : token_ids ,
85+ "padding_mask" : padding_mask ,
86+ }
87+
88+ image_features = vision_encoder (images )
89+
90+ self .vision_projection = keras .layers .Dense (
91+ projection_dim , name = "vision_projection"
92+ )
3093 projected_images = self .vision_projection (image_features )
3194
32- # 3. Text Embeddings
33- text_embeddings = self .text_decoder .get_input_embeddings (token_ids )
95+ text_embeddings = text_decoder .get_input_embeddings (token_ids )
3496
35- # 4. Concatenate
3697 combined_embeddings = ops .concatenate (
3798 [projected_images , text_embeddings ], axis = 1
3899 )
39100
40- # 5. Masking
41101 batch_size = ops .shape (images )[0 ]
42102 num_patches = ops .shape (projected_images )[1 ]
43103
44- image_mask = ops .ones ((batch_size , num_patches ), dtype = "bool" )
104+ # Use int32 to match padding_mask dtype
105+ image_mask = ops .ones ((batch_size , num_patches ), dtype = "int32" )
45106 combined_mask = ops .concatenate ([image_mask , padding_mask ], axis = 1 )
46107
47- # 6. Decoder Pass
48- # Now compatible with our Subclass Mock Decoder
49- outputs = self . text_decoder (
108+ # We set inputs=None because we are passing calculated embeddings
109+ # directly via `decoder_inputs_embeds`.
110+ outputs = text_decoder (
50111 inputs = None ,
51112 decoder_inputs_embeds = combined_embeddings ,
52113 padding_mask = combined_mask ,
53114 )
54115
55- return outputs
116+ super ().__init__ (inputs = inputs , outputs = outputs , ** kwargs )
117+
118+ self .vision_encoder = vision_encoder
119+ self .text_decoder = text_decoder
120+ self .projection_dim = projection_dim
56121
57122 def get_config (self ):
58123 config = super ().get_config ()
@@ -61,8 +126,10 @@ def get_config(self):
61126 "vision_encoder" : keras .saving .serialize_keras_object (
62127 self .vision_encoder
63128 ),
64- "text_decoder" : keras .saving .serialize_keras_object (self .text_decoder ),
65- "projection_dim" : self .vision_projection .units ,
129+ "text_decoder" : keras .saving .serialize_keras_object (
130+ self .text_decoder
131+ ),
132+ "projection_dim" : self .projection_dim ,
66133 }
67134 )
68135 return config
0 commit comments