Skip to content

Commit 6ce880e

Browse files
Merge pull request #2769 from AI-Hypercomputer:hengtaoguo-nnx-vis
PiperOrigin-RevId: 839920093
2 parents 0ea9b55 + 43967ce commit 6ce880e

File tree

3 files changed

+53
-30
lines changed

3 files changed

+53
-30
lines changed

src/MaxText/layers/encoders.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,55 +15,78 @@
1515
""""Module for encoder layers."""
1616

1717
import jax
18-
from flax import linen as nn
18+
from flax import nnx
1919
from jax.sharding import Mesh
2020

2121
from MaxText.common_types import Config
22-
from MaxText.layers import quantizations
22+
from MaxText.layers import nnx_wrappers
23+
from MaxText.layers import initializers
2324

24-
# Type alias for cleaner type hints
25-
Quant = quantizations.AqtQuantization
2625

27-
28-
class VisionEncoder(nn.Module):
26+
class VisionEncoder(nnx.Module):
2927
"""Vision encoder to encode images into soft tokens."""
3028

31-
config: Config
32-
mesh: Mesh
33-
34-
def setup(self):
35-
self.vision_encoder_layer = self.get_vision_encoder_layers()
29+
def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs):
30+
self.config = config
31+
self.mesh = mesh
32+
self.rngs = rngs
33+
self.encoder_name, self.projector_name = self._setup_vision_encoder_layers()
3634

37-
def get_vision_encoder_layers(self):
38-
"""Get vision encoder layers specific to the model, classes of nn.Module type."""
35+
def _setup_vision_encoder_layers(self):
36+
"""Setup vision encoder layers specific to the model, instantiate NNX modules."""
3937
if self.config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
4038
from MaxText.layers import gemma3 # pylint: disable=import-outside-toplevel
4139

42-
return [gemma3.gemma3visionencoder_as_linen, gemma3.visionembedder_as_linen]
40+
encoder_name = "Gemma3VisionEncoderLayer_0"
41+
projector_name = "VisionEmbedder_0"
42+
setattr(self, encoder_name, gemma3.Gemma3VisionEncoderLayer(config=self.config, mesh=self.mesh, rngs=self.rngs))
43+
setattr(self, projector_name, gemma3.VisionEmbedder(config=self.config, mesh=self.mesh, rngs=self.rngs))
44+
return encoder_name, projector_name
4345
elif self.config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
4446
from MaxText.layers import llama4 # pylint: disable=import-outside-toplevel
4547

46-
return [llama4.llama4visionmodel_as_linen, llama4.llama4multimodalprojector_as_linen]
48+
encoder_name = "Llama4VisionModel_0"
49+
projector_name = "Llama4MultiModalProjector_0"
50+
setattr(self, encoder_name, llama4.Llama4VisionModel(config=self.config, mesh=self.mesh, rngs=self.rngs))
51+
setattr(self, projector_name, llama4.Llama4MultiModalProjector(config=self.config, mesh=self.mesh, rngs=self.rngs))
52+
return encoder_name, projector_name
4753
elif self.config.model_name in ["qwen3-omni-30b-a3b"]:
4854
from MaxText.layers import qwen3 # pylint: disable=import-outside-toplevel
4955

50-
return [qwen3.qwen3omni_visionencoder_as_linen, qwen3.qwen3omni_visionprojector_as_linen]
56+
encoder_name = "Qwen3OmniMoeVisionEncoder_0"
57+
projector_name = "Qwen3OmniMoeVisionProjector_0"
58+
setattr(self, encoder_name, qwen3.Qwen3OmniMoeVisionEncoder(config=self.config, mesh=self.mesh, rngs=self.rngs))
59+
setattr(self, projector_name, qwen3.Qwen3OmniMoeVisionProjector(config=self.config, rngs=self.rngs))
60+
return encoder_name, projector_name
5161
else:
5262
raise ValueError(f"No VisionEncoder implemented for {self.config.model_name} yet")
5363

54-
@nn.compact
5564
def __call__(self, input_images, deterministic=False):
56-
cfg = self.config
57-
mesh = self.mesh
5865
# vision encoder output, frozen params in many cases
59-
embeddings = self.vision_encoder_layer[0](config=cfg, mesh=mesh)(input_images, deterministic=deterministic)
60-
if cfg.model_name in ["qwen3-omni-30b-a3b"]:
61-
embeddings = embeddings[0] # todo(eitanporat) add deepstack support
66+
encoder = getattr(self, self.encoder_name)
67+
embeddings = encoder(input_images, deterministic=deterministic)
6268

63-
if cfg.freeze_vision_encoder_params:
69+
if self.config.freeze_vision_encoder_params:
6470
embeddings = jax.lax.stop_gradient(embeddings)
6571

66-
if len(self.vision_encoder_layer) > 1:
67-
# vision embedder / projection layer, not frozen in most cases, trained / finetuned together with main model
68-
embeddings = self.vision_encoder_layer[1](config=cfg, mesh=mesh)(embeddings)
72+
# vision embedder / projection layer, not frozen in most cases, trained / finetuned together with main model
73+
projector = getattr(self, self.projector_name)
74+
embeddings = projector(embeddings)
75+
6976
return embeddings
77+
78+
79+
def vision_encoder_as_linen(
80+
config: Config,
81+
mesh: Mesh,
82+
):
83+
"""Creates a VisionEncoder module."""
84+
module = nnx_wrappers.to_linen(
85+
VisionEncoder,
86+
config=config,
87+
mesh=mesh,
88+
name="vision_encoder",
89+
abstract_init=False,
90+
metadata_fn=initializers.variable_to_logically_partitioned,
91+
)
92+
return module

src/MaxText/layers/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from MaxText.layers import nnx_wrappers
3333
from MaxText.layers.decoders import Decoder
3434
from MaxText.layers.embeddings import Embed, embed_as_linen
35-
from MaxText.layers.encoders import VisionEncoder
35+
from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen
3636
from MaxText.layers.quantizations import AqtQuantization as Quant
3737
from MaxText.layers.multi_token_prediction import MultiTokenPredictionBlock
3838
from MaxText.sharding import all_gather_over_fsdp
@@ -85,7 +85,7 @@ def setup(self):
8585
config=cfg,
8686
mesh=self.mesh,
8787
)
88-
self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh) if cfg.use_multimodal else None
88+
self.vision_encoder = vision_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_multimodal else None
8989
self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
9090
# If MTP is enabled via config, set up the MTP block.
9191
if self.config.mtp_num_layers > 0:
@@ -304,7 +304,7 @@ def __init__(self, config: Config, mesh: Mesh, quant: Quant, *, model_mode: str
304304
config=cfg,
305305
rngs=rngs,
306306
)
307-
self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh) if cfg.use_multimodal else None
307+
self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_multimodal else None
308308

309309
decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
310310
self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs)

tests/integration_tests/vision_encoder_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_image_embedding_gemma3_4b_tpu(self):
8484
input_images = images[jnp.newaxis, jnp.newaxis, ...] # pytype: disable=unsupported-operands
8585

8686
# Initialize only the vision encoder part and extract the corresponding params
87-
vision_encoder_model = models.VisionEncoder(config)
87+
vision_encoder_model = models.VisionEncoder(config, engine.mesh, rngs=engine.rng)
8888
vision_encoder_params = params["params"]["vision_encoder"]
8989

9090
# Apply the vision encoder to get the image embeddings

0 commit comments

Comments
 (0)