77
88@keras_hub_export ("keras_hub.models.MoondreamBackbone" )
99class MoondreamBackbone (Backbone ):
10- def __init__ (self , vision_encoder , text_decoder , projection_dim = 2048 , ** kwargs ):
11- super ().__init__ (** kwargs )
10+ """
11+ The Moondream Backbone model.
12+
13+ This model connects a vision encoder (SigLIP) and a text decoder (Phi-1.5)
14+ using a projection layer. It is designed for vision-language tasks where
15+ image features are projected into the text embedding space.
16+
17+ Args:
18+ vision_encoder: A Keras model (e.g., SigLIP). The vision encoder
19+ responsible for processing input images.
20+ text_decoder: A Keras model (e.g., Phi-1.5). The text decoder
21+ responsible for generating text tokens.
22+ projection_dim: int. The dimension to project image features into.
23+ Defaults to `2048`.
24+ **kwargs: Standard Keras keyword arguments.
25+
26+ Example:
27+ ```python
28+ import keras
29+ import numpy as np
30+ from keras_hub.src.models.moondream.moondream_backbone import (
31+ MoondreamBackbone
32+ )
33+
34+ # 1. Create Mock Encoders
35+ # Vision Encoder: Maps (378, 378, 3) -> (729, 1152)
36+ image_input = keras.Input(shape=(378, 378, 3))
37+ vision_output = keras.layers.Lambda(
38+ lambda x: keras.ops.ones((keras.ops.shape(x)[0], 729, 1152))
39+ )(image_input)
40+ vision_encoder = keras.Model(inputs=image_input, outputs=vision_output)
41+
42+ # Text Decoder: Maps (Seq,) -> (Seq, 2048)
43+ text_input = keras.Input(shape=(None,), dtype="int32")
44+ text_output = keras.layers.Lambda(
45+ lambda x: keras.ops.ones(
46+ (keras.ops.shape(x)[0], keras.ops.shape(x)[1], 2048)
47+ )
48+ )(text_input)
49+ text_decoder = keras.Model(inputs=text_input, outputs=text_output)
50+
51+ # Helper for embeddings
52+ text_decoder.get_input_embeddings = lambda x: keras.layers.Embedding(
53+ 50000, 2048
54+ )(x)
1255
56+ # 2. Instantiate Backbone
57+ backbone = MoondreamBackbone(
58+ vision_encoder=vision_encoder,
59+ text_decoder=text_decoder,
60+ projection_dim=2048
61+ )
62+
63+ # 3. Run Forward Pass
64+ inputs = {
65+ "images": np.random.rand(2, 378, 378, 3),
66+ "token_ids": np.random.randint(0, 50000, (2, 10)),
67+ "padding_mask": np.ones((2, 10))
68+ }
69+ outputs = backbone(inputs)
70+ ```
71+ """
72+
73+ def __init__ (
74+ self , vision_encoder , text_decoder , projection_dim = 2048 , ** kwargs
75+ ):
76+ super ().__init__ (** kwargs )
1377 self .vision_encoder = vision_encoder
1478 self .text_decoder = text_decoder
79+ self .projection_dim = projection_dim
1580
16- # The Connector
1781 self .vision_projection = keras .layers .Dense (
1882 projection_dim , name = "vision_projection"
1983 )
2084
21- def call (self , inputs ):
22- images = inputs ["images" ]
23- token_ids = inputs ["token_ids" ]
24- padding_mask = inputs ["padding_mask" ]
85+ images = keras .Input (shape = (None , None , 3 ), name = "images" )
86+ token_ids = keras .Input (shape = (None ,), dtype = "int32" , name = "token_ids" )
87+ padding_mask = keras .Input (
88+ shape = (None ,), dtype = "int32" , name = "padding_mask"
89+ )
90+
91+ inputs = {
92+ "images" : images ,
93+ "token_ids" : token_ids ,
94+ "padding_mask" : padding_mask ,
95+ }
2596
26- # 1. Image Features
2797 image_features = self .vision_encoder (images )
28-
29- # 2. Project
3098 projected_images = self .vision_projection (image_features )
3199
32- # 3. Text Embeddings
33100 text_embeddings = self .text_decoder .get_input_embeddings (token_ids )
34101
35- # 4. Concatenate
36102 combined_embeddings = ops .concatenate (
37103 [projected_images , text_embeddings ], axis = 1
38104 )
39105
40- # 5. Masking
41106 batch_size = ops .shape (images )[0 ]
42107 num_patches = ops .shape (projected_images )[1 ]
43108
44- image_mask = ops .ones ((batch_size , num_patches ), dtype = "bool " )
109+ image_mask = ops .ones ((batch_size , num_patches ), dtype = "int32 " )
45110 combined_mask = ops .concatenate ([image_mask , padding_mask ], axis = 1 )
46111
47- # 6. Decoder Pass
48- # Now compatible with our Subclass Mock Decoder
49112 outputs = self .text_decoder (
50113 inputs = None ,
51114 decoder_inputs_embeds = combined_embeddings ,
52115 padding_mask = combined_mask ,
53116 )
54117
55- return outputs
118+ super (MoondreamBackbone , self ).__init__ (
119+ inputs = inputs , outputs = outputs , ** kwargs
120+ )
56121
57122 def get_config (self ):
58123 config = super ().get_config ()
@@ -61,8 +126,10 @@ def get_config(self):
61126 "vision_encoder" : keras .saving .serialize_keras_object (
62127 self .vision_encoder
63128 ),
64- "text_decoder" : keras .saving .serialize_keras_object (self .text_decoder ),
65- "projection_dim" : self .vision_projection .units ,
129+ "text_decoder" : keras .saving .serialize_keras_object (
130+ self .text_decoder
131+ ),
132+ "projection_dim" : self .projection_dim ,
66133 }
67134 )
68135 return config
0 commit comments