-
Notifications
You must be signed in to change notification settings - Fork 31.8k
[GLM-Image] Add batch > 1 support and fix configuration defaults #43342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: JaredforReal <[email protected]>
Signed-off-by: JaredforReal <[email protected]>
Signed-off-by: JaredforReal <[email protected]>
Signed-off-by: JaredforReal <[email protected]>
Signed-off-by: JaredforReal <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds comprehensive batch processing support (batch_size > 1) for the GLM-Image model, enabling efficient parallel image generation. Previously, the processor explicitly rejected batch sizes greater than 1. The implementation introduces two new tracking tensors (images_per_sample and num_source_images_per_sample) to manage packed image grids across batch samples, updates the RoPE position encoding computation to work per-sample, and modifies the generation utilities to correctly expand inputs for beam search.
Changes:
- Removed batch size restriction in processor and added per-sample image tracking
- Updated position ID computation to handle batches with independent per-sample caching
- Modified beam search expansion logic to correctly handle packed visual inputs across batch samples
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/models/glm_image/test_modeling_glm_image.py | Adds test_batch_consistency to verify batch and single processing produce identical predictions |
| src/transformers/models/glm_image/processing_glm_image.py | Removes batch size restriction, adds image counting per sample, and implements per-sample prompt/grid construction |
| src/transformers/models/glm_image/modular_glm_image.py | Updates get_rope_index() for per-sample position IDs, modifies forward() to split packed grids, updates generation utilities for batch support |
| src/transformers/models/glm_image/modeling_glm_image.py | Auto-generated from modular file with matching changes for batch support |
| src/transformers/models/glm_image/image_processing_glm_image_fast.py | Removes min_pixels/max_pixels class attributes and simplifies initialization |
| src/transformers/models/glm_image/configuration_glm_image.py | Reorders tie_word_embeddings parameter and removes it from super().init() call |
| num_target_grids = all_target_grids[0].shape[0] | ||
| image_inputs["images_per_sample"] = torch.tensor( | ||
| [n + num_target_grids for n in images_per_sample], dtype=torch.long |
Copilot
AI
Jan 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line assumes all samples have the same number of target grids by using all_target_grids[0].shape[0]. If different samples could have different numbers of target grids (e.g., due to different is_text_to_image settings), this would cause incorrect counting. Consider validating that all samples have the same num_target_grids, or handle varying target grid counts per sample.
| num_target_grids = all_target_grids[0].shape[0] | |
| image_inputs["images_per_sample"] = torch.tensor( | |
| [n + num_target_grids for n in images_per_sample], dtype=torch.long | |
| target_grids_per_sample = [grids.shape[0] for grids in all_target_grids] | |
| image_inputs["images_per_sample"] = torch.tensor( | |
| [n_source + n_target for n_source, n_target in zip(images_per_sample, target_grids_per_sample)], | |
| dtype=torch.long, |
| position_ids[:, batch_idx, :] = curr_position_ids | ||
|
|
||
| all_prefill_lens.append(seq_len) | ||
| all_final_positions.append(current_pos) |
Copilot
AI
Jan 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable all_final_positions is populated in the loop but never used. Consider removing this unused variable to improve code clarity.
| all_final_positions.append(current_pos) |
| num_target_grids = all_target_grids[0].shape[0] | ||
| image_inputs["images_per_sample"] = torch.tensor( | ||
| [n + num_target_grids for n in images_per_sample], dtype=torch.long | ||
| ) |
Copilot
AI
Jan 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line assumes all samples have the same number of target grids by using all_target_grids[0].shape[0]. If different samples could have different numbers of target grids (e.g., due to different is_text_to_image settings), this would cause incorrect counting. Consider validating that all samples have the same num_target_grids, or handle varying target grid counts per sample.
| num_target_grids = all_target_grids[0].shape[0] | |
| image_inputs["images_per_sample"] = torch.tensor( | |
| [n + num_target_grids for n in images_per_sample], dtype=torch.long | |
| ) | |
| num_target_grids_per_sample = [g.shape[0] for g in all_target_grids] | |
| if len(set(num_target_grids_per_sample)) == 1: | |
| num_target_grids = num_target_grids_per_sample[0] | |
| images_per_sample_with_targets = [n + num_target_grids for n in images_per_sample] | |
| else: | |
| images_per_sample_with_targets = [n + t for n, t in zip(images_per_sample, num_target_grids_per_sample)] | |
| image_inputs["images_per_sample"] = torch.tensor(images_per_sample_with_targets, dtype=torch.long) |
|
|
||
| dict_to_expand[key] = _repeat_interleave_samples( | ||
| dict_to_expand[key], lengths=lengths, repeat_times=expand_size | ||
| ) |
Copilot
AI
Jan 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When splitting pixel_values for beam search expansion, if sum(source_image_nums) == 0 (no source images), the pixel_values tensor is not handled. While the code checks if sum > 0 before processing, it doesn't explicitly handle the else case where pixel_values should remain unchanged or be set appropriately. Consider adding an explicit else clause or handling for the case where there are no source images.
| ) | |
| ) | |
| else: | |
| # No source images: leave pixel_values unchanged | |
| dict_to_expand[key] = dict_to_expand[key] |
| super().__init__( | ||
| tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs | ||
| ) | ||
| super().__init__(ignore_keys_at_rope_validation={"mrope_section"}, **kwargs) |
Copilot
AI
Jan 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The super().init() call no longer passes tie_word_embeddings. This means the parent class PreTrainedConfig won't receive this parameter. Verify that the parent class correctly handles tie_word_embeddings through **kwargs, or explicitly pass it if needed.
| super().__init__(ignore_keys_at_rope_validation={"mrope_section"}, **kwargs) | |
| super().__init__( | |
| tie_word_embeddings=tie_word_embeddings, | |
| ignore_keys_at_rope_validation={"mrope_section"}, | |
| **kwargs, | |
| ) |
| position_ids[:, batch_idx, :] = curr_position_ids | ||
|
|
||
| all_prefill_lens.append(seq_len) | ||
| all_final_positions.append(current_pos) |
Copilot
AI
Jan 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable all_final_positions is populated in the loop but never used. Consider removing this unused variable to improve code clarity.
| all_final_positions.append(current_pos) |
| # Per-sample decode position lookup | ||
| # _cached_decode_position_ids shape: [batch_size, 3, max_decode_len] | ||
| # _prefill_lens shape: [batch_size] | ||
| position_ids_list = [] | ||
| for batch_idx in range(batch_size): | ||
| prefill_len = self._prefill_lens[batch_idx].item() | ||
| step = cache_position[0].item() - prefill_len | ||
| # Get position ids for this sample | ||
| sample_pos_ids = self._cached_decode_position_ids[batch_idx, :, step : step + seq_length] | ||
| position_ids_list.append(sample_pos_ids) | ||
| # Stack and transpose to [3, batch_size, seq_length] | ||
| position_ids = torch.stack(position_ids_list, dim=1) |
Copilot
AI
Jan 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a potential issue with varying batch sizes during decode. If the batch size differs between prefill and decode stages (e.g., due to early stopping), the code attempts to index self._prefill_lens and self._cached_decode_position_ids with batch indices that may be out of bounds. Consider adding a check to ensure batch_size matches the cached dimensions.
| # Per-sample decode position lookup | |
| # _cached_decode_position_ids shape: [batch_size, 3, max_decode_len] | |
| # _prefill_lens shape: [batch_size] | |
| position_ids_list = [] | |
| for batch_idx in range(batch_size): | |
| prefill_len = self._prefill_lens[batch_idx].item() | |
| step = cache_position[0].item() - prefill_len | |
| # Get position ids for this sample | |
| sample_pos_ids = self._cached_decode_position_ids[batch_idx, :, step : step + seq_length] | |
| position_ids_list.append(sample_pos_ids) | |
| # Stack and transpose to [3, batch_size, seq_length] | |
| position_ids = torch.stack(position_ids_list, dim=1) | |
| # Only use cached decode positions if they are compatible with the current batch size. | |
| use_cached_positions = ( | |
| hasattr(self, "_cached_decode_position_ids") | |
| and hasattr(self, "_prefill_lens") | |
| and self._cached_decode_position_ids is not None | |
| and self._prefill_lens is not None | |
| and self._cached_decode_position_ids.shape[0] == batch_size | |
| ) | |
| if not use_cached_positions: | |
| position_ids, rope_deltas = self.get_rope_index( | |
| input_ids, | |
| image_grid_thw, | |
| images_per_sample=images_per_sample, | |
| attention_mask=attention_mask_2d, | |
| ) | |
| self.rope_deltas = rope_deltas | |
| else: | |
| # Per-sample decode position lookup | |
| # _cached_decode_position_ids shape: [batch_size, 3, max_decode_len] | |
| # _prefill_lens shape: [batch_size] | |
| position_ids_list = [] | |
| for batch_idx in range(batch_size): | |
| prefill_len = self._prefill_lens[batch_idx].item() | |
| step = cache_position[0].item() - prefill_len | |
| # Get position ids for this sample | |
| sample_pos_ids = self._cached_decode_position_ids[batch_idx, :, step : step + seq_length] | |
| position_ids_list.append(sample_pos_ids) | |
| # Stack and transpose to [3, batch_size, seq_length] | |
| position_ids = torch.stack(position_ids_list, dim=1) |
| patch_size = 14 | ||
| temporal_patch_size = 2 | ||
| merge_size = 2 | ||
| min_pixels = None | ||
| max_pixels = None | ||
| valid_kwargs = GlmImageImageProcessorKwargs |
Copilot
AI
Jan 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The removal of min_pixels and max_pixels as class attributes could be a breaking change for users who directly access these attributes (e.g., processor.min_pixels). Consider documenting this change or maintaining backward compatibility by keeping them as None defaults in the class definition.
| def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): | ||
| """Applies Rotary Position Embedding to the query and key tensors. | ||
| Args: | ||
| q (`torch.Tensor`): The query tensor. | ||
| k (`torch.Tensor`): The key tensor. | ||
| cos (`torch.Tensor`): The cosine part of the rotary embedding. | ||
| sin (`torch.Tensor`): The sine part of the rotary embedding. | ||
| position_ids (`torch.Tensor`, *optional*): | ||
| Deprecated and unused. | ||
| unsqueeze_dim (`int`, *optional*, defaults to 1): | ||
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and | ||
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
Copilot
AI
Jan 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unused position_ids parameter has been removed from the function signature. However, the docstring still references this parameter in the Args section. The docstring should be updated to reflect this change.
Signed-off-by: JaredforReal <[email protected]>
Signed-off-by: JaredforReal <[email protected]>
Signed-off-by: JaredforReal <[email protected]>
Signed-off-by: JaredforReal <[email protected]>
Signed-off-by: JaredforReal <[email protected]>
Signed-off-by: JaredforReal <[email protected]>
|
[For maintainers] Suggested jobs to run (before merge) run-slow: glm_image |
What does this PR do?
This PR adds full batch processing support (batch_size > 1) for the GLM-Image model, fixes padding direction for autoregressive generation, and aligns configuration defaults with the official model.
Problem
GlmImageProcessorexplicitly rejected batch_size > 1pad_token_idandeos_token_idwere not set inGlmImageTextConfig, andmax_position_embeddingsdidn't match the official config.jsonSolution
1. Batch support
Processor changes (
processing_glm_image.py):images_per_sampletensor to track number of grids per samplenum_source_images_per_sampletensor to distinguish source images from target gridsModel changes (
modeling_glm_image.pyviamodular_glm_image.py):get_rope_index()to compute position IDs per-sample with batch support_cached_decode_position_idsshape from[3, max_len]to[batch, 3, max_len]forward()to properly handle packedimage_grid_thwby splitting per sample_expand_inputs_for_generation()andprepare_inputs_for_generation()for beam search compatibility2. Left padding for autoregressive generation
Processor changes (
processing_glm_image.py):padding=Trueis specified3. Configuration alignment with official model
Config changes (
modular_glm_image.py→configuration_glm_image.py):pad_token_id=167841toGlmImageTextConfigeos_token_id=16385toGlmImageTextConfigmax_position_embeddingsfrom 32768 to 131072 to match official config.jsonTests
All 64 modeling tests pass. Added
test_batch_consistencyto verify:Breaking changes
None. All new parameters are optional and backward compatible.
Checklist
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.