diff --git a/eole/bin/convert/HF_mappings.py b/eole/bin/convert/HF_mappings.py index be30524a..be5167f3 100644 --- a/eole/bin/convert/HF_mappings.py +++ b/eole/bin/convert/HF_mappings.py @@ -37,11 +37,81 @@ MODEL_OVERRIDES = { "LlamaForCausalLM": {}, # default "MistralForCausalLM": {}, - "Qwen2ForCausalLM": { + "Bagel": { # bagel's arch is actually Qwen2, but requires specific mapping + "decoder_layer_prefix": "language_model.model.layers.", + "decoder.layer_norm.weight": "language_model.model.norm.weight", + "decoder.layer_norm_moe_gen.weight": "language_model.model.norm_moe_gen.weight", + "encoder_layer_prefix": "vit_model.vision_model.encoder.layers.", + "encoder.patch_conv.weight": "vit_model.vision_model.embeddings.patch_embedding.weight", + "encoder.patch_conv.bias": "vit_model.vision_model.embeddings.patch_embedding.bias", + "encoder.position_embeddings.weight": "vit_model.vision_model.embeddings.position_embedding.weight", + "encoder.post_layernorm.weight": "vit_model.vision_model.post_layernorm.weight", + "encoder.post_layernorm.bias": "vit_model.vision_model.post_layernorm.bias", + "tgt_emb.embeddings.weight": "language_model.model.embed_tokens.weight", + "generator.weight": "language_model.lm_head.weight", + # vision_adapter + "adapter.w_in.weight": "connector.fc1.weight", + "adapter.w_in.bias": "connector.fc1.bias", + "adapter.w_out.weight": "connector.fc2.weight", + "adapter.w_out.bias": "connector.fc2.bias", + # additional stuff, mostly replicated as-is for now + "vit_pos_embed.pos_embed": "vit_pos_embed.pos_embed", + "latent_pos_embed.pos_embed": "latent_pos_embed.pos_embed", + "time_embedder.mlp.0.weight": "time_embedder.mlp.0.weight", + "time_embedder.mlp.0.bias": "time_embedder.mlp.0.bias", + "time_embedder.mlp.2.weight": "time_embedder.mlp.2.weight", + "time_embedder.mlp.2.bias": "time_embedder.mlp.2.bias", + "vae2llm.weight": "vae2llm.weight", + "vae2llm.bias": "vae2llm.bias", + "llm2vae.weight": "llm2vae.weight", + "llm2vae.bias": "llm2vae.bias", + # TODO: not sure how to properly grab VAE stuff + "decoder": { + ".self_attn.q_norm.": ".self_attn.q_norm.", + ".self_attn.k_norm.": ".self_attn.k_norm.", + # MOE GEN (simplify with loop?) + ".self_attn.q_norm_moe_gen.": ".self_attn.q_norm_moe_gen.", + ".self_attn.k_norm_moe_gen.": ".self_attn.k_norm_moe_gen.", + ".self_attn.linear_query_moe_gen.": ".self_attn.q_proj_moe_gen.", + ".self_attn.linear_keys_moe_gen.": ".self_attn.k_proj_moe_gen.", + ".self_attn.linear_values_moe_gen.": ".self_attn.v_proj_moe_gen.", + ".self_attn.final_linear_moe_gen.": ".self_attn.o_proj_moe_gen.", + ".mlp_moe_gen.gate_up_proj.": ".mlp_moe_gen.gate_proj.", + ".mlp_moe_gen.down_proj.": ".mlp_moe_gen.down_proj.", + ".mlp_moe_gen.up_proj.": ".mlp_moe_gen.up_proj.", + ".input_layernorm_moe_gen.": ".input_layernorm_moe_gen.", + ".post_attention_layernorm_moe_gen.": ".post_attention_layernorm_moe_gen.", + }, + "encoder": { + ".self_attn.linear_query.": ".self_attn.q_proj.", + ".self_attn.linear_keys.": ".self_attn.k_proj.", + ".self_attn.linear_values.": ".self_attn.v_proj.", + ".self_attn.final_linear.": ".self_attn.out_proj.", + ".mlp.gate_up_proj.": ".mlp.fc1.", + ".mlp.down_proj.": ".mlp.fc2.", + ".input_layernorm.": ".layer_norm1.", + ".post_attention_layernorm.": ".layer_norm2.", + }, "config": { "add_qkvbias": True, "add_final_linear_bias": False, - } + "adapter": "bagel", + "vit_position_embeddings": True, + "decoder": { + "query_norm": True, + "key_norm": True, + }, + "encoder": { + "mlp_activation_fn": "gelu-tanh", + "add_ffnbias": True, + "add_final_linear_bias": True, + "add_qkvbias": True, + "layer_norm": "standard", + "patch_conv_bias": True, + "patch_conv_linear": True, + "layernorm_pre": False, # implies post layernorm + }, + }, }, "Qwen3ForCausalLM": { "decoder": { @@ -386,6 +456,7 @@ "Mistral3ForConditionalGeneration": VisionTransformerLMModelConfig, "Gemma3ForConditionalGeneration": VisionTransformerLMModelConfig, "M2M100ForConditionalGeneration": TransformerModelConfig, + "Bagel": VisionTransformerLMModelConfig, }, ) diff --git a/eole/bin/convert/convert_HF.py b/eole/bin/convert/convert_HF.py index f9d0d32d..9c1de51d 100755 --- a/eole/bin/convert/convert_HF.py +++ b/eole/bin/convert/convert_HF.py @@ -66,6 +66,8 @@ class HuggingfaceFiles: wmap_path: Optional[str] = None model_path: Optional[str] = None special_tokens_json: Optional[str] = None + vision_config_path: Optional[str] = None + ae_model_path: Optional[str] = None # Unified dictionary to cache loaded files _loaded_files: dict = field(default_factory=dict, init=False) @@ -117,7 +119,7 @@ def download_file_from_hub(file_name, required=True): # Fetch required and optional files paths = { - "config_path": get_file_fn("config.json", required=True), + "config_path": get_file_fn("llm_config.json", required=False) or get_file_fn("config.json", required=True), "tokenizer_config_json": get_file_fn("tokenizer_config.json", required=True), "generation_config_json": get_file_fn("generation_config.json", required=False), "tokenizer_model": get_file_fn("tokenizer.model", required=False) @@ -126,8 +128,11 @@ def download_file_from_hub(file_name, required=True): "wmap_path": get_file_fn("model.safetensors.index.json", required=False) or get_file_fn("pytorch_model.bin.index.json", required=False), "model_path": get_file_fn("model.safetensors", required=False) - or get_file_fn("pytorch_model.bin", required=False), + or get_file_fn("pytorch_model.bin", required=False) + or get_file_fn("ema.safetensors", required=False), "special_tokens_json": get_file_fn("special_tokens_map.json", required=False), + "vision_config_path": get_file_fn("vit_config.json", required=False), + "ae_model_path": get_file_fn("ae.safetensors", required=False), } return cls(**paths, model_dir=args.model_dir, token=args.token) @@ -158,6 +163,8 @@ def __getattr__(self, name): @property def arch(self): + if self.model_dir == "ByteDance-Seed/BAGEL-7B-MoT": + return "Bagel" return self.config["architectures"][0] @property @@ -270,9 +277,11 @@ def build_config_dict(hf): arch = hf.arch print("Architecture: ", arch) - vision_config = config.get("vision_config", None) - other_config = config # save what is not text/vision for later use - config = config.get("text_config", config) + vision_config = getattr(hf, "vision_config", None) + if vision_config is None: + vision_config = config.get("vision_config", None) + other_config = config # save what is not text/vision for later use + config = config.get("text_config", config) model_config = {} training_config = {} @@ -289,6 +298,7 @@ def build_config_dict(hf): "transformer_ff_moe": config.get("moe_intermediate_size", None), "mlp_activation_fn": ACT_TABLE[arch], "layer_norm": LN_TABLE[arch], + # TODO: this can break encoder (e.g. bagel) "heads_kv": config.get("multi_query", False) or config.get( "num_key_value_heads", @@ -351,6 +361,24 @@ def build_config_dict(hf): model_config["projector_activation_fn"] = other_config.get("projector_hidden_act", "gelu") model_config["spatial_merge_size"] = other_config.get("spatial_merge_size", None) + if arch == "Bagel": + model_config["encoder"] = { + "hidden_size": vision_config.get("hidden_size", 1152), + "image_size": 1024, # 980 for VIT (vit_config.json), 1024 for VAE + "patch_size": vision_config["patch_size"], + "heads": vision_config["num_attention_heads"], + "heads_kv": vision_config["num_attention_heads"], + "layers": 26, # 27 in config, but actually 26 in safetensors + "transformer_ff": vision_config["intermediate_size"], + # siglip style learned position embeddings (like gemma3) + "position_encoding_type": PositionEncodingType.Learned, + "n_positions": (vision_config["image_size"] // vision_config["patch_size"]) ** 2, + "image_token_id": 151654, + "image_start_token_id": 151652, + "image_end_token_id": 151653, + "max_patches_per_side": 70, + } + if arch == "Gemma3ForConditionalGeneration": if model_config.get("head_dim", None) is None: model_config["head_dim"] = 256 # src/transformers/models/gemma3/configuration_gemma3.py#L61 @@ -646,6 +674,13 @@ def build_shards(model_config, hf, args, params): eole_safetensor = {} def build_first_shard(hf, eole_safetensor): + # let's add AE here (visual autoencoder for image generation) + if hf.ae_model_path is not None: + ae_checkpoint = hf.get_load_ckpt(*os.path.split(hf.ae_model_path)) + ae_params = safetensors.torch.load_file(ae_checkpoint) + for key, value in ae_params.items(): + eole_safetensor[f"image_autoencoder.{key}"] = value + for target in KEY_MAPS[hf.arch].keys(): if model_config["share_decoder_embeddings"] and target == "generator.weight": continue diff --git a/eole/config/inference.py b/eole/config/inference.py index 1d4063e1..07eb19b8 100644 --- a/eole/config/inference.py +++ b/eole/config/inference.py @@ -81,8 +81,40 @@ class DecodingConfig(Config): align_debug: bool = Field(default=False, description="Print best align for each word.") +class ImageGenerationConfig(Config): + """ + Let's centralize image generation related stuff here. + This is not a complete config, but rather a subset of options + that are relevant for image generation tasks. + Used as mixin for InferenceConfig for now, but might be properly nested at some point. + """ + + # image generation specific stuff, might move elsewhere + image_generation: bool | None = Field( + default=False, + description="Generate image from text input. " + "This will only work if the model is trained for image generation.", + ) + image_width: int | None = Field( + default=1024, + description="Width of the generated image. " + "This will only work if the model is trained for image generation.", + ) + image_height: int | None = Field( + default=1024, + description="Height of the generated image. " + "This will only work if the model is trained for image generation.", + ) + cfg_text_scale: float | None = Field(default=1.0, description="Classifier-free guidance scale for text input. ") + cfg_image_scale: float | None = Field(default=1.0, description="Classifier-free guidance scale for image input. ") + cfg_interval_min: float | None = Field(default=0.0, description="Minimum classifier-free guidance interval. ") + cfg_interval_max: float | None = Field(default=1.0, description="Maximum classifier-free guidance interval. ") + timestep_shift: float | None = Field(default=1.0, description="Shift the timestep for image generation. ") + num_timesteps: int | None = Field(default=50, description="Number of timesteps for image generation. ") + + # in legacy opts, decoding config is separated (probably to be used elsewhere) -class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig): +class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig, ImageGenerationConfig): model_config = get_config_dict() model_config["arbitrary_types_allowed"] = True # to allow torch.dtype diff --git a/eole/config/models.py b/eole/config/models.py index 4f639c81..1c3af3f1 100644 --- a/eole/config/models.py +++ b/eole/config/models.py @@ -371,10 +371,17 @@ class VisionEncoderConfig(TransformerConfig, EncoderConfig): num_channels: int | None = 3 image_size: int | None = 1024 patch_size: int | None = 16 + max_patches_per_side: int | None = None + max_latent_size: int | None = 64 # bagel + latent_patch_size: int | None = 2 + latent_channel: int | None = 16 image_token_id: int | None = 10 # pixtral uses 10, gemma3 uses 262144 + image_start_token_id: int | None = None + image_end_token_id: int | None = None mm_tokens_per_image: int | None = 256 # added for gemma3 layernorm_pre: bool = True # True for pixtral/mistral False for gemma3 patch_conv_bias: bool = False # False for pixtral/mistral True for gemma3 + patch_conv_linear: bool = False # False for pixtral/gemma3 True for bagel # use Field with default= + description would be more readable @@ -771,6 +778,10 @@ class VisionTransformerLMModelConfig(TransformerConfig, BaseModelConfig): adapter: str | None = Field(default="llava", description="Adapter type to use in the model.") + vit_position_embeddings: bool = Field( + default=False, description="Additional position embeddings for images, introduced for Bagel." + ) + @model_validator(mode="before") @classmethod def encoder_decoder_type(cls, data: Any) -> Any: diff --git a/eole/decoders/decoder.py b/eole/decoders/decoder.py index f879dd58..12916fd9 100644 --- a/eole/decoders/decoder.py +++ b/eole/decoders/decoder.py @@ -14,6 +14,10 @@ def __init__(self, attentional=True): # Decoder state self.state = {} + @property + def device(self): + return next(self.parameters()).device + @classmethod def from_config(cls, decoder_config, running_config=None, with_cross_attn=False): """Alternate constructor. diff --git a/eole/decoders/transformer.py b/eole/decoders/transformer.py index 0ca6b03d..112c7e57 100644 --- a/eole/decoders/transformer.py +++ b/eole/decoders/transformer.py @@ -76,6 +76,63 @@ def __init__(self, decoder_config, running_config=None, with_cross_attn=False): running_config=running_config, ) + self._maybe_init_image_generation(decoder_config, running_config) + + def _maybe_init_image_generation(self, decoder_config, running_config): + self.image_generation = getattr(running_config, "image_generation", False) + if self.image_generation: + self.input_layernorm_moe_gen = LayerNorm[decoder_config.layer_norm]( + decoder_config.hidden_size, eps=decoder_config.norm_eps + ) + if decoder_config.post_attention_layernorm: + self.post_attention_layernorm_moe_gen = LayerNorm[decoder_config.layer_norm]( + decoder_config.hidden_size, eps=decoder_config.norm_eps + ) + self.mlp_moe_gen = MLP( + decoder_config, + running_config=running_config, + ) + + def _input_layernorm(self, x, **kwargs): + text_indices = kwargs.pop("text_indices", None) + image_indices = kwargs.pop("image_indices", None) + + if self.image_generation: + # TODO: reduce frequency of such assert calls + assert text_indices is not None, "Text indices must be provided for image generation" + assert image_indices is not None, "Image indices must be provided for image generation" + out = torch.zeros_like(x, dtype=x.dtype, device=x.device) + out[:, text_indices, :] = self.input_layernorm(x[:, text_indices, :]) + out[:, image_indices, :] = self.input_layernorm_moe_gen(x[:, image_indices, :]) + return out + else: + return self.input_layernorm(x) + + def _context_attention(self, layer_in, norm_layer_in, self_attn, enc_out, src_pad_mask, return_attn): + if not self.context_attn: + return 0, None + + if self.parallel_residual: + ctx_attn, attns = self.context_attn( + enc_out, + enc_out, + norm_layer_in, + attn_mask=~src_pad_mask, + return_attn=return_attn, + ) + else: + norm_query = self.precontext_layernorm(self_attn + layer_in) + ctx_attn, attns = self.context_attn( + enc_out, + enc_out, + norm_query, + attn_mask=~src_pad_mask, + return_attn=return_attn, + ) + if self.dropout_p > 0: + ctx_attn = self.dropout(ctx_attn) + return ctx_attn, attns + def _mlp(self, hidden_states): if self.ffn_layernorm: hidden_states = self.pre_feedforward_layernorm(hidden_states) @@ -85,6 +142,23 @@ def _mlp(self, hidden_states): hidden_states = self.mlp(hidden_states) return hidden_states + def _prepare_ff_input(self, layer_in, norm_layer_in, self_attn, ctx_attn, text_indices=None, image_indices=None): + """Prepare input for feedforward network, handling different residual configurations""" + if self.parallel_residual: + if not self.shared_layer_norm: + norm_res_layer_in = self.residual_layernorm(layer_in) + return norm_res_layer_in + else: + return norm_layer_in + else: + sequence = ctx_attn + self_attn + layer_in + if self.image_generation: + text_sequence = self.post_attention_layernorm(sequence[:, text_indices, :]) + image_sequence = self.post_attention_layernorm_moe_gen(sequence[:, image_indices, :]) + return text_sequence, image_sequence + else: + return self.post_attention_layernorm(sequence) + def forward(self, layer_in, **kwargs): """ Args: @@ -110,7 +184,14 @@ def forward(self, layer_in, **kwargs): return_attn = kwargs.pop("return_attn", False) position_embeddings = kwargs.pop("position_embeddings", None) - norm_layer_in = self.input_layernorm(layer_in) + text_indices = kwargs.pop("text_indices", None) + image_indices = kwargs.pop("image_indices", None) + + norm_layer_in = self._input_layernorm( + layer_in, + text_indices=text_indices, + image_indices=image_indices, + ) self_attn, attns = self.self_attn( norm_layer_in, @@ -118,6 +199,8 @@ def forward(self, layer_in, **kwargs): step=step, return_attn=return_attn, position_embeddings=position_embeddings, + text_indices=text_indices, + image_indices=image_indices, ) if self.dropout_p > 0: @@ -129,39 +212,26 @@ def forward(self, layer_in, **kwargs): layer_out = ff_in + self._mlp(ff_in) return layer_out, attns - if self.parallel_residual: - if self.context_attn: - ctx_attn, attns = self.context_attn( - enc_out, - enc_out, - norm_layer_in, - attn_mask=~src_pad_mask, - return_attn=return_attn, - ) - else: - ctx_attn = 0 - if not self.shared_layer_norm: - norm_res_layer_in = self.residual_layernorm(layer_in) - ff_in = norm_res_layer_in - else: - ff_in = norm_layer_in + ctx_attn, attns = self._context_attention( + layer_in, + norm_layer_in, + self_attn, + enc_out, + src_pad_mask, + return_attn, + ) + + ff_in = self._prepare_ff_input(layer_in, norm_layer_in, self_attn, ctx_attn, text_indices, image_indices) + if self.image_generation: + text_sequence, image_sequence = ff_in + MLP = torch.zeros_like(layer_in, dtype=layer_in.dtype, device=layer_in.device) + MLP[:, text_indices, :] = self.mlp(text_sequence) + MLP[:, image_indices, :] = self.mlp_moe_gen(image_sequence) else: - if self.context_attn: - norm_query = self.precontext_layernorm(self_attn + layer_in) - ctx_attn, attns = self.context_attn( - enc_out, - enc_out, - norm_query, - attn_mask=~src_pad_mask, - return_attn=return_attn, - ) - if self.dropout_p > 0: - ctx_attn = self.dropout(ctx_attn) - else: - ctx_attn = 0 - ff_in = self.post_attention_layernorm(ctx_attn + self_attn + layer_in) + MLP = self.mlp(ff_in) + # we apply residual with un-normed - layer_out = self.mlp(ff_in) + layer_in + self_attn + ctx_attn + layer_out = MLP + layer_in + self_attn + ctx_attn return layer_out, attns @@ -225,7 +295,13 @@ def __init__( for i in range(decoder_config.layers) ] ) + self.image_generation = getattr(running_config, "image_generation", False) self.layer_norm = LayerNorm[decoder_config.layer_norm](decoder_config.hidden_size, eps=decoder_config.norm_eps) + if self.image_generation: + # MOE GEN params + self.layer_norm_moe_gen = LayerNorm[decoder_config.layer_norm]( + decoder_config.hidden_size, eps=decoder_config.norm_eps + ) self._disable_cache() @classmethod @@ -311,7 +387,10 @@ def forward(self, emb, **kwargs): step = kwargs.pop("step", None) with_align = kwargs.pop("with_align", False) return_attn = with_align or kwargs.pop("return_attn", False) - position_embeddings = self.rope.update(emb.size(1), step=step) + positions = kwargs.pop("positions", None) + position_embeddings = self.rope.update( + emb.size(1), step=step, positions=positions, device=self.device, dtype=emb.dtype + ) if self.rope_local is not None: position_embeddings_local = self.rope_local.update(emb.size(1), step=step) else: @@ -339,7 +418,10 @@ def forward(self, emb, **kwargs): # we need to adapt the mask for gemma3, TODO: find another condition? # SEEMS OK TO MASK IMAGES FOR LLAVA TOO ? if decoder_in is not None and attn_mask is not None: - attn_mask = self._update_causal_mask(attn_mask, decoder_in == image_token_id) + attn_mask = self._update_causal_mask( + attn_mask, (decoder_in == image_token_id) | (decoder_in == 151652) | (decoder_in == 151653) + ) + if self.sliding_window > 0 and step >= self.sliding_window and attn_mask is not None: attn_mask = attn_mask[:, :, :, -self.sliding_window :] @@ -354,6 +436,7 @@ def forward(self, emb, **kwargs): position_embeddings=( position_embeddings_local if (i + 1) % self.interleave_local else position_embeddings ), + **kwargs, ) if with_align: attn_align = layer.get_attn_align( @@ -371,7 +454,18 @@ def forward(self, emb, **kwargs): if attn_align is not None: attn_aligns.append(attn_align) - emb = self.layer_norm(emb) + # TODO apply MOE logic here + if self.image_generation: + emb_ = torch.zeros_like(emb, dtype=emb.dtype, device=emb.device) + text_indices = kwargs.get("text_indices", None) + image_indices = kwargs.get("image_indices", None) + assert text_indices is not None, "Text indices must be provided for image generation" + assert image_indices is not None, "Image indices must be provided for image generation" + emb_[:, text_indices, :] = self.layer_norm(emb[:, text_indices, :]) + emb_[:, image_indices, :] = self.layer_norm_moe_gen(emb[:, image_indices, :]) + emb = emb_ + else: + emb = self.layer_norm(emb) # we take the first head top_attn = None if attn is None else attn[:, 0, :, :].contiguous() diff --git a/eole/encoders/encoder.py b/eole/encoders/encoder.py index 25a92ad5..d3464db1 100644 --- a/eole/encoders/encoder.py +++ b/eole/encoders/encoder.py @@ -16,6 +16,10 @@ class EncoderBase(nn.Module): def from_config(cls, encoder_config, running_config=None): raise NotImplementedError + @property + def device(self): + return next(self.parameters()).device + def forward(self, emb, **kwargs): """ Args: diff --git a/eole/encoders/transformer.py b/eole/encoders/transformer.py index 0e023ae3..83f99770 100644 --- a/eole/encoders/transformer.py +++ b/eole/encoders/transformer.py @@ -45,6 +45,7 @@ def __init__( self.mlp = MLP( encoder_config, running_config=running_config, + is_decoder=False, ) def forward(self, layer_in, pad_mask, position_embeddings=None): @@ -131,7 +132,7 @@ def forward(self, emb, **kwargs): """ pad_mask = kwargs.pop("pad_mask", None) assert pad_mask is not None, "TransformerEncoder requires a src pad mask" - position_embeddings = self.rope.update(emb.size(1), step=None) + position_embeddings = self.rope.update(emb.size(1), step=None, device=self.device, dtype=emb.dtype) pad_mask = pad_mask.unsqueeze(1) # batch x 1 x 1 x maxlen # dim 1 (heads) and 2 (src_len) will be broadcasted automatically in MHA diff --git a/eole/encoders/vision.py b/eole/encoders/vision.py index 01396464..a1aa9d94 100644 --- a/eole/encoders/vision.py +++ b/eole/encoders/vision.py @@ -53,18 +53,28 @@ def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> to return image_features -def position_ids_in_meshgrid(patch_embeds_list, max_width, flatten=True): - positions = [] - for patch in patch_embeds_list: - height, width = patch.shape[-2:] - mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") - h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) - ids = h_grid * max_width + v_grid - positions.append(ids[:, 0]) - if flatten: - return torch.cat(positions) - else: - return torch.stack(positions) +# replaced by get_flattened_position_ids_extrapolate +# def position_ids_in_meshgrid(patch_embeds_list, max_width, flatten=True): +# positions = [] +# for patch in patch_embeds_list: +# height, width = patch.shape[-2:] +# mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") +# h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) +# ids = h_grid * max_width + v_grid +# positions.append(ids[:, 0]) +# if flatten: +# return torch.cat(positions) +# else: +# return torch.stack(positions) + + +# from bagel +def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side): + num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size + coords_h = torch.arange(0, num_patches_h) + coords_w = torch.arange(0, num_patches_w) + pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() + return pos_ids def create_block_diagonal_mask(lengths, device): @@ -88,6 +98,19 @@ def create_block_diagonal_mask(lengths, device): return mask.to(device) +# grabbed from bagel repo + + +def patchify(image, patch_size): + p = patch_size + c, h, w = image.shape + assert h % p == 0 and w % p == 0 + image = image.reshape(c, h // p, p, w // p, p) + image = torch.einsum("chpwq->hwpqc", image) + image = image.reshape(-1, p**2 * c) + return image + + class VisionEncoder(nn.Module): def __init__(self, encoder_config, running_config=None): super(VisionEncoder, self).__init__() @@ -99,13 +122,24 @@ def __init__(self, encoder_config, running_config=None): ) else: self.rope = build_rope(encoder_config, mode="2d") - self.patch_conv = nn.Conv2d( - in_channels=encoder_config.num_channels, - out_channels=encoder_config.hidden_size, - kernel_size=encoder_config.patch_size, - stride=encoder_config.patch_size, - bias=encoder_config.patch_conv_bias, - ) + + self.patch_conv_linear = encoder_config.patch_conv_linear + if self.patch_conv_linear: + # linear patch conv for bagel + self.patch_conv = nn.Linear( + encoder_config.patch_size * encoder_config.patch_size * encoder_config.num_channels, + encoder_config.hidden_size, + bias=True, + ) + else: + self.patch_conv = nn.Conv2d( + in_channels=encoder_config.num_channels, + out_channels=encoder_config.hidden_size, + kernel_size=encoder_config.patch_size, + stride=encoder_config.patch_size, + bias=encoder_config.patch_conv_bias, + ) + if encoder_config.layernorm_pre: self.ln_pre = RMSNorm(encoder_config.hidden_size, eps=1e-5) else: @@ -133,6 +167,8 @@ def from_config(cls, encoder_config, running_config=None): @property def max_patches_per_side(self): + if self.encoder_config.max_patches_per_side is not None: + return self.encoder_config.max_patches_per_side # hardcoded bagel value return self.encoder_config.image_size // self.encoder_config.patch_size @property @@ -152,7 +188,14 @@ def forward(self, images): dtype = next(self.parameters()).dtype # pass images through initial convolution independently (because they may have different sizes) - patch_embeds_list = [self.patch_conv(img.to(dtype)) for img in images] + + if self.patch_conv_linear: + print("patch_conv_linear") + # TODO: this is a patch condition for bagel, should be improved + pixel_values = [patchify(img, self.encoder_config.patch_size) for img in images] + patch_embeds_list = [self.patch_conv(pv.to(dtype)).transpose(0, 1) for pv in pixel_values] + else: + patch_embeds_list = [self.patch_conv(img.to(dtype)) for img in images] if self.ln_pre is not None: # pixtral / mistral # flatten H+W then change to (H+W, C) and stack all images of ex @@ -172,22 +215,37 @@ def forward(self, images): mask = None # positional embeddings - positions = position_ids_in_meshgrid( - patch_embeds_list, - max_width=self.encoder_config.image_size // self.encoder_config.patch_size, - flatten=self.ln_pre is not None, # dirty flag need to improve - ).to(self.device) + positions = ( + torch.cat( + [ + get_flattened_position_ids_extrapolate( + img.shape[-2], + img.shape[-1], + self.encoder_config.patch_size, + self.max_patches_per_side, + ) + for img in images + ], + axis=0, + ) + .unsqueeze(0) + .to(self.device) + ) + # TODO: make this cleaner if hasattr(self, "position_embeddings"): # this is only used for rope position_embeddings = None - patch_embeds += self.position_embeddings(positions) + pos_embeds = self.position_embeddings(positions) + patch_embeds += pos_embeds else: position_embeddings = self.rope.update( patch_embeds.size(1), step=0, reset=True, - positions=positions, + positions=positions, # enable for bagel only? + device=self.device, + dtype=dtype, ) out = patch_embeds @@ -197,7 +255,7 @@ def forward(self, images): if self.post_layernorm is not None: out = self.post_layernorm(out) - return out + return out, positions # Multi-Modal Projector @@ -266,4 +324,5 @@ def from_config(cls, model_config, running_config=None): str2adapter = { "llava": VisionLanguageAdapter, "gemma3": Gemma3MultiModalProjector, + "bagel": VisionLanguageAdapter, } diff --git a/eole/inputters/bagel_utils.py b/eole/inputters/bagel_utils.py new file mode 100644 index 00000000..ce89e7e0 --- /dev/null +++ b/eole/inputters/bagel_utils.py @@ -0,0 +1,110 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torchvision import transforms +from torchvision.transforms import functional as F +from torchvision.transforms import InterpolationMode + + +class MaxLongEdgeMinShortEdgeResize(torch.nn.Module): + """Resize the input image so that its longest side and shortest side are within a specified range, + ensuring that both sides are divisible by a specified stride. + + Args: + max_size (int): Maximum size for the longest edge of the image. + min_size (int): Minimum size for the shortest edge of the image. + stride (int): Value by which the height and width of the image must be divisible. + max_pixels (int): Maximum pixels for the full image. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted. + antialias (bool, optional): Whether to apply antialiasing (default is True). + """ + + def __init__( + self, + max_size: int, + min_size: int, + stride: int, + max_pixels: int, + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ): + super().__init__() + self.max_size = max_size + self.min_size = min_size + self.stride = stride + self.max_pixels = max_pixels + self.interpolation = interpolation + self.antialias = antialias + + def _make_divisible(self, value, stride): + """Ensure the value is divisible by the stride.""" + return max(stride, int(round(value / stride) * stride)) + + def _apply_scale(self, width, height, scale): + new_width = round(width * scale) + new_height = round(height * scale) + new_width = self._make_divisible(new_width, self.stride) + new_height = self._make_divisible(new_height, self.stride) + return new_width, new_height + + def forward(self, img, img_num=1): + """ + Args: + img (PIL Image): Image to be resized. + img_num (int): Number of images, used to change max_tokens. + Returns: + PIL Image or Tensor: Rescaled image with divisible dimensions. + """ + if isinstance(img, torch.Tensor): + height, width = img.shape[-2:] + else: + width, height = img.size + + scale = min(self.max_size / max(width, height), 1.0) + scale = max(scale, self.min_size / min(width, height)) + new_width, new_height = self._apply_scale(width, height, scale) + + # Ensure the number of pixels does not exceed max_pixels + if new_width * new_height > self.max_pixels / img_num: + scale = self.max_pixels / img_num / (new_width * new_height) + new_width, new_height = self._apply_scale(new_width, new_height, scale) + + # Ensure longest edge does not exceed max_size + if max(new_width, new_height) > self.max_size: + scale = self.max_size / max(new_width, new_height) + new_width, new_height = self._apply_scale(new_width, new_height, scale) + + return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias) + + +class ImageTransform: + def __init__( + self, + max_image_size, + min_image_size, + image_stride, + max_pixels=14 * 14 * 9 * 1024, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + ): + self.stride = image_stride + + self.resize_transform = MaxLongEdgeMinShortEdgeResize( + max_size=max_image_size, + min_size=min_image_size, + stride=image_stride, + max_pixels=max_pixels, + ) + self.to_tensor_transform = transforms.ToTensor() + self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True) + + def __call__(self, img, img_num=1): + img = self.resize_transform(img, img_num=img_num) + img = self.to_tensor_transform(img) + img = self.normalize_transform(img) + return img diff --git a/eole/inputters/image_utils.py b/eole/inputters/image_utils.py index fe015a57..67b056f0 100644 --- a/eole/inputters/image_utils.py +++ b/eole/inputters/image_utils.py @@ -55,8 +55,7 @@ def transform_image(image: Image.Image, new_size: Tuple[int, int]) -> np.ndarray return normalize_llava(np_image, DATASET_MEAN, DATASET_STD) -def image_to_num_tokens(img, image_size=1024, image_patch_size=16): - w, h = img.size +def image_to_num_tokens(w, h, image_size=1024, image_patch_size=16): ratio = max(h / image_size, w / image_size) if ratio > 1: w = round(w / ratio) @@ -417,7 +416,8 @@ def to_channel_dimension_format( def process_image(image_path, adapter="llava", image_size=1024, image_patch_size=16): if adapter == "llava": image = Image.open(image_path) - w, h = image_to_num_tokens(image, image_size=image_size, image_patch_size=image_patch_size) + w, h = image.size + w, h = image_to_num_tokens(w, h, image_size=image_size, image_patch_size=image_patch_size) new_image_size = (w * image_patch_size, h * image_patch_size) # TODO retrieve from model config / vocab / tokenizer image_tokens = (["[IMG]"] * w + ["[IMG_BREAK]"]) * h @@ -437,5 +437,19 @@ def process_image(image_path, adapter="llava", image_size=1024, image_patch_size # TODO: make this configurable? image_tokens = "" + "" * 256 + "" return {"image": image, "tokens": image_tokens} + elif adapter == "bagel": + from eole.inputters.bagel_utils import ImageTransform + + vae_transform = ImageTransform(1024, 512, 16) + vit_transform = ImageTransform(980, 224, 14) + image = Image.open(image_path) + image = _convert_to_rgb(image) + image = vae_transform.resize_transform(image) + image = vit_transform(image) + # return image + w, h = image.shape[-2:] + w, h = image_to_num_tokens(w, h, image_size=image_size, image_patch_size=image_patch_size) + image_tokens = ["<|vision_start|>"] + ["<|vision_pad|>"] * w * h + ["<|vision_end|>"] + return {"image": image, "tokens": image_tokens} else: raise ValueError("Unsupported Adapter type: {}".format(adapter)) diff --git a/eole/models/model.py b/eole/models/model.py index 6dbec78b..e8dfdf62 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -24,10 +24,14 @@ from eole.modules.embeddings import Embeddings from eole.models.model_saver import load_checkpoint from eole.modules.estimator import FeedForward +from eole.modules.bagel_autoencoder import AutoEncoder, AutoEncoderParams, BagelPositionEmbedding, TimestepEmbedder from eole.encoders.vision import str2adapter from eole.encoders.vision import VisionEncoder +from PIL import Image +from tqdm import tqdm + def build_encoder(model_config, running_config=None): """ @@ -909,12 +913,57 @@ def __init__(self, **kwargs): super(VisionEncoderDecoderModel, self).__init__(**kwargs) self.tgt_shift = 1 self.image_token_id = kwargs.get("image_token_id", None) + self.image_start_token_id = kwargs.get("image_start_token_id", None) + self.image_end_token_id = kwargs.get("image_end_token_id", None) if self.encoder is None or self.decoder is None: raise ValueError("A EncoderDecoderModel requires both an Encoder and a Decoder") # TODO: make this compatible? if self.add_estimator: self.estimator = FeedForward(self.hidden_size) + self.vit_position_embeddings = kwargs.pop("vit_position_embeddings", False) + if self.vit_position_embeddings: + self.vit_pos_embed = BagelPositionEmbedding( + 70, # todo grab from config + self.hidden_size, + ) + + def _maybe_init_image_generation(self, model_config, running_config): + self.image_generation = getattr(running_config, "image_generation", False) + if self.image_generation: + # NOTE: this is 100% mapped on ByteDance Bagel for now + self.latent_pos_embed = BagelPositionEmbedding( + # NOTE: deduced from weights, though config seems to tell 32... + 64, # todo grab from config # max_latent_size, + self.hidden_size, + ) + # TODO: implement properly? + ae_params = AutoEncoderParams( + resolution=256, + in_channels=3, + downsample=8, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ) + self.image_autoencoder = AutoEncoder(ae_params) + self.time_embedder = TimestepEmbedder(self.hidden_size) + latent_patch_size = 2 # TODO: check this + latent_channel = 16 # TODO: check this + self.patch_latent_dim = latent_patch_size**2 * latent_channel + self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size) + self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim) + + # settings + downsample = 8 # (grab from ae_params?) + self.latent_downsample = downsample * latent_patch_size + self.latent_patch_size = latent_patch_size + self.latent_channel = latent_channel + @classmethod def build_blocks(cls, model_config, vocabs, running_config=None): encoder = build_encoder(model_config, running_config=running_config) @@ -926,7 +975,7 @@ def build_blocks(cls, model_config, vocabs, running_config=None): share_embeddings=model_config.share_embeddings, ) decoder = build_decoder(model_config, running_config=running_config) - return cls( + model = cls( encoder=encoder, decoder=decoder, adapter=adapter, @@ -934,24 +983,48 @@ def build_blocks(cls, model_config, vocabs, running_config=None): add_estimator=model_config.add_estimator, hidden_size=model_config.decoder.hidden_size, image_token_id=model_config.encoder.image_token_id, + image_start_token_id=model_config.encoder.image_start_token_id, + image_end_token_id=model_config.encoder.image_end_token_id, + vit_position_embeddings=model_config.vit_position_embeddings, ) + model._maybe_init_image_generation(model_config, running_config) # from there, the base blocks exist, and the rest is done in the from_opt from base class + return model def embed_vision_language_features(self, src, images): # TODO: test with batch > 1? batch_size = src.size(0) text_locations = src != self.image_token_id image_locations = src == self.image_token_id + # build text_positions + non_text_ids = [] + if self.image_token_id is not None: + non_text_ids.append(self.image_token_id) + if self.image_start_token_id is not None: + non_text_ids.append(self.image_start_token_id) + if self.image_end_token_id is not None: + non_text_ids.append(self.image_end_token_id) + text_positions = ( + ~torch.isin(src, torch.tensor(non_text_ids, device=src.device)) + if non_text_ids + else torch.ones_like(src, dtype=torch.bool) + ) + text_features = self.tgt_emb(src[text_locations].view(batch_size, -1)) if len(images) == 0: return text_features image_sizes = torch.tensor([[images[i].size(1), images[i].size(2)] for i in range(len(images))]) # images is a list of tensors, each being [channel, H, W] - encoded_images = self.encoder(images) + encoded_images, positions = self.encoder(images) # encoded_images is [N_img x seq x hidden_size] image_features = self.adapter(encoded_images, image_sizes=image_sizes) + # Introduced for BAGEL, need to factorize/make configurable properly + if self.vit_position_embeddings: + pos_emb = self.vit_pos_embed(positions) + image_features = image_features + pos_emb + seq_len = src.shape[1] batch, N_txt, D_txt = text_features.shape N_img, tokperimg, D_img = image_features.shape @@ -972,7 +1045,16 @@ def embed_vision_language_features(self, src, images): image_mask = image_locations.unsqueeze(-1).expand_as(combined_features) combined_features = combined_features.masked_scatter(image_mask, image_features.view(-1, D_img)) - return combined_features + # deduce positions for bagel ([0] * image_positions + range(1, len(text_locations) + 1))]) + # NOTE: only works for single image, need to adapt if several images + positions = torch.zeros((batch, seq_len), dtype=torch.long, device=text_features.device) + positions[text_positions] = torch.arange(1, text_positions.sum().item() + 1, device=text_features.device) + positions = positions[0] + + # only return positions for bagel ? + # positions = None + + return combined_features, positions def forward(self, src, tgt, src_len, bptt=False, with_align=False, images=[]): """A DecoderModel forward the src side to the decoder along @@ -1003,6 +1085,232 @@ def forward(self, src, tgt, src_len, bptt=False, with_align=False, images=[]): return dec_out, attns, estim + def prepare_image_generation(self, image_width=1024, image_height=1024, current_position_id=0): + """ + Prepare necessary stuff for image generation: + - noise + - seqlens + - ... + """ + h, w = image_height // self.latent_downsample, image_width // self.latent_downsample + num_image_tokens = h * w + init_noise = torch.randn( + num_image_tokens, + self.latent_channel * self.latent_patch_size**2, + # device="cuda", # TODO: deduce from config or running_config + # dtype=torch.bfloat16, # TODO: deduce from config or running_config, or cast later? + ).to(device="cuda", dtype=torch.bfloat16) + seqlens = num_image_tokens + 2 + position_ids = [current_position_id] * seqlens + + # how to handle prompt + image tokens? + # official bagel code does two steps: first fill kv cache based on text prompt, then generate image + # can we do this in a single forward? (full forward then select only image token positions...) + + return ( + init_noise, + # seqlens, + position_ids, + ) + + def generate_image(self, text_src, init_noise, position_ids, num_timesteps=20, timestep_shift=1.0): + """ + Generate an image from the input features. + # TODO: proper docstring + # TODO: might need to move this to another subclass for clarity + Stuff needed: + - timesteps + - packed_init_noises + - packed_text_ids (x) + - packed_vae_position_ids + - packed_seqlens (flatten/patched image size) + - ... + """ + + device = init_noise.device + + x_t = init_noise + + timesteps = torch.linspace(1, 0, num_timesteps, device=device, dtype=torch.bfloat16) # TODO: deduce dtype + timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps) + dts = timesteps[:-1] - timesteps[1:] + timesteps = timesteps[:-1] + + num_image_tokens, latent_dim = x_t.shape + # TODO: make this configurable? + # min_cfg, max_cfg = 0.0, 1.0 + # cfg_text_scale = 1.0 + # cfg_img_scale = 1.0 + + text_ids = torch.tensor([self.image_start_token_id, self.image_end_token_id], device=device) + text_indices = [0, num_image_tokens + 1] # does this always work? + image_indices = list(range(1, num_image_tokens + 1)) # [1, ..., num_image_tokens] + + image_position_ids = torch.arange(0, num_image_tokens, device=device) + + for i, t in tqdm(enumerate(timesteps)): + timestep = torch.tensor([t] * num_image_tokens, device=device) + # TODO: check in official implementation if this is needed/used + # if t > min_cfg and t <= max_cfg: + # cfg_text_scale_ = cfg_text_scale + # cfg_img_scale_ = cfg_img_scale + # else: + # cfg_text_scale_ = 1.0 + # cfg_img_scale_ = 1.0 + + v_t = self.forward_image_gen( + text_src, + x_t, + timestep, + text_ids, + text_indices, + image_indices, + torch.tensor([num_image_tokens + 2], device=device), # seqlens is always num_image_tokens + 2? + image_position_ids, + position_ids, + ) + + x_t = x_t - v_t.to(device) * dts[i] + + output = x_t.split([num_image_tokens]) + # no real multi image support for now, so we just return the first one + return output[0] + + def forward_image_gen( + self, + text_src, + x_t, + timestep, + text_ids, + text_indices, + image_indices, + seqlens, + image_position_ids, + position_ids, + # cfg_text_scale=1.0, + cfg_text_scale=4.0, + cfg_img_scale=1.0, + cfg_renorm_type="global", + cfg_renorm_min=0.0, + ): + """ + (Somewhat corresponds to bagel._forward_flow at high level.) + """ + text_embeddings = self.tgt_emb(text_ids) + text_prompt_emb = self.tgt_emb(text_src) + + sequence = text_embeddings.new_zeros((sum(seqlens), self.hidden_size)) + sequence[text_indices] = text_embeddings + + position_embeddings = self.latent_pos_embed(image_position_ids) + timestep_embeddings = self.time_embedder(timestep) + + x_t = self.vae2llm(x_t) + timestep_embeddings + position_embeddings + sequence[image_indices] = x_t + + sequence = sequence.unsqueeze(0) + + # used for CFG + sequence_without_text = sequence.clone() + sequence = torch.cat((text_prompt_emb, sequence), dim=1) + + offset_image_indices = [i + text_src.size(1) for i in image_indices] + offset_text_indices = list(range(text_src.size(1))) + [i + text_src.size(1) for i in text_indices] + output, _ = self.decoder( + sequence, + step=0, # not sure + enc_out=None, + src_len=seqlens, + # tgt_pad_mask=None, # TODO: handle padding mask properly + tgt_pad_mask=torch.zeros((sequence.size(0), sequence.size(1))).to( + dtype=torch.bool, device=sequence.device + ), # no padding + text_indices=offset_text_indices, + image_indices=offset_image_indices, + # dirty patch to allow update_causal_mask in decoder + decoder_in=torch.cat( + (text_src, torch.tensor([[self.image_token_id] * (len(image_indices) + 2)], device=text_src.device)), + dim=1, + ), + image_token_id=self.image_token_id, + positions=torch.tensor(list(range(text_src.size(1))) + position_ids, device=text_src.device), + ) + v_t = self.llm2vae(output) + v_t = v_t.squeeze(0) + v_t = v_t[offset_image_indices] # select only image tokens + + # TODO: additional conditions for cfg_text_scale / cfg_img_scale ? + if cfg_text_scale > 1.0: + cfg_text_output, _ = self.decoder( + sequence_without_text, + step=0, + enc_out=None, + src_len=sequence_without_text.size(1), + tgt_pad_mask=torch.zeros((sequence_without_text.size(0), sequence_without_text.size(1))).to( + dtype=torch.bool, device=sequence_without_text.device + ), # no padding + text_indices=text_indices, + image_indices=image_indices, + decoder_in=torch.tensor([[self.image_token_id] * (len(image_indices) + 2)], device=text_src.device), + image_token_id=self.image_token_id, + positions=torch.zeros((sequence_without_text.size(1)), device=text_src.device), + # TODO: find a way to disable cache update for such calls + # (might be an issue for more complex queries downstream) + ) + cfg_text_v_t = self.llm2vae(cfg_text_output) + cfg_text_v_t = cfg_text_v_t.squeeze(0) + cfg_text_v_t = cfg_text_v_t[image_indices] # select only image tokens + + if cfg_img_scale > 1.0: + cfg_img_v_t = v_t.clone() + # this is actually useful only for the image editing case (input=text+image, output=image), + # which is still to be investigated + pass + + # cfg renorm stuff + if cfg_text_scale > 1.0: + if cfg_renorm_type == "text_channel": + v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) + norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) + norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True) + scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + v_t_text = v_t_text_ * scale + if cfg_img_scale > 1.0: + v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t) + else: + v_t = v_t_text + else: + v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) + + if cfg_img_scale > 1.0: + v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t) + else: + v_t_ = v_t_text_ + + # NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit + if cfg_renorm_type == "global": + norm_v_t = torch.norm(v_t) + norm_v_t_ = torch.norm(v_t_) + elif cfg_renorm_type == "channel": + norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) + norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True) + else: + raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted") + scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + v_t = v_t_ * scale + + return v_t + + def decode_image(self, latent, image_height, image_width): + h, w = image_height // self.latent_downsample, image_width // self.latent_downsample + latent = latent.reshape(1, h, w, self.latent_patch_size, self.latent_patch_size, self.latent_channel) + latent = torch.einsum("nhwpqc->nchpwq", latent) + latent = latent.reshape(1, self.latent_channel, h * self.latent_patch_size, w * self.latent_patch_size) + image = self.image_autoencoder.decode(latent) + image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255 + image = Image.fromarray((image).to(torch.uint8).cpu().numpy()) + return image + def update_dropout(self, dropout, attention_dropout): self.encoder.update_dropout(dropout, attention_dropout) self.src_emb.update_dropout(dropout) diff --git a/eole/modules/bagel_autoencoder.py b/eole/modules/bagel_autoencoder.py new file mode 100644 index 00000000..da0782b1 --- /dev/null +++ b/eole/modules/bagel_autoencoder.py @@ -0,0 +1,475 @@ +# Copyright (c) 2024 Black Forest Labs. +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20. +# +# Original file was released under Apache-2.0, with the full license text +# available at https://github.com/black-forest-labs/flux/blob/main/LICENSE. +# +# This modified file is released under the same license. + +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn +from safetensors.torch import load_file as load_sft +import numpy as np +import math + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + downsample: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + +def print_load_warning(missing: list[str], unexpected: list[str]) -> None: + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + + +def load_ae(local_path: str) -> AutoEncoder: + ae_params = AutoEncoderParams( + resolution=256, + in_channels=3, + downsample=8, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ) + + # Loading the autoencoder + ae = AutoEncoder(ae_params) + + if local_path is not None: + sd = load_sft(local_path) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return ae, ae_params + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +class BagelPositionEmbedding(nn.Module): + def __init__(self, max_num_patch_per_side, hidden_size): + super().__init__() + self.max_num_patch_per_side = max_num_patch_per_side + self.hidden_size = hidden_size + self.pos_embed = nn.Parameter(torch.zeros(max_num_patch_per_side**2, hidden_size), requires_grad=False) + self._init_weights() + + def _init_weights(self): + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float()) + + def forward(self, position_ids): + return self.pos_embed[position_ids] + + +# -------------------------------------------------------- +# TimestepEmbedder +# Reference: +# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py +# -------------------------------------------------------- +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=t.device + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_freq = t_freq.to(dtype=t.dtype) # TODO: not sure about casting here + t_emb = self.mlp(t_freq) + return t_emb diff --git a/eole/modules/multi_headed_attn.py b/eole/modules/multi_headed_attn.py index 7d1e1f13..c9815aa4 100644 --- a/eole/modules/multi_headed_attn.py +++ b/eole/modules/multi_headed_attn.py @@ -107,14 +107,15 @@ def __init__(self, model_config, running_config=None, is_decoder: bool = True) - out_features=self.dim_per_head * self.heads // self.parallel_gpu, bias=model_config.add_qkvbias, ) + self.dropout_p = getattr(running_config, "attention_dropout", [0.0])[0] self.dropout = nn.Dropout(self.dropout_p) # introduced for gemma3 if model_config.query_norm: - self.q_norm = LayerNorm[model_config.layer_norm](model_config.head_dim, eps=model_config.norm_eps) + self.q_norm = LayerNorm[model_config.layer_norm](model_config.dim_per_head, eps=model_config.norm_eps) if model_config.key_norm: - self.k_norm = LayerNorm[model_config.layer_norm](model_config.head_dim, eps=model_config.norm_eps) + self.k_norm = LayerNorm[model_config.layer_norm](model_config.dim_per_head, eps=model_config.norm_eps) self.final_linear = skip_init( nn.Linear, @@ -122,7 +123,11 @@ def __init__(self, model_config, running_config=None, is_decoder: bool = True) - out_features=model_config.hidden_size, bias=model_config.add_final_linear_bias, ) + self.is_decoder = is_decoder + + self._maybe_init_image_generation(model_config, running_config) + self.scale = self.attn_scaling**-0.5 if self.is_decoder and self.attn_scaling is not None else None self.relative_positions_buckets = model_config.relative_positions_buckets self.kcache, self.kcache = None, None @@ -160,6 +165,45 @@ def __init__(self, model_config, running_config=None, is_decoder: bool = True) - else: self.flash = False + def _maybe_init_image_generation(self, model_config, running_config): + self.image_generation = getattr(running_config, "image_generation", False) + + if self.image_generation and self.is_decoder: + # initialize MOE GEN params + self.linear_keys_moe_gen = skip_init( + nn.Linear, + in_features=model_config.hidden_size, + out_features=self.dim_per_head * self.heads_kv // self.parallel_gpu, + bias=model_config.add_qkvbias, + ) + self.linear_values_moe_gen = skip_init( + nn.Linear, + in_features=model_config.hidden_size, + out_features=self.dim_per_head * self.heads_kv // self.parallel_gpu, + bias=model_config.add_qkvbias, + ) + self.linear_query_moe_gen = skip_init( + nn.Linear, + in_features=model_config.hidden_size, + out_features=self.dim_per_head * self.heads // self.parallel_gpu, + bias=model_config.add_qkvbias, + ) + if model_config.query_norm: + self.q_norm_moe_gen = LayerNorm[model_config.layer_norm]( + model_config.dim_per_head, eps=model_config.norm_eps + ) + if model_config.key_norm: + self.k_norm_moe_gen = LayerNorm[model_config.layer_norm]( + model_config.dim_per_head, eps=model_config.norm_eps + ) + + self.final_linear_moe_gen = skip_init( + nn.Linear, + in_features=self.dim_per_head * self.heads // self.parallel_gpu, + out_features=model_config.hidden_size, + bias=model_config.add_final_linear_bias, + ) + def update_dropout(self, dropout: float) -> None: self.dropout.p = dropout self.dropout_p = dropout @@ -199,8 +243,23 @@ def _prepare_inputs( seqlen = query.size(2) cos, sin = position_embeddings[0][:seqlen], position_embeddings[1][:seqlen] query, key = apply_rotary_emb(query, key, (cos, sin), interleave=self.rotary_interleave) + return key, value, query + def _final_linear(self, context, text_indices, image_indices): + if self.image_generation: + assert ( + text_indices is not None and image_indices is not None + ), "text_indices and image_indices must be provided for image generation" + attn_output = torch.zeros_like(context) + attn_output[:, text_indices, :] = self.final_linear(context[:, text_indices, :]) + attn_output[:, image_indices, :] = self.final_linear_moe_gen(context[:, image_indices, :]) + elif self.kcache is not None: + attn_output = self.final_linear(context) + else: + attn_output = self.maybe_ckpt(self.final_linear, context) + return attn_output + def _compute_attention( self, key: Tensor, @@ -208,6 +267,8 @@ def _compute_attention( query: Tensor, attn_mask: Optional[Tensor] = None, return_attn: Optional[bool] = False, + text_indices: Optional[Tensor] = None, + image_indices: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """ Compute the context vector and the attention vectors. @@ -229,6 +290,7 @@ def _compute_attention( """ b, h, l, d = key.size() + # replaced by enable_gqa? if self.heads_kv < self.heads: qh = query.size(1) # expand key on heads dimension when it's less than query heads (multi-query variant) @@ -252,6 +314,7 @@ def _compute_attention( attn_mask=attn_mask, dropout_p=self.dropout_p, scale=self.scale, + # enable_gqa=True, # not memory efficient -- https://github.com/pytorch/pytorch/issues/154363 ) attn = None else: @@ -315,10 +378,7 @@ def _compute_attention( context = unshape(attn_output) - if self.kcache is not None: - attn_output = self.final_linear(context) - else: - attn_output = self.maybe_ckpt(self.final_linear, context) + attn_output = self._final_linear(context, text_indices, image_indices) if self.parallel_gpu > 1: # all_reduce is an inplace op - not easily backprop @@ -375,7 +435,9 @@ def _prepare_inputs_w_cache( if self.position_encoding_type == PositionEncodingType.Rotary: seqlen = query.size(2) - cos, sin = position_embeddings[0][step : step + seqlen], position_embeddings[1][step : step + seqlen] + cos, sin = position_embeddings[0][step : step + seqlen].to(query.dtype), position_embeddings[1][ + step : step + seqlen + ].to(query.dtype) query, key = apply_rotary_emb(query, key, (cos, sin), interleave=self.rotary_interleave) if step == 0: @@ -388,19 +450,52 @@ def _prepare_inputs_w_cache( self.vcache[:, :, cache_len - 1, :] = value[:, :, 0, :] return self.kcache[:, :, :cache_len, :], self.vcache[:, :, :cache_len, :], query - def forward( - self, - query: Tensor, - attn_mask: Optional[Tensor] = None, - step: Optional[int] = 0, - return_attn: Optional[bool] = False, - position_embeddings=None, - ) -> Tuple[Tensor, Tensor]: - if self.kcache is not None: - # Inference step decoding + def _qkv_image_generation(self, query, text_indices, image_indices): + text_query = query[:, text_indices, :] + image_query = query[:, image_indices, :] + + text_key = self.linear_keys(text_query) + image_key = self.linear_keys_moe_gen(image_query) + text_value = self.linear_values(text_query) + image_value = self.linear_values_moe_gen(image_query) + text_query = self.linear_query(text_query) + image_query = self.linear_query_moe_gen(image_query) + + text_query = shape(text_query, self.dim_per_head) + image_query = shape(image_query, self.dim_per_head) + text_key = shape(text_key, self.dim_per_head) + image_key = shape(image_key, self.dim_per_head) + text_value = shape(text_value, self.dim_per_head) + image_value = shape(image_value, self.dim_per_head) + + key = query.new_zeros((query.size(0), self.heads_kv, query.size(1), self.dim_per_head // self.parallel_gpu)) + value = query.new_zeros((query.size(0), self.heads_kv, query.size(1), self.dim_per_head // self.parallel_gpu)) + query = query.new_zeros((query.size(0), self.heads, query.size(1), self.dim_per_head // self.parallel_gpu)) + + key[:, :, text_indices, :] = text_key + key[:, :, image_indices, :] = image_key + value[:, :, text_indices, :] = text_value + value[:, :, image_indices, :] = image_value + + query[:, :, text_indices, :] = text_query + query[:, :, image_indices, :] = image_query + + if hasattr(self, "q_norm") and hasattr(self, "q_norm_moe_gen"): + query[:, :, text_indices, :] = self.q_norm(query[:, :, text_indices, :]) + query[:, :, image_indices, :] = self.q_norm_moe_gen(query[:, :, image_indices, :]) + if hasattr(self, "k_norm") and hasattr(self, "k_norm_moe_gen"): + key[:, :, text_indices, :] = self.k_norm(key[:, :, text_indices, :]) + key[:, :, image_indices, :] = self.k_norm_moe_gen(key[:, :, image_indices, :]) + return query, key, value + + def _qkv(self, query, text_indices, image_indices): + if self.image_generation: + return self._qkv_image_generation(query, text_indices, image_indices) + else: key = self.linear_keys(query) value = self.linear_values(query) query = self.linear_query(query) + key = shape(key, self.dim_per_head) value = shape(value, self.dim_per_head) query = shape(query, self.dim_per_head) @@ -410,6 +505,21 @@ def forward( if hasattr(self, "k_norm"): key = self.k_norm(key) + return query, key, value + + def forward( + self, + query: Tensor, + attn_mask: Optional[Tensor] = None, + step: Optional[int] = 0, + return_attn: Optional[bool] = False, + position_embeddings=None, + text_indices: Optional[Tensor] = None, + image_indices: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + if self.kcache is not None: + # Inference step decoding + query, key, value = self._qkv(query, text_indices, image_indices) if ( not self.flash or self.position_encoding_type in [PositionEncodingType.Relative, PositionEncodingType.Alibi] @@ -469,6 +579,8 @@ def forward( query, attn_mask=attn_mask, return_attn=return_attn, + text_indices=text_indices, + image_indices=image_indices, ) diff --git a/eole/modules/rope.py b/eole/modules/rope.py index d00b0c18..a9e2ecfc 100644 --- a/eole/modules/rope.py +++ b/eole/modules/rope.py @@ -125,9 +125,10 @@ def __init__(self, model_config, mode="1d", variant="global"): self.llama3_scaling() if getattr(self.model_config.rope_config, "scaling_type", None) == "gemma3" and variant == "global": self.gemma3_scaling() - cos, sin = self.update(1024) - self.register_buffer("cos", cos, persistent=False) - self.register_buffer("sin", sin, persistent=False) + # disable auto initialization (e.g. bagel depends on image positions...) + # cos, sin = self.update(1024) + # self.register_buffer("cos", cos, persistent=False) + # self.register_buffer("sin", sin, persistent=False) def init_2d_inv_freq(self, inv_freq): """ @@ -183,13 +184,20 @@ def gemma3_scaling(self): factor = rope_config.scaling_factor # `8` in the original implementation self.inv_freq /= factor - def forward_1d(self, maxseqlen, step=0, prefetch=1024, offset=32): + def forward_1d( + self, maxseqlen, step=0, prefetch=1024, offset=32, positions=None, device="cpu", dtype=torch.float32 + ): maxseqlen += prefetch - device = self.cos.device if hasattr(self, "cos") else torch.device("cpu") - dtype = self.cos.dtype if hasattr(self, "cos") else torch.float32 + # device = self.cos.device if hasattr(self, "cos") else torch.device("cuda") + # dtype = self.cos.dtype if hasattr(self, "cos") else torch.float32 + + if positions is None: + tmax = torch.arange(maxseqlen, device=device) + else: + tmax = positions.to(device) - tmax = torch.arange(max(offset + step, 0) + maxseqlen, device=device) tmax += self.model_config.rope_config.tmax_index + rope = torch.outer(tmax, self.inv_freq.to(device)) cos = torch.cos(rope) sin = torch.sin(rope) @@ -219,10 +227,12 @@ def forward_1d(self, maxseqlen, step=0, prefetch=1024, offset=32): # self.inv_freq = self.original_inv_freq # self.max_seq_len_cached = self.original_max_seq_len - def forward_2d(self, maxseqlen, step=0, prefetch=1024, offset=32, positions=None): + def forward_2d( + self, maxseqlen, step=0, prefetch=1024, offset=32, positions=None, device="cpu", dtype=torch.float32 + ): # TODO: maybe do scaling here - device = self.cos.device if hasattr(self, "cos") else torch.device("cpu") - dtype = self.cos.dtype if hasattr(self, "cos") else torch.float32 + # device = self.cos.device if hasattr(self, "cos") else torch.device("cuda") + # dtype = self.cos.dtype if hasattr(self, "cos") else torch.float32 if positions is None: tmax = torch.arange(maxseqlen, device=self.inv_freq.device) @@ -237,7 +247,7 @@ def forward_2d(self, maxseqlen, step=0, prefetch=1024, offset=32, positions=None return cos, sin - def update(self, maxseqlen, step=0, prefetch=1024, reset=False, positions=None): + def update(self, maxseqlen, step=0, prefetch=1024, reset=False, positions=None, device="cpu", dtype=torch.float32): """ Computes the rotary position embeddings for a given input. Args: @@ -261,10 +271,25 @@ def update(self, maxseqlen, step=0, prefetch=1024, reset=False, positions=None): maxseqlen = max(maxseqlen, 1024) # reset as in init() with self.update(1024) elif hasattr(self, "cos") and self.cos.size(0) >= max(offset + (step or 0), 0) + maxseqlen: return self.cos, self.sin + if positions is not None: + # apply offset... + positions = torch.cat( + (positions, torch.arange(positions[-1], positions[-1] + offset, device=positions.device)), dim=0 + ) if self.mode == "1d": - cos, sin = self.forward_1d(maxseqlen, step=(step or 0), prefetch=prefetch, offset=offset) + cos, sin = self.forward_1d( + maxseqlen, + step=(step or 0), + prefetch=prefetch, + offset=offset, + positions=positions, + device=device, + dtype=dtype, + ) elif self.mode == "2d": - cos, sin = self.forward_2d(maxseqlen, step=(step or 0), prefetch=prefetch, positions=positions) + cos, sin = self.forward_2d( + maxseqlen, step=(step or 0), prefetch=prefetch, positions=positions, device=device, dtype=dtype + ) else: raise NotImplementedError self.register_buffer("cos", cos, persistent=False) diff --git a/eole/modules/transformer_mlp.py b/eole/modules/transformer_mlp.py index 2063ee29..065a15a4 100644 --- a/eole/modules/transformer_mlp.py +++ b/eole/modules/transformer_mlp.py @@ -15,12 +15,7 @@ class MLP(nn.Module): running_config: TrainingConfig or InferenceConfig derived from RunningConfig """ - def __init__( - self, - model_config, - running_config=None, - is_moe=False, - ): + def __init__(self, model_config, running_config=None, is_moe=False, is_decoder: bool = True): self.parallel_gpu = getattr(running_config, "parallel_gpu", 1) super(MLP, self).__init__() if is_moe: diff --git a/eole/predict/inference.py b/eole/predict/inference.py index 23671b91..7c067065 100644 --- a/eole/predict/inference.py +++ b/eole/predict/inference.py @@ -89,6 +89,11 @@ def __init__( optional_eos=[], id_tokenization=False, image_token_id=10, + image_generation=False, + image_width=1024, + image_height=1024, + num_timesteps=20, + output=None, ): self.model = model self.vocabs = vocabs @@ -165,6 +170,15 @@ def __init__( self.id_tokenization = id_tokenization self.image_token_id = image_token_id + self.positions = None + + # image generation + self.image_generation = image_generation + self.image_width = image_width + self.image_height = image_height + self.num_timesteps = num_timesteps + self.output = output + @classmethod def from_config( cls, @@ -240,6 +254,11 @@ def from_config( optional_eos=config.optional_eos, id_tokenization=id_tokenization, image_token_id=image_token_id, + image_generation=config.image_generation, + image_width=config.image_width, + image_height=config.image_height, + num_timesteps=config.num_timesteps, + output=config.output, ) def _log(self, msg): @@ -634,9 +653,37 @@ def _decode_and_generate( src_pad_mask = None if images is not None and step == 0: - emb = self.model.embed_vision_language_features(decoder_in, images) + emb, positions = self.model.embed_vision_language_features(decoder_in, images) + self.positions = positions + # "simple" image generation case + elif self.image_generation: + init_noise, position_ids = self.model.prepare_image_generation( + image_width=self.image_width, + image_height=self.image_height, + current_position_id=src_len.max().item(), # TODO: not sure + ) + latent = self.model.generate_image( + decoder_in, + init_noise, + position_ids, + num_timesteps=self.num_timesteps, + ) + image = self.model.decode_image(latent, self.image_height, self.image_width) + # TODO: should this logic be moved elsewhere? + if self.output is not None: + image.save(self.output) + exit() + + # image edition case + elif images is not None and self.image_generation: + pass else: emb = self.model.tgt_emb(decoder_in, step=step) + if self.positions is not None: + # add position + # NOTE: this does not work if image is after text... + next_pos = self.positions[-1] + 1 + self.positions = torch.cat([self.positions, next_pos.unsqueeze(0)]) tgt_pad_mask = decoder_in.eq(self._tgt_pad_idx).unsqueeze(1) # [B, 1, T_tgt] dec_out, dec_attn = self.model.decoder( @@ -650,6 +697,8 @@ def _decode_and_generate( left_pad=left_pad, decoder_in=decoder_in, image_token_id=self.image_token_id, + # TODO: retrieve proper positions for bagel ([0] * image_tokens + [1, 2, ...]) + positions=self.positions, ) # Generator forward. if "std" in dec_attn: diff --git a/recipes/bagel/README.md b/recipes/bagel/README.md new file mode 100644 index 00000000..cc265ae2 --- /dev/null +++ b/recipes/bagel/README.md @@ -0,0 +1 @@ +eole convert HF --model_dir ByteDance-Seed/BAGEL-7B-MoT --output ./bagel --token $HF_TOKEN \ No newline at end of file diff --git a/recipes/bagel/generated_image_42.png b/recipes/bagel/generated_image_42.png new file mode 100644 index 00000000..19232daf Binary files /dev/null and b/recipes/bagel/generated_image_42.png differ diff --git a/recipes/bagel/test_bagel_generation.py b/recipes/bagel/test_bagel_generation.py new file mode 100644 index 00000000..0828561d --- /dev/null +++ b/recipes/bagel/test_bagel_generation.py @@ -0,0 +1,74 @@ +# flake8: noqa + +from rich import print +from eole.config.run import * +from eole.inference_engine import InferenceEnginePY + +seed = 42 + +config = PredictConfig( + model_path="./bagel", + src="dummy", + max_length=600, + gpu_ranks=[0], + # quant_type="bnb_NF4", + quant_type="bnb_FP4", # HF default, using it for initial reproducibility checks + quant_layers=[ + "gate_up_proj", + "down_proj", + "up_proj", + "linear_values", + "linear_query", + "linear_keys", + "final_linear", + "w_in", + "w_out", + ], + compute_dtype="bf16", + top_k=0, + top_p=0.0, + # top_p=0.8, + # temperature=0.35, + # beam_size=5, + beam_size=1, + # temperature=0.3, + seed=seed, + batch_size=1, + batch_type="sents", + self_attn_backend="pytorch", + image_generation=True, + image_width=1024, + image_height=1024, + # num_timesteps=10, + num_timesteps=30, + # num_timesteps=50, + # self_attn_backend="flash", # not properly supported (mixed masking) + output=f"generated_image_{seed}.png", +) + +print(config) + +# config.data_type = "image" +config.data_type = "text" +engine = InferenceEnginePY(config) + +print(engine.predictor.model) +engine.predictor.model.count_parameters() + +# prompt = "A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere." + +prompt = "A breathtaking photorealistic landscape of a windswept coastal cliff at golden hour. The scene features jagged rocks covered in moss, waves crashing below with mist rising, and seabirds flying overhead. The lighting is warm and natural, casting long shadows and reflecting on wet surfaces. The level of detail is ultra high, with textures of stone, water, and clouds rendered realistically, evoking a feeling of awe and solitude." + +# test_input = [{ +# "text": f"<|im_start|>{prompt}<|im_end|><|im_start|>" +# }] #not fully sure about prompt structure + +test_input = [f"<|im_start|>{prompt}<|im_end|>"] + +import torch +import numpy as np +import random + +pred = engine.infer_list(test_input) + +print(pred) diff --git a/recipes/bagel/test_bagel_understanding.py b/recipes/bagel/test_bagel_understanding.py new file mode 100644 index 00000000..ca8d9ac9 --- /dev/null +++ b/recipes/bagel/test_bagel_understanding.py @@ -0,0 +1,130 @@ +# flake8: noqa + +from rich import print +from eole.config.run import * +from eole.inference_engine import InferenceEnginePY + +config = PredictConfig( + model_path="./bagel", + src="dummy", + max_length=600, + gpu_ranks=[0], + # quant_type="bnb_NF4", + quant_type="bnb_FP4", # HF default, using it for initial reproducibility checks + quant_layers=[ + "gate_up_proj", + "down_proj", + "up_proj", + "linear_values", + "linear_query", + "linear_keys", + "final_linear", + "w_in", + "w_out", + ], + compute_dtype="bf16", + top_k=0, + top_p=0.0, + # top_p=0.8, + # temperature=0.35, + # beam_size=5, + beam_size=1, + # temperature=0.3, + seed=42, + batch_size=1, + batch_type="sents", + self_attn_backend="pytorch", + # self_attn_backend="flash", # not properly supported (mixed masking) +) + +print(config) + +config.data_type = "image" +engine = InferenceEnginePY(config) + +print(engine.predictor.model) +engine.predictor.model.count_parameters() + +test_input = [ + { + # "text": "<|im_start|>List the top 5 countries in Europe with the highest GDP from this image<|im_end|>\n{image1}\n", + "text": "{image1}<|im_start|>List the top 5 countries in Europe with the highest GDP from this image<|im_end|><|im_start|>", + # "text": "{image1}", # replicate first pass of bagel with image only + "images": {"image1": "../../eole/tests/data/images/gdp.png"}, + }, + # { + # # "text": "{image1}<|im_start|>When did things start to go wrong for dark dragon?<|im_end|>", + # # "text": "{image1}", + # # "text": "{image1}<|im_start|>Describe<|im_end|><|im_start|>", # bagel weirdly starts decoding by adding a <|im_start|> token + # "text": "{image1}<|im_start|>When did things start to go wrong for dark dragon?<|im_end|><|im_start|>", + # "images": { + # "image1": "../../eole/tests/data/images/loss_curve.jpg" + # } + # }, + # { + # "text": "{image1}<|im_start|>Which model is best?<|im_end|>", + # "images": { + # "image1": "../../eole/tests/data/images/loss_curve.jpg" + # } + # }, + # { + # "text": "{image1}<|im_start|>Can someone explain what’s funny about this meme??<|im_end|>", + # "images": { + # "image1": "./BAGEL/test_images/meme.jpg" + # } + # } + # { + # "text": "{image1}<|im_start|>Is this person really big, or is this building just super small?<|im_end|>", + # "images": { + # "image1": "../../eole/tests/data/images/pisa_2.jpg" + # } + # }, + # { + # "text": "user\nCombine information in both the tables into a single markdown table\n{image1}\n{image2}model\n", + # "images": { + # "image1": "../../eole/tests/data/images/table1.png", + # "image2": "../../eole/tests/data/images/table2.png", + # }, + # }, + # { + # "text": "user\nCombine information in both the tables into a single markdown table\n{image1}model\n", + # "images": { + # "image1": "../../eole/tests/data/images/multi-images.png" + # } + # }, + # { + # "text": "user\nDescribe the images.\n{image1}\n{image2}\n{image3}\n{image4}model\n", + # "images": { + # "image1": "../../eole/tests/data/images/image1.png", + # "image2": "../../eole/tests/data/images/image2.png", + # "image3": "../../eole/tests/data/images/image3.png", + # "image4": "../../eole/tests/data/images/image4.png", + # } + # }, + # { + # "text": "user\nCombine information in both the tables into a single markdown table\n{image1}{image2}model\n", + # "images": { + # "image1": "../../eole/tests/data/images/table1.png", + # "image2": "../../eole/tests/data/images/table2.png" + # } + # }, +] + +import torch +import numpy as np +import random + +seed = 42 +random.seed(seed) +np.random.seed(seed) +torch.manual_seed(seed) +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +pred = engine.infer_list(test_input) + +print(pred) +print(pred[2][0][0].replace("⦅newline⦆", "\n")) diff --git a/setup.py b/setup.py index 15cfdf10..6e3d53eb 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ install_requires=[ "configargparse", "ctranslate2>=4,<5", + "einops", "fastapi", "fasttext-wheel", "huggingface_hub",