|
15 | 15 | """"Module for encoder layers.""" |
16 | 16 |
|
17 | 17 | import jax |
18 | | -from flax import linen as nn |
| 18 | +from flax import nnx |
19 | 19 | from jax.sharding import Mesh |
20 | 20 |
|
21 | 21 | from MaxText.common_types import Config |
22 | | -from MaxText.layers import quantizations |
| 22 | +from MaxText.layers import nnx_wrappers |
| 23 | +from MaxText.layers import initializers |
23 | 24 |
|
24 | | -# Type alias for cleaner type hints |
25 | | -Quant = quantizations.AqtQuantization |
26 | 25 |
|
27 | | - |
28 | | -class VisionEncoder(nn.Module): |
| 26 | +class VisionEncoder(nnx.Module): |
29 | 27 | """Vision encoder to encode images into soft tokens.""" |
30 | 28 |
|
31 | | - config: Config |
32 | | - mesh: Mesh |
33 | | - |
34 | | - def setup(self): |
35 | | - self.vision_encoder_layer = self.get_vision_encoder_layers() |
| 29 | + def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs): |
| 30 | + self.config = config |
| 31 | + self.mesh = mesh |
| 32 | + self.rngs = rngs |
| 33 | + self.encoder_name, self.projector_name = self._setup_vision_encoder_layers() |
36 | 34 |
|
37 | | - def get_vision_encoder_layers(self): |
38 | | - """Get vision encoder layers specific to the model, classes of nn.Module type.""" |
| 35 | + def _setup_vision_encoder_layers(self): |
| 36 | + """Setup vision encoder layers specific to the model, instantiate NNX modules.""" |
39 | 37 | if self.config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]: |
40 | 38 | from MaxText.layers import gemma3 # pylint: disable=import-outside-toplevel |
41 | 39 |
|
42 | | - return [gemma3.gemma3visionencoder_as_linen, gemma3.visionembedder_as_linen] |
| 40 | + encoder_name = "Gemma3VisionEncoderLayer_0" |
| 41 | + projector_name = "VisionEmbedder_0" |
| 42 | + setattr(self, encoder_name, gemma3.Gemma3VisionEncoderLayer(config=self.config, mesh=self.mesh, rngs=self.rngs)) |
| 43 | + setattr(self, projector_name, gemma3.VisionEmbedder(config=self.config, mesh=self.mesh, rngs=self.rngs)) |
| 44 | + return encoder_name, projector_name |
43 | 45 | elif self.config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]: |
44 | 46 | from MaxText.layers import llama4 # pylint: disable=import-outside-toplevel |
45 | 47 |
|
46 | | - return [llama4.llama4visionmodel_as_linen, llama4.llama4multimodalprojector_as_linen] |
| 48 | + encoder_name = "Llama4VisionModel_0" |
| 49 | + projector_name = "Llama4MultiModalProjector_0" |
| 50 | + setattr(self, encoder_name, llama4.Llama4VisionModel(config=self.config, mesh=self.mesh, rngs=self.rngs)) |
| 51 | + setattr(self, projector_name, llama4.Llama4MultiModalProjector(config=self.config, mesh=self.mesh, rngs=self.rngs)) |
| 52 | + return encoder_name, projector_name |
47 | 53 | elif self.config.model_name in ["qwen3-omni-30b-a3b"]: |
48 | 54 | from MaxText.layers import qwen3 # pylint: disable=import-outside-toplevel |
49 | 55 |
|
50 | | - return [qwen3.qwen3omni_visionencoder_as_linen, qwen3.qwen3omni_visionprojector_as_linen] |
| 56 | + encoder_name = "Qwen3OmniMoeVisionEncoder_0" |
| 57 | + projector_name = "Qwen3OmniMoeVisionProjector_0" |
| 58 | + setattr(self, encoder_name, qwen3.Qwen3OmniMoeVisionEncoder(config=self.config, mesh=self.mesh, rngs=self.rngs)) |
| 59 | + setattr(self, projector_name, qwen3.Qwen3OmniMoeVisionProjector(config=self.config, rngs=self.rngs)) |
| 60 | + return encoder_name, projector_name |
51 | 61 | else: |
52 | 62 | raise ValueError(f"No VisionEncoder implemented for {self.config.model_name} yet") |
53 | 63 |
|
54 | | - @nn.compact |
55 | 64 | def __call__(self, input_images, deterministic=False): |
56 | | - cfg = self.config |
57 | | - mesh = self.mesh |
58 | 65 | # vision encoder output, frozen params in many cases |
59 | | - embeddings = self.vision_encoder_layer[0](config=cfg, mesh=mesh)(input_images, deterministic=deterministic) |
60 | | - if cfg.model_name in ["qwen3-omni-30b-a3b"]: |
61 | | - embeddings = embeddings[0] # todo(eitanporat) add deepstack support |
| 66 | + encoder = getattr(self, self.encoder_name) |
| 67 | + embeddings = encoder(input_images, deterministic=deterministic) |
62 | 68 |
|
63 | | - if cfg.freeze_vision_encoder_params: |
| 69 | + if self.config.freeze_vision_encoder_params: |
64 | 70 | embeddings = jax.lax.stop_gradient(embeddings) |
65 | 71 |
|
66 | | - if len(self.vision_encoder_layer) > 1: |
67 | | - # vision embedder / projection layer, not frozen in most cases, trained / finetuned together with main model |
68 | | - embeddings = self.vision_encoder_layer[1](config=cfg, mesh=mesh)(embeddings) |
| 72 | + # vision embedder / projection layer, not frozen in most cases, trained / finetuned together with main model |
| 73 | + projector = getattr(self, self.projector_name) |
| 74 | + embeddings = projector(embeddings) |
| 75 | + |
69 | 76 | return embeddings |
| 77 | + |
| 78 | + |
| 79 | +def vision_encoder_as_linen( |
| 80 | + config: Config, |
| 81 | + mesh: Mesh, |
| 82 | +): |
| 83 | + """Creates a VisionEncoder module.""" |
| 84 | + module = nnx_wrappers.to_linen( |
| 85 | + VisionEncoder, |
| 86 | + config=config, |
| 87 | + mesh=mesh, |
| 88 | + name="vision_encoder", |
| 89 | + abstract_init=False, |
| 90 | + metadata_fn=initializers.variable_to_logically_partitioned, |
| 91 | + ) |
| 92 | + return module |
0 commit comments