Skip to content

Commit 7dba12e

Browse files
committed
docs: Add docstrings and usage example to MoondreamPreprocessor
1 parent 3bba424 commit 7dba12e

File tree

5 files changed

+239
-53
lines changed

5 files changed

+239
-53
lines changed

keras_hub/api/models/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,15 @@
458458
from keras_hub.src.models.mobilenetv5.mobilenetv5_image_classifier_preprocessor import (
459459
MobileNetV5ImageClassifierPreprocessor as MobileNetV5ImageClassifierPreprocessor,
460460
)
461+
from keras_hub.src.models.moondream.moondream_backbone import (
462+
MoondreamBackbone as MoondreamBackbone,
463+
)
464+
from keras_hub.src.models.moondream.moondream_causal_lm import (
465+
MoondreamCausalLM as MoondreamCausalLM,
466+
)
467+
from keras_hub.src.models.moondream.moondream_preprocessor import (
468+
MoondreamPreprocessor as MoondreamPreprocessor,
469+
)
461470
from keras_hub.src.models.moonshine.moonshine_audio_to_text import (
462471
MoonshineAudioToText as MoonshineAudioToText,
463472
)
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from keras_hub.src.models.moondream.moondream_backbone import MoondreamBackbone
2-
from keras_hub.src.models.moondream.moondream_preprocessor import \
3-
MoondreamPreprocessor
2+
from keras_hub.src.models.moondream.moondream_preprocessor import (
3+
MoondreamPreprocessor,
4+
)

keras_hub/src/models/moondream/moondream_backbone.py

Lines changed: 92 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,52 +7,117 @@
77

88
@keras_hub_export("keras_hub.models.MoondreamBackbone")
99
class MoondreamBackbone(Backbone):
10-
def __init__(self, vision_encoder, text_decoder, projection_dim=2048, **kwargs):
11-
super().__init__(**kwargs)
10+
"""
11+
The Moondream Backbone model.
1212
13-
self.vision_encoder = vision_encoder
14-
self.text_decoder = text_decoder
13+
This model connects a vision encoder (SigLIP) and a text decoder (Phi-1.5)
14+
using a projection layer. It is designed for vision-language tasks where
15+
image features are projected into the text embedding space.
1516
16-
# The Connector
17-
self.vision_projection = keras.layers.Dense(
18-
projection_dim, name="vision_projection"
17+
Args:
18+
vision_encoder: A Keras model (e.g., SigLIP). The vision encoder
19+
responsible for processing input images.
20+
text_decoder: A Keras model (e.g., Phi-1.5). The text decoder
21+
responsible for generating text tokens.
22+
projection_dim: int. The dimension to project image features into.
23+
Defaults to `2048`.
24+
**kwargs: Standard Keras keyword arguments.
25+
26+
Example:
27+
```python
28+
import keras
29+
import numpy as np
30+
from keras_hub.src.models.moondream.moondream_backbone import (
31+
MoondreamBackbone
32+
)
33+
34+
# 1. Create Mock Encoders
35+
# Vision Encoder: Maps (378, 378, 3) -> (729, 1152)
36+
image_input = keras.Input(shape=(378, 378, 3))
37+
vision_output = keras.layers.Lambda(
38+
lambda x: keras.ops.ones((keras.ops.shape(x)[0], 729, 1152))
39+
)(image_input)
40+
vision_encoder = keras.Model(inputs=image_input, outputs=vision_output)
41+
42+
# Text Decoder: Maps (Seq,) -> (Seq, 2048)
43+
text_input = keras.Input(shape=(None,), dtype="int32")
44+
text_output = keras.layers.Lambda(
45+
lambda x: keras.ops.ones(
46+
(keras.ops.shape(x)[0], keras.ops.shape(x)[1], 2048)
1947
)
48+
)(text_input)
49+
text_decoder = keras.Model(inputs=text_input, outputs=text_output)
50+
51+
# Helper for embeddings
52+
text_decoder.get_input_embeddings = lambda x: keras.layers.Embedding(
53+
50000, 2048
54+
)(x)
2055
21-
def call(self, inputs):
22-
images = inputs["images"]
23-
token_ids = inputs["token_ids"]
24-
padding_mask = inputs["padding_mask"]
56+
# 2. Instantiate Backbone
57+
backbone = MoondreamBackbone(
58+
vision_encoder=vision_encoder,
59+
text_decoder=text_decoder,
60+
projection_dim=2048
61+
)
2562
26-
# 1. Image Features
27-
image_features = self.vision_encoder(images)
63+
# 3. Run Forward Pass
64+
inputs = {
65+
"images": np.random.rand(2, 378, 378, 3),
66+
"token_ids": np.random.randint(0, 50000, (2, 10)),
67+
"padding_mask": np.ones((2, 10))
68+
}
69+
outputs = backbone(inputs)
70+
```
71+
"""
72+
73+
def __init__(
74+
self, vision_encoder, text_decoder, projection_dim=2048, **kwargs
75+
):
76+
images = keras.Input(shape=(None, None, 3), name="images")
77+
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
78+
padding_mask = keras.Input(
79+
shape=(None,), dtype="int32", name="padding_mask"
80+
)
2881

29-
# 2. Project
82+
inputs = {
83+
"images": images,
84+
"token_ids": token_ids,
85+
"padding_mask": padding_mask,
86+
}
87+
88+
image_features = vision_encoder(images)
89+
90+
self.vision_projection = keras.layers.Dense(
91+
projection_dim, name="vision_projection"
92+
)
3093
projected_images = self.vision_projection(image_features)
3194

32-
# 3. Text Embeddings
33-
text_embeddings = self.text_decoder.get_input_embeddings(token_ids)
95+
text_embeddings = text_decoder.get_input_embeddings(token_ids)
3496

35-
# 4. Concatenate
3697
combined_embeddings = ops.concatenate(
3798
[projected_images, text_embeddings], axis=1
3899
)
39100

40-
# 5. Masking
41101
batch_size = ops.shape(images)[0]
42102
num_patches = ops.shape(projected_images)[1]
43103

44-
image_mask = ops.ones((batch_size, num_patches), dtype="bool")
104+
# Use int32 to match padding_mask dtype
105+
image_mask = ops.ones((batch_size, num_patches), dtype="int32")
45106
combined_mask = ops.concatenate([image_mask, padding_mask], axis=1)
46107

47-
# 6. Decoder Pass
48-
# Now compatible with our Subclass Mock Decoder
49-
outputs = self.text_decoder(
108+
# We set inputs=None because we are passing calculated embeddings
109+
# directly via `decoder_inputs_embeds`.
110+
outputs = text_decoder(
50111
inputs=None,
51112
decoder_inputs_embeds=combined_embeddings,
52113
padding_mask=combined_mask,
53114
)
54115

55-
return outputs
116+
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
117+
118+
self.vision_encoder = vision_encoder
119+
self.text_decoder = text_decoder
120+
self.projection_dim = projection_dim
56121

57122
def get_config(self):
58123
config = super().get_config()
@@ -61,8 +126,10 @@ def get_config(self):
61126
"vision_encoder": keras.saving.serialize_keras_object(
62127
self.vision_encoder
63128
),
64-
"text_decoder": keras.saving.serialize_keras_object(self.text_decoder),
65-
"projection_dim": self.vision_projection.units,
129+
"text_decoder": keras.saving.serialize_keras_object(
130+
self.text_decoder
131+
),
132+
"projection_dim": self.projection_dim,
66133
}
67134
)
68135
return config

keras_hub/src/models/moondream/moondream_causal_lm.py

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,69 @@
1-
import keras
2-
31
from keras_hub.src.api_export import keras_hub_export
42
from keras_hub.src.models.causal_lm import CausalLM
53
from keras_hub.src.models.moondream.moondream_backbone import MoondreamBackbone
6-
from keras_hub.src.models.moondream.moondream_preprocessor import \
7-
MoondreamPreprocessor
4+
from keras_hub.src.models.moondream.moondream_preprocessor import (
5+
MoondreamPreprocessor,
6+
)
87

98

109
@keras_hub_export("keras_hub.models.MoondreamCausalLM")
1110
class MoondreamCausalLM(CausalLM):
11+
"""
12+
An end-to-end Moondream model for causal language modeling.
13+
14+
This model wraps `MoondreamBackbone` and handles the complete flow from
15+
raw inputs (images + text) to generated text output. It provides a
16+
high-level interface for image captioning and visual question answering.
17+
18+
Args:
19+
backbone: A `MoondreamBackbone` instance. The backbone model that
20+
connects the vision encoder and text decoder.
21+
preprocessor: A `MoondreamPreprocessor` instance. Handles data
22+
preprocessing (tokenization and image resizing).
23+
**kwargs: Standard Keras keyword arguments.
24+
25+
Example:
26+
```python
27+
import keras
28+
import numpy as np
29+
from keras_hub.src.models.moondream.moondream_backbone import (
30+
MoondreamBackbone
31+
)
32+
from keras_hub.src.models.moondream.moondream_causal_lm import (
33+
MoondreamCausalLM
34+
)
35+
36+
# 1. Setup Mock Backbone
37+
images = keras.Input(shape=(None, None, 3), name="images")
38+
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
39+
padding_mask = keras.Input(
40+
shape=(None,), dtype="int32", name="padding_mask"
41+
)
42+
43+
outputs = keras.layers.Dense(2048)(token_ids)
44+
45+
backbone = keras.Model(
46+
inputs={
47+
"images": images,
48+
"token_ids": token_ids,
49+
"padding_mask": padding_mask
50+
},
51+
outputs=outputs
52+
)
53+
54+
# 2. Instantiate CausalLM
55+
model = MoondreamCausalLM(backbone=backbone)
56+
57+
# 3. Run Forward Pass
58+
inputs = {
59+
"images": np.random.rand(2, 378, 378, 3),
60+
"token_ids": np.random.randint(0, 100, (2, 10)),
61+
"padding_mask": np.ones((2, 10))
62+
}
63+
outputs = model(inputs)
64+
```
65+
"""
66+
1267
backbone_cls = MoondreamBackbone
1368
preprocessor_cls = MoondreamPreprocessor
1469

@@ -18,20 +73,14 @@ def __init__(
1873
preprocessor=None,
1974
**kwargs,
2075
):
21-
inputs = getattr(backbone, "input", None)
76+
inputs = backbone.input
77+
outputs = backbone(inputs)
2278

23-
super().__init__(**kwargs)
79+
super().__init__(
80+
inputs=inputs,
81+
outputs=outputs,
82+
**kwargs,
83+
)
2484

25-
# Manually set the attributes
2685
self.backbone = backbone
2786
self.preprocessor = preprocessor
28-
29-
# Set tensor spec if available
30-
if inputs is not None:
31-
self.input_tensor_spec = inputs
32-
33-
def call(self, inputs, training=False):
34-
if self.backbone is None:
35-
raise ValueError("Backbone not initialized")
36-
x = self.backbone(inputs)
37-
return x

keras_hub/src/models/moondream/moondream_preprocessor.py

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,63 @@
66

77
@keras_hub_export("keras_hub.models.MoondreamPreprocessor")
88
class MoondreamPreprocessor(CausalLMPreprocessor):
9+
"""
10+
Moondream Causal LM Preprocessor.
11+
12+
This class handles the preprocessing of images and text for the Moondream
13+
model. It combines image resizing/rescaling logic with text tokenization
14+
to prepare inputs for the model.
15+
16+
Args:
17+
tokenizer: The tokenizer to be used for text inputs.
18+
image_converter: An optional layer or callable for image preprocessing
19+
(e.g., resizing, normalization).
20+
sequence_length: int. The context length for tokenization.
21+
Defaults to 1024.
22+
add_start_token: bool. Whether to add the start token.
23+
Defaults to True.
24+
add_end_token: bool. Whether to add the end token.
25+
Defaults to True.
26+
**kwargs: Standard Keras keyword arguments.
27+
28+
Example:
29+
```python
30+
import keras
31+
import numpy as np
32+
from keras_hub.src.models.moondream.moondream_preprocessor import (
33+
MoondreamPreprocessor
34+
)
35+
36+
# 1. Create a Mock Tokenizer
37+
class MockTokenizer:
38+
def __call__(self, x):
39+
return keras.ops.convert_to_tensor([[1, 2, 3]] * len(x))
40+
def detokenize(self, x):
41+
return x
42+
pass
43+
44+
tokenizer = MockTokenizer()
45+
46+
# 2. Create an Image Converter
47+
image_converter = keras.layers.Resizing(height=378, width=378)
48+
49+
# 3. Instantiate Preprocessor
50+
preprocessor = MoondreamPreprocessor(
51+
tokenizer=tokenizer,
52+
image_converter=image_converter,
53+
sequence_length=128
54+
)
55+
56+
# 4. Preprocess Data
57+
inputs = {
58+
"images": np.random.randint(0, 255, (2, 500, 500, 3)),
59+
"text": ["Describe this image.", "What is in the photo?"]
60+
}
61+
62+
outputs = preprocessor(inputs)
63+
```
64+
"""
65+
966
def __init__(
1067
self,
1168
tokenizer,
@@ -25,23 +82,26 @@ def __init__(
2582
self.image_converter = image_converter
2683

2784
def call(self, x, y=None, sample_weight=None):
28-
output = super().call(x, y, sample_weight)
85+
if isinstance(x, dict):
86+
text_input = x.get("text", "")
87+
images = x.get("images", None)
88+
else:
89+
text_input = x
90+
images = None
91+
92+
output = super().call(text_input, y=y, sample_weight=sample_weight)
2993

30-
# 1. Identify the input dictionary from the output
31-
# If output is a tuple (x, y, sw), the first element is the input dict.
3294
if isinstance(output, tuple):
3395
x_out = output[0]
3496
else:
3597
x_out = output
3698

37-
# 2. Type Guard for Pylance
38-
# We explicitly check if x_out IS a dictionary.
39-
# This stops Pylance from thinking it might be a Tuple/List.
40-
if isinstance(x_out, dict) and isinstance(x, dict) and "images" in x:
41-
images = x["images"]
99+
if images is not None:
42100
if self.image_converter:
43101
images = self.image_converter(images)
44-
x_out["images"] = images
102+
103+
if isinstance(x_out, dict):
104+
x_out["images"] = images
45105

46106
return output
47107

0 commit comments

Comments
 (0)