diff --git a/benchmark/recognition.py b/benchmark/recognition.py index 45ed0417..24c7364d 100644 --- a/benchmark/recognition.py +++ b/benchmark/recognition.py @@ -103,6 +103,7 @@ def normalize_text(text: str) -> str: help="Comma-separated list of languages to benchmark.", default=None, ) +@click.option("--xla_eager", is_flag=True, help="Use XLA eager mode for Surya.") def main( results_dir: str, max_rows: int, @@ -112,7 +113,13 @@ def main( tess_cpus: int, textract_cpus: int, languages: str | None, + xla_eager: bool = False, ): + if xla_eager: + import torch_xla + + torch_xla.experimental.eager_mode(True) + foundation_predictor = FoundationPredictor() rec_predictor = RecognitionPredictor(foundation_predictor) diff --git a/surya/common/adetr/decoder.py b/surya/common/adetr/decoder.py index 7c0f77e9..ef80c51a 100644 --- a/surya/common/adetr/decoder.py +++ b/surya/common/adetr/decoder.py @@ -11,7 +11,7 @@ from transformers.modeling_outputs import BaseModelOutputWithNoAttention from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from surya.common.util import mark_step +from surya.common.xla import mark_step _MAX_SQRT_GRADIENT = 1000.0 @@ -20,6 +20,7 @@ class WrappedEmbedding(nn.Embedding): def forward(self, input_ids, *args, **kwargs): return super().forward(input_ids) + class SuryaADETRDecoderRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -41,9 +42,9 @@ def forward(self, x): # Clamp to float16 range f16_info = torch.finfo(x.dtype) output = output.clamp(min=f16_info.min, max=f16_info.max) - output = torch.where(torch.isnan(output), - torch.tensor(0.0, device=output.device), - output) + output = torch.where( + torch.isnan(output), torch.tensor(0.0, device=output.device), output + ) return output.type_as(x) def extra_repr(self): @@ -58,7 +59,10 @@ def __init__(self, dim, base=10000, device=None): super().__init__() self.dim = dim self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim) + ) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() @@ -66,10 +70,14 @@ def __init__(self, dim, base=10000, device=None): def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] self.inv_freq.to(x.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( + 1, 2 + ) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() @@ -119,7 +127,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -138,10 +148,24 @@ def __init__(self, config: PretrainedConfig): self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads - self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) + self.q_proj = nn.Linear( + self.hidden_size, + self.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + self.config.encoder_hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + self.config.encoder_hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + self.num_attention_heads * self.head_dim, self.hidden_size, bias=True + ) self.rotary_emb = SuryaADETRDecoderRotaryEmbedding( self.head_dim, base=config.rope_theta, @@ -161,13 +185,19 @@ def forward( _, v_len, _ = encoder_hidden_states.size() query_states = self.q_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_attention_heads, self.head_dim + ).transpose(1, 2) if self.key_states is None: key_states = self.k_proj(encoder_hidden_states) value_states = self.v_proj(encoder_hidden_states) - key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view( + bsz, v_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, v_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) if use_cache: self._update_cache(key_states, value_states) else: @@ -223,10 +253,24 @@ def __init__(self, config: PretrainedConfig, static_cache=False, max_boxes=None) self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads - self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) + self.q_proj = nn.Linear( + self.hidden_size, + self.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + self.num_attention_heads * self.head_dim, self.hidden_size, bias=True + ) self.rotary_emb = SuryaADETRDecoderRotaryEmbedding( self.head_dim, base=config.rope_theta, @@ -251,16 +295,29 @@ def forward( value_states = self.v_proj(hidden_states) # Final is bsz, num_attention_heads, seq_len, head_dim - query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_attention_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) if use_cache and hasattr(self, "key_states"): - cache_kwargs = {"cache_position": cache_position, "window_attn": window_attn} - key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs) + cache_kwargs = { + "cache_position": cache_position, + "window_attn": window_attn, + } + key_states, value_states = self._update_cache( + key_states, value_states, **cache_kwargs + ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -268,10 +325,12 @@ def forward( causal_mask = attention_mask if attention_mask is not None: # Mask is batch, head, seq_len, kv_len - causal_mask = causal_mask[:, :, :, :key_states.shape[-2]] + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] if cache_position is not None and self.static_cache: current_pos = cache_position[-1] - causal_mask[:, :, :, current_pos + 1:] = torch.finfo(causal_mask.dtype).min + causal_mask[:, :, :, current_pos + 1 :] = torch.finfo( + causal_mask.dtype + ).min mark_step() attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -299,7 +358,12 @@ def _setup_cache(self, batch_size, device, dtype=None): self.key_states = None if self.static_cache: - cache_shape = (batch_size, self.num_key_value_heads, self.max_boxes, self.head_dim) + cache_shape = ( + batch_size, + self.num_key_value_heads, + self.max_boxes, + self.head_dim, + ) self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device) self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device) @@ -311,7 +375,10 @@ def _clear_cache(self): def _update_static_cache(self, key_states, value_states, **cache_kwargs): cache_position = cache_kwargs.get("cache_position") - k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device) + k_out, v_out = ( + self.key_states.to(key_states.device), + self.value_states.to(value_states.device), + ) k_out[:, :, cache_position] = key_states.to(k_out.dtype) v_out[:, :, cache_position] = value_states.to(v_out.dtype) @@ -360,36 +427,50 @@ def forward(self, x): class SuryaADETRDecoderLayer(nn.Module): def __init__(self, config, layer_idx, static_cache=False, max_boxes=None): super().__init__() - self.cross_pre_norm = SuryaADETRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.temporal_pre_norm = SuryaADETRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.cross_pre_norm = SuryaADETRDecoderRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.temporal_pre_norm = SuryaADETRDecoderRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.temporal_block = None if layer_idx in config.self_attn_layers: - self.temporal_block = SuryaADETRDecoderSdpaAttention(config, static_cache=static_cache, max_boxes=max_boxes) + self.temporal_block = SuryaADETRDecoderSdpaAttention( + config, static_cache=static_cache, max_boxes=max_boxes + ) self.cross_attn_block = None if layer_idx in config.cross_attn_layers: self.cross_attn_block = SuryaADETRDecoderSdpaCrossAttention(config) self.window_attn = layer_idx not in config.global_attn_layers - self.channel_pre_norm = SuryaADETRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.channel_pre_norm = SuryaADETRDecoderRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.mlp_block = SuryaADETRDecoderMlp(config) self.double_residual_flow = getattr(config, "double_residual_flow", False) def forward( - self, - activations: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - encoder_attention_mask: torch.Tensor = None, - cache_position: torch.Tensor = None, - use_cache: bool = None, + self, + activations: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_attention_mask: torch.Tensor = None, + cache_position: torch.Tensor = None, + use_cache: bool = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: if self.double_residual_flow: return self.double_res_forward( - activations, position_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache + activations, + position_ids, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + cache_position, + use_cache, ) hidden_states = activations @@ -397,15 +478,25 @@ def forward( # Do cross-attention on encoder outputs cross_attn_inputs = self.cross_pre_norm(hidden_states) cross_attn_path = self.cross_attn_block( - cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache + cross_attn_inputs, + encoder_hidden_states, + attention_mask, + encoder_attention_mask, + use_cache=use_cache, ) hidden_states = cross_attn_path + hidden_states if self.temporal_block is not None: - temporal_inputs = self.temporal_pre_norm(hidden_states) # RMSNorm introduces slight slight differences + temporal_inputs = self.temporal_pre_norm( + hidden_states + ) # RMSNorm introduces slight slight differences temporal_path = self.temporal_block( - temporal_inputs, position_ids, attention_mask, cache_position=cache_position, - use_cache=use_cache, window_attn=self.window_attn + temporal_inputs, + position_ids, + attention_mask, + cache_position=cache_position, + use_cache=use_cache, + window_attn=self.window_attn, ) hidden_states = temporal_path + hidden_states @@ -433,16 +524,27 @@ def double_res_forward( # Do cross-attention on encoder outputs cross_attn_inputs = self.cross_pre_norm(activations) cross_attn_path = self.cross_attn_block( - cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache + cross_attn_inputs, + encoder_hidden_states, + attention_mask, + encoder_attention_mask, + use_cache=use_cache, ) cross_attn_output = cross_attn_path + raw_activations else: cross_attn_output = raw_activations if self.temporal_block is not None: - inputs_normalized = self.temporal_pre_norm(cross_attn_output) # RMSNorm introduces slight slight differences + inputs_normalized = self.temporal_pre_norm( + cross_attn_output + ) # RMSNorm introduces slight slight differences hidden_states = self.temporal_block( - inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn + inputs_normalized, + position_ids, + attention_mask, + cache_position=cache_position, + use_cache=use_cache, + window_attn=self.window_attn, ) residual = hidden_states + raw_activations @@ -469,11 +571,19 @@ class SuryaADETRDecoderPreTrainedModel(PreTrainedModel): def _init_weights(self, module): if isinstance(module, SuryaADETRDecoderSdpaAttention): - torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=self.config.init_std) - torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=self.config.init_std) - torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=self.config.init_std) + torch.nn.init.normal_( + module.q_proj.weight, mean=0.0, std=self.config.init_std + ) + torch.nn.init.normal_( + module.k_proj.weight, mean=0.0, std=self.config.init_std + ) + torch.nn.init.normal_( + module.v_proj.weight, mean=0.0, std=self.config.init_std + ) - torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=self.config.init_std) + torch.nn.init.normal_( + module.o_proj.weight, mean=0.0, std=self.config.init_std + ) elif isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) if getattr(module, "bias", None) is not None: @@ -518,11 +628,11 @@ class SuryaADETRDecoderModel(SuryaADETRDecoderPreTrainedModel): """ def __init__( - self, - config: PretrainedConfig, - embedder: nn.Module = None, - max_boxes: int = None, - static_cache: bool = False + self, + config: PretrainedConfig, + embedder: nn.Module = None, + max_boxes: int = None, + static_cache: bool = False, ): super().__init__(config) self.padding_idx = config.pad_token_id @@ -534,13 +644,22 @@ def __init__( self.static_cache = static_cache self.layers = nn.ModuleList( - [SuryaADETRDecoderLayer(config, layer_idx, static_cache=static_cache, max_boxes=max_boxes) for layer_idx in range(config.num_hidden_layers)] + [ + SuryaADETRDecoderLayer( + config, layer_idx, static_cache=static_cache, max_boxes=max_boxes + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.final_norm = SuryaADETRDecoderRMSNorm( + config.hidden_size, eps=config.rms_norm_eps ) - self.final_norm = SuryaADETRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False self.register_buffer( - "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False + "normalizer", + torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), + persistent=False, ) # Initialize weights and apply final processing self.post_init() @@ -566,10 +685,12 @@ def forward( use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - prefill: bool = False + prefill: bool = False, ) -> Union[Tuple, BaseModelOutputWithNoAttention]: use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) if self.gradient_checkpointing and self.training and use_cache: use_cache = False @@ -578,14 +699,23 @@ def forward( hidden_states = inputs_embeds if use_cache and prefill: - self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype) + self._setup_cache( + self.config, + hidden_states.shape[0], + hidden_states.device, + hidden_states.dtype, + ) if cache_position is None: - cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + cache_position = torch.arange( + hidden_states.shape[1], device=hidden_states.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position + ) all_hidden_states = () if output_hidden_states else None for i, residual_block in enumerate(self.layers): @@ -593,10 +723,25 @@ def forward( all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( - residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache + residual_block.__call__, + hidden_states, + position_ids, + causal_mask, + encoder_hidden_states, + encoder_attention_mask, + cache_position, + use_cache, ) else: - hidden_states = residual_block(hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache) + hidden_states = residual_block( + hidden_states, + position_ids, + causal_mask, + encoder_hidden_states, + encoder_attention_mask, + cache_position, + use_cache, + ) hidden_states = self.final_norm(hidden_states) @@ -626,27 +771,44 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): sequence_length = input_tensor.shape[1] target_length = max(self.max_boxes, sequence_length) - diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + diagonal = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) causal_mask = diagonal if sequence_length != 1: # Select the upper triangular part of the matrix, but unmask current token (the diagonal) # triu will be the min_dtype, everything else is 0 (attended to) causal_mask = torch.triu(diagonal, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + causal_mask *= torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand( + input_tensor.shape[0], 1, -1, -1 + ) if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: # Mask positions in the causal mask that are masked in the attention mask mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[ + :, None, None, : + ].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[ + ..., :mask_length + ].masked_fill(padding_mask, min_dtype) if attention_mask is not None and attention_mask.device.type == "cuda": # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) - return causal_mask \ No newline at end of file + return causal_mask diff --git a/surya/common/donut/encoder.py b/surya/common/donut/encoder.py index 8d59a8ba..0dafd96c 100644 --- a/surya/common/donut/encoder.py +++ b/surya/common/donut/encoder.py @@ -17,7 +17,7 @@ from transformers.utils import ModelOutput from transformers import DonutSwinConfig -from surya.common.util import mark_step +from surya.common.xla import mark_step _EXPECTED_OUTPUT_SHAPE = [1, 49, 1024] diff --git a/surya/common/predictor.py b/surya/common/predictor.py index efb3a437..4f0b2385 100644 --- a/surya/common/predictor.py +++ b/surya/common/predictor.py @@ -9,15 +9,16 @@ class BasePredictor: model_loader_cls = ModelLoader batch_size: Optional[int] = None - default_batch_sizes = { - "cpu": 1, - "mps": 1, - "cuda": 1 - } + default_batch_sizes = {"cpu": 1, "mps": 1, "cuda": 1} disable_tqdm: bool = settings.DISABLE_TQDM torch_dtype = settings.MODEL_DTYPE - def __init__(self, checkpoint: Optional[str] = None, device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, dtype: Optional[torch.dtype | str] = None): + def __init__( + self, + checkpoint: Optional[str] = None, + device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, + dtype: Optional[torch.dtype | str] = None, + ): if dtype is None: dtype = self.torch_dtype @@ -43,7 +44,9 @@ def get_batch_size(self): return batch_size @staticmethod - def pad_to_batch_size(tensor: torch.Tensor, batch_size: int): + def pad_to_batch_size( + tensor: torch.Tensor, batch_size: int, pad_value: int = 0 + ) -> torch.Tensor: current_batch_size = tensor.shape[0] if current_batch_size >= batch_size: return tensor @@ -51,7 +54,7 @@ def pad_to_batch_size(tensor: torch.Tensor, batch_size: int): pad_size = batch_size - current_batch_size padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) - return F.pad(tensor, padding, mode='constant', value=0) + return F.pad(tensor, padding, mode="constant", value=pad_value) def __call__(self, *args, **kwargs): - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() diff --git a/surya/common/surya/__init__.py b/surya/common/surya/__init__.py index 3f73151a..e37e68c5 100644 --- a/surya/common/surya/__init__.py +++ b/surya/common/surya/__init__.py @@ -1,5 +1,4 @@ from typing import Optional, Tuple, TypedDict -import warnings from dataclasses import dataclass import torch @@ -56,6 +55,7 @@ class FlashAttentionKwargs(TypedDict, total=False): class KwargsForCausalLM(FlashAttentionKwargs): ... + class DistanceProjection(nn.Module): def __init__(self, in_features: int, out_features: int): super().__init__() @@ -75,6 +75,7 @@ def init_weights(self): nn.init.zeros_(self.fc1.bias) nn.init.zeros_(self.fc2.bias) + class SuryaModel(S3DownloaderMixin, PreTrainedModel): config_class = SuryaModelConfig supports_gradient_checkpointing = True @@ -166,88 +167,115 @@ def maybe_static_pad_image_inputs( chunk_pixels: torch.Tensor, chunk_grid_thw: torch.Tensor, actual_chunk_len: int, - encoder_chunk_size: int + encoder_chunk_size: int, ) -> Tuple[torch.Tensor, torch.Tensor]: - valid_embed_len = actual_chunk_len // (self.vision_encoder.spatial_merge_size ** 2) + valid_embed_len = actual_chunk_len // ( + self.vision_encoder.spatial_merge_size**2 + ) if settings.FOUNDATION_STATIC_CACHE and actual_chunk_len < encoder_chunk_size: padding_len = encoder_chunk_size - actual_chunk_len - padding = torch.zeros( - padding_len, - *chunk_pixels.shape[1:], - device=chunk_pixels.device, - dtype=chunk_pixels.dtype + chunk_pixels = F.pad( + chunk_pixels, + (0, 0, 0, padding_len), + mode="constant", + value=0.0, ) - chunk_pixels = torch.cat([chunk_pixels, padding], dim=0) - + padding_grid = torch.tensor( [[1, 2, padding_len // 2]], device=chunk_grid_thw.device, - dtype=chunk_grid_thw.dtype + dtype=chunk_grid_thw.dtype, ) chunk_grid_thw = torch.cat([chunk_grid_thw, padding_grid], dim=0) return chunk_pixels, chunk_grid_thw, valid_embed_len - def get_image_embeddings( self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, encoder_chunk_size: int | None, + image_tile_length: torch.Tensor | None = None, + valid_batch_size: torch.Tensor | None = None, ): # embed all images with the vision encoder after they have already been tiled and flattened into a single batch chunks = [0] grid_chunks = [0] curr_chunk_len = 0 curr_seq_len = 0 - for i in range(len(grid_thw)): - curr_chunk_len += (grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2]).item() - if curr_chunk_len > encoder_chunk_size: + chunk_tokens = [] + grid_chunk_size = [] + curr_grid_len = 0 + for i in range(len(grid_thw[:valid_batch_size])): + curr_sample_len = grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2] + + if ( + curr_chunk_len > (encoder_chunk_size - curr_sample_len) + and curr_chunk_len > 0 + ): + chunk_tokens.append(curr_chunk_len) chunks.append(curr_chunk_len + curr_seq_len) curr_seq_len += curr_chunk_len curr_chunk_len = 0 - grid_chunks.append(i + 1) + grid_chunks.append(i) + grid_chunk_size.append(curr_grid_len) + curr_grid_len = 0 + + curr_chunk_len += curr_sample_len + curr_grid_len += 1 if curr_chunk_len > 0: - chunks.append(pixel_values.shape[0]) - grid_chunks.append(len(grid_thw)) + chunks.append(image_tile_length) + grid_chunks.append(valid_batch_size) + chunk_tokens.append(curr_chunk_len) + grid_chunk_size.append(curr_grid_len) - assert curr_chunk_len + curr_seq_len == pixel_values.shape[0], ( - f"Mismatch in encoder chunking, {curr_chunk_len} + {curr_seq_len} != {pixel_values.shape[0]}" + assert curr_chunk_len + curr_seq_len == image_tile_length, ( + f"Mismatch in encoder chunking, {curr_chunk_len} + {curr_seq_len} != {image_tile_length}" ) logger.debug( f"Chunking encoder sequence into {len(chunks) - 1} chunks of size {encoder_chunk_size} with lengths {chunks} and grids {grid_chunks}" ) - embeddings = [] + + final_length = image_tile_length // 4 + embeddings = torch.zeros( + final_length, + self.vision_encoder.config.hidden_size, + dtype=pixel_values.dtype, + device=self.device, + ) + out_start = 0 + grid_thw = grid_thw.to(self.device) for i in range(len(chunks) - 1): start = chunks[i] end = chunks[i + 1] grid_start = grid_chunks[i] grid_end = grid_chunks[i + 1] - + chunk_pixels = pixel_values[start:end] chunk_grid_thw = grid_thw[grid_start:grid_end] actual_chunk_len = end - start - chunk_pixels, chunk_grid_thw, valid_embed_len = self.maybe_static_pad_image_inputs(chunk_pixels, chunk_grid_thw, actual_chunk_len, encoder_chunk_size) - - chunk_embeddings = self.vision_encoder.embed_images( - image_batch=chunk_pixels, - grid_thw=chunk_grid_thw + chunk_pixels, chunk_grid_thw, valid_embed_len = ( + self.maybe_static_pad_image_inputs( + chunk_pixels, chunk_grid_thw, actual_chunk_len, encoder_chunk_size + ) + ) + logger.debug( + f"Inferencing chunk {i} with size {chunk_pixels.shape} and grid {chunk_grid_thw.shape}" ) - embeddings.append(chunk_embeddings[:valid_embed_len]) - if len(embeddings) == 0: - raise ValueError( - "No image embeddings were generated. Check the input images and grid sizes." + chunk_embeddings = self.vision_encoder.embed_images( + image_batch=chunk_pixels.to(self.device), grid_thw=chunk_grid_thw ) - elif len(embeddings) == 1: - embeddings = embeddings[0] - else: - embeddings = torch.cat(embeddings, dim=0) + embeddings[out_start : (out_start + valid_embed_len)] = chunk_embeddings[ + :valid_embed_len + ] + out_start += valid_embed_len.item() encoding_2d = self.get_2d_learned_embeddings( grid_thw, + valid_batch_size=valid_batch_size, device=embeddings.device, bbox_size=self.config.image_embed_encoding_multiplier, ) @@ -259,11 +287,14 @@ def get_image_embeddings( ) embeddings = embeddings + encoding_2d - return embeddings def embed_ids_boxes_images( - self, input_ids, pixel_values, grid_thw, encoder_chunk_size: int + self, + input_ids, + image_embeddings, + encoder_chunk_size: int, + valid_batch_size: torch.Tensor | None = None, ): """ Insert embedded image tiles into the corresponding positions into the full input sequence @@ -272,25 +303,23 @@ def embed_ids_boxes_images( """ # This is batched in the inner call inputs_embeds = self.embedder.embed(input_tokens=input_ids) - if pixel_values is not None: - image_features = self.get_image_embeddings( - pixel_values=pixel_values, - grid_thw=grid_thw, - encoder_chunk_size=encoder_chunk_size, - ) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds) - if inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = torch.sum((input_ids == self.config.image_token_id)) - n_image_features = image_features.shape[0] * image_features.shape[1] - warnings.warn( - f"Image features and image tokens do not match: tokens {n_image_tokens}, features {n_image_features}. This may lead to unexpected results" - ) - image_features = image_features.to(inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter( - special_image_mask, image_features + if image_embeddings is not None: + image_token_id_tensor = torch.tensor( + self.config.image_token_id, + device=image_embeddings.device, + dtype=torch.long, + ) + mask = input_ids == image_token_id_tensor + flat = inputs_embeds.view(-1, inputs_embeds.size(-1)) # (B·L) x D + flat_mask = mask.view(-1) # (B·L) + flat.index_copy_( + 0, + flat_mask.nonzero(as_tuple=False).squeeze(1).to(torch.long), + image_embeddings.to(flat.dtype), ) + + inputs_embeds = flat.view_as(inputs_embeds) else: assert (input_ids == self.config.image_token_id).sum() == 0, ( "Image tokens were present in the input but no input images were provided" @@ -301,74 +330,70 @@ def embed_ids_boxes_images( def get_2d_learned_embeddings( self, grid_thw, + valid_batch_size: int, device: str | torch.device = "cpu", bbox_size: int = 256, ): - all_embeddings = [] - for grid_t, grid_h, grid_w in grid_thw: - llm_grid_h, llm_grid_w = ( - grid_h // self.config.merge_size, - grid_w // self.config.merge_size, - ) - - # Scale to 0-1024 - llm_grid_h = ( - torch.arange(llm_grid_h, device=device) - / max(1, (llm_grid_h - 1)) - * bbox_size - ) - llm_grid_w = ( - torch.arange(llm_grid_w, device=device) - / max(1, (llm_grid_w - 1)) - * bbox_size - ) - - llm_grid_w_idx = llm_grid_w.to(torch.long) - llm_grid_h_idx = llm_grid_h.to(torch.long) - - llm_grid_w = self.img_w_embed(llm_grid_w_idx) - llm_grid_h = self.img_h_embed(llm_grid_h_idx) - - full_grid = llm_grid_h[:, None] + llm_grid_w[None, :] - - flattened = full_grid.flatten( - 0, 1 - ) # Flatten first dimension, so they are seq_len x embed_dim - all_embeddings.append(flattened) - return torch.concat( - all_embeddings, dim=0 - ) # Shape is num_image_tokens x embed_dim + grid_thw = grid_thw[:valid_batch_size] # ────── (B,3) + dev = grid_thw.device + merge = self.config.merge_size + + # per-sample grid sizes after merge + H = (grid_thw[:, 1] // merge).long() # (B,) + W = (grid_thw[:, 2] // merge).long() # (B,) + + row_coords = torch.cat( + [ + torch.linspace(0, bbox_size, steps=int(h), device=dev) + .round() + .repeat_interleave(w) # repeat each row value w times + for h, w in zip(H.tolist(), W.tolist()) + ] + ) # (full_grid_size,) + + col_coords = torch.cat( + [ + torch.linspace(0, bbox_size, steps=int(w), device=dev) + .round() + .repeat(int(h)) # tile the column vector h times + for h, w in zip(H.tolist(), W.tolist()) + ] + ) # (full_grid_size,) + + emb = self.img_h_embed(row_coords.long()) + self.img_w_embed(col_coords.long()) + return emb def get_logits(self, hidden_states): - assert hidden_states.shape[1] == 1, "Multi output predictions only applied on the last token" + assert hidden_states.shape[1] == 1, ( + "Multi output predictions only applied on the last token" + ) all_lm_logits = [] all_bbox_logits = [] - + current_hidden = hidden_states - + # Loop includes initial prediction (i=0) plus multi_out_distance additional predictions for i in range(self.config.multi_output_distance + 1): if i > 0: - current_hidden = self.multi_output_projections[i-1](current_hidden) - + current_hidden = self.multi_output_projections[i - 1](current_hidden) + lm_logits = self.lm_head(current_hidden) bbox_logits = F.sigmoid(self.bbox_head(current_hidden)) - + all_lm_logits.append(lm_logits) all_bbox_logits.append(bbox_logits) - + # Concatenate along sequence dimension (dim=1) final_lm_logits = torch.cat(all_lm_logits, dim=1) final_bbox_logits = torch.cat(all_bbox_logits, dim=1) - + return final_lm_logits, final_bbox_logits def forward( self, input_ids=None, - image_tiles=None, - grid_thw=None, + image_embeddings=None, inputs_embeds=None, attention_mask=None, position_ids=None, @@ -382,20 +407,21 @@ def forward( num_valid_tokens=None, prefill=False, text_lengths=None, + valid_batch_size: torch.Tensor = None, **kwargs: KwargsForCausalLM, ): # Process the mixed batch if provided if inputs_embeds is None: inputs_embeds = self.embed_ids_boxes_images( - input_ids, image_tiles, grid_thw, encoder_chunk_size + input_ids, + image_embeddings, + encoder_chunk_size, + valid_batch_size, ) # Handling flash attention kwargs outside the decoder to speed up + avoid graph breaks inside the decoder # Skipped during decoding since not required - if ( - self.decoder.config._attn_implementation == "flash_attention_2" - and prefill - ): + if self.decoder.config._attn_implementation == "flash_attention_2" and prefill: batch_size, query_length, _ = inputs_embeds.shape indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( attention_mask @@ -580,4 +606,4 @@ def _prepare_4d_causal_attention_mask_with_cache_position( causal_mask[:, :, :, :mask_length] = causal_mask[ :, :, :, :mask_length ].masked_fill(padding_mask, min_dtype) - return causal_mask \ No newline at end of file + return causal_mask diff --git a/surya/common/surya/config.py b/surya/common/surya/config.py index 5fbd6730..4cf8797d 100644 --- a/surya/common/surya/config.py +++ b/surya/common/surya/config.py @@ -2,8 +2,6 @@ from transformers import PretrainedConfig from surya.common.s3 import S3DownloaderMixin -from surya.common.surya.encoder.config import SuryaEncoderConfig -from surya.common.surya.decoder.config import SuryaDecoderConfig class SuryaModelConfig(S3DownloaderMixin, PretrainedConfig): @@ -39,6 +37,9 @@ def __init__( max_multi_out: int = 8, **kwargs, ): + from surya.common.surya.encoder.config import SuryaEncoderConfig + from surya.common.surya.decoder.config import SuryaDecoderConfig + super().__init__(**kwargs) self.is_encoder_decoder = False self.vocab_size = vocab_size @@ -63,6 +64,9 @@ def __init__( self.num_beacon_tokens = num_beacon_tokens self.beacon_token_interval = beacon_token_interval self.sliding_window = sliding_window + if self.sliding_window is None: + self.sliding_window = 512 # Default to 512 + self.multi_output_distance = multi_output_distance self.max_multi_out = max_multi_out diff --git a/surya/common/surya/decoder/__init__.py b/surya/common/surya/decoder/__init__.py index d81cef58..cfce38a1 100644 --- a/surya/common/surya/decoder/__init__.py +++ b/surya/common/surya/decoder/__init__.py @@ -14,20 +14,19 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack -from transformers.utils import ( - logging, -) from surya.common.surya.decoder.config import SuryaDecoderConfig from transformers.utils import is_flash_attn_2_available +from surya.logging import get_logger + if is_flash_attn_2_available(): from surya.common.surya.flash_attn_utils import ( flash_attn_decode, flash_attn_prefill, ) -logger = logging.get_logger(__name__) +logger = get_logger() class Qwen2MLP(nn.Module): @@ -127,10 +126,9 @@ def eager_attention_forward( class Qwen2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: SuryaDecoderConfig, layer_idx: int): + def __init__(self, config: SuryaDecoderConfig): super().__init__() self.config = config - self.layer_idx = layer_idx self.head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ) @@ -160,7 +158,7 @@ def forward( attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - cache_idxs: Optional[List[int]] = None, + cache_idxs: Optional[torch.Tensor] = None, # padded num_valid_tokens: Optional[List[int]] = None, text_lengths: Optional[List[int]] = None, prefill: bool = False, @@ -180,7 +178,14 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - # cache_idxs, num_valid_tokens, and prefill add support for our new caching mechanism + # cache_idxs, num_valid_tokens, and prefill add support for our new caching mechanism + + # Recompiles without this + # We pass in a padded cache_idxs, so we need to compute the length of the cache + cache_idx_length = ( + torch.count_nonzero(cache_idxs > -1) if cache_idxs is not None else 0 + ) + cache_kwargs = { "sin": sin, "cos": cos, @@ -188,10 +193,11 @@ def forward( "cache_idxs": cache_idxs, "num_valid_tokens": num_valid_tokens, "prefill": prefill, - "text_lengths": text_lengths + "text_lengths": text_lengths, + "cache_idx_length": cache_idx_length, } key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs + key_states, value_states, cache_kwargs ) attention_interface: Callable = eager_attention_forward @@ -263,10 +269,10 @@ def extra_repr(self): class Qwen2DecoderLayer(nn.Module): - def __init__(self, config: SuryaDecoderConfig, layer_idx: int): + def __init__(self, config: SuryaDecoderConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + self.self_attn = Qwen2Attention(config=config) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm( @@ -282,7 +288,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - cache_idxs: Optional[List[int]] = None, + cache_idxs: Optional[torch.Tensor] = None, num_valid_tokens: Optional[List[int]] = None, text_lengths: Optional[List[int]] = None, prefill: bool = False, @@ -332,7 +338,7 @@ class Qwen2RotaryEmbedding(nn.Module): def __init__(self, config: SuryaDecoderConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): self.rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type") ) @@ -348,60 +354,18 @@ def __init__(self, config: SuryaDecoderConfig, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len - ) - self.register_buffer( - "inv_freq", inv_freq, persistent=False - ) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if ( - seq_len < self.original_max_seq_len - and self.max_seq_len_cached > self.original_max_seq_len - ): # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + self.inv_freq[None, :, None] + .expand(position_ids.shape[0], -1, 1) + .to(dtype=x.dtype) ) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = ( - device_type - if isinstance(device_type, str) and device_type != "mps" - else "cpu" - ) - with torch.autocast(device_type=device_type, enabled=False): - freqs = ( - inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float() - ).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling + position_ids_expanded = position_ids[:, None, :].to(dtype=x.dtype) + + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -447,10 +411,7 @@ def __init__(self, config: SuryaDecoderConfig): self.vocab_size = config.vocab_size self.layers = nn.ModuleList( - [ - Qwen2DecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ] + [Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) @@ -470,7 +431,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - cache_idxs: Optional[List[int]] = None, + cache_idxs: Optional[torch.Tensor] = None, num_valid_tokens: Optional[List[int]] = None, text_lengths: Optional[List[int]] = None, prefill: bool = False, @@ -482,33 +443,30 @@ def forward( ) if inputs_embeds is None: - raise ValueError( - "You must specify inputs_embeds" - ) + raise ValueError("You must specify inputs_embeds") if cache_position is None: - raise ValueError( - "You must specify cache_position" - ) + raise ValueError("You must specify cache_position") if position_ids is None: - raise ValueError( - "You must specify position_ids" - ) + raise ValueError("You must specify position_ids") hidden_states = inputs_embeds - causal_mask = attention_mask # We make the 4D mask in the combined model when needed + causal_mask = ( + attention_mask # We make the 4D mask in the combined model when needed + ) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i in range(self.config.num_hidden_layers): + decoder_layer = self.layers[i] layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_value=past_key_values.layer_caches[i], output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -528,4 +486,4 @@ def forward( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) - return output if return_dict else output.to_tuple() \ No newline at end of file + return output if return_dict else output.to_tuple() diff --git a/surya/common/surya/encoder/__init__.py b/surya/common/surya/encoder/__init__.py index 96045987..b6315666 100644 --- a/surya/common/surya/encoder/__init__.py +++ b/surya/common/surya/encoder/__init__.py @@ -1,22 +1,34 @@ -import logging import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel from transformers import PreTrainedModel from transformers.activations import ACT2FN from transformers.utils import is_flash_attn_2_available from surya.common.surya.encoder.config import SuryaEncoderConfig +from surya.common.xla import get_nearest_pad +from surya.logging import get_logger +from surya.settings import settings if is_flash_attn_2_available(): from flash_attn import flash_attn_varlen_func from flash_attn.layers.rotary import apply_rotary_emb # noqa +# This is for the custom xla kernel for flash attention +try: + import torch_xla + import torch_xla.distributed.spmd + import torch_xla.experimental.custom_kernel + import torch_xla.experimental.splash_attention +except ImportError: + pass -logger = logging.getLogger(__name__) + +logger = get_logger() class Qwen2_5_VLMLP(nn.Module): @@ -133,6 +145,176 @@ def apply_rotary_pos_emb_flashatt( return q_embed, k_embed +class Qwen2_5_VLVisionXLASDPAFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + self.head_dim = dim // num_heads + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states) + .reshape(seq_length, 3, self.num_heads, -1) + .permute(1, 0, 2, 3) + .unbind(0) + ) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.zeros( + [1, seq_length, seq_length], device=q.device, dtype=torch.bool + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]): + attn_output = F.scaled_dot_product_attention( + q.unsqueeze(0), + k.unsqueeze(0), + v.unsqueeze(0), + attention_mask, + dropout_p=0.0, + ) + attn_output = attn_output.squeeze(0).transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5_VLVisionXLASplashAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + self.head_dim = dim // num_heads + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + + # Single reshape to target layout - avoid multiple operations + qkv = self.qkv(hidden_states).view(seq_length, 3, self.num_heads, self.head_dim) + + # More efficient unbind - no permute needed + q, k, v = qkv.unbind(dim=1) # [seq_len, num_heads, head_dim] + + # Apply rotary embeddings if provided + if position_embeddings is not None: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + # Single reshape to flash attention format [batch, num_heads, seq_len, head_dim] + q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim] + k = k.unsqueeze(0).transpose(1, 2) + v = v.unsqueeze(0).transpose(1, 2) + + total_seqlen = cu_seqlens[-1].item() + # from cu_seqlens to segment ids for each position in dim 0 + segment_ids = torch.zeros((total_seqlen,), dtype=torch.int32, device=q.device) + for i in range(1, cu_seqlens.shape[0]): + segment_ids[cu_seqlens[i - 1] : cu_seqlens[i]] = i - 1 + segment_ids = segment_ids.reshape(1, -1) + + mesh = torch_xla.distributed.spmd.Mesh( + device_ids=[0], mesh_shape=(1, 1), axis_names=("data", "fsdp") + ) + + partition_spec = (None, None, None, None) + segment_ids_partition_spec = (None, None) + + splash_config = torch_xla.experimental.splash_attention.SplashAttentionConfig( + mesh=str(mesh), + qkv_partition_spec=partition_spec, + segment_ids_partition_spec=segment_ids_partition_spec, + ) + + attn_output = torch_xla.experimental.splash_attention.splash_attention( + q, k, v, splash_config.to_json(), decoder_segment_ids=segment_ids + ).reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5_VLVisionXLAFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + self.head_dim = dim // num_heads + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + + # Single reshape to target layout - avoid multiple operations + qkv = self.qkv(hidden_states).view(seq_length, 3, self.num_heads, self.head_dim) + + # More efficient unbind - no permute needed + q, k, v = qkv.unbind(dim=1) # [seq_len, num_heads, head_dim] + + # Apply rotary embeddings if provided + if position_embeddings is not None: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + # Single reshape to flash attention format [batch, num_heads, seq_len, head_dim] + q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim] + k = k.unsqueeze(0).transpose(1, 2) + v = v.unsqueeze(0).transpose(1, 2) + + total_seqlen = cu_seqlens[-1].item() + # from cu_seqlens to segment ids for each position in dim 0 + segment_ids = torch.zeros((total_seqlen,), dtype=torch.int32, device=q.device) + for i in range(1, cu_seqlens.shape[0]): + segment_ids[cu_seqlens[i - 1] : cu_seqlens[i]] = i - 1 + segment_ids = segment_ids.reshape(1, -1) + + attn_output = torch_xla.experimental.custom_kernel.flash_attention( + q, k, v, q_segment_ids=segment_ids, kv_segment_ids=segment_ids + ).reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + class Qwen2_5_VLVisionFlashAttention2(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() @@ -292,77 +474,78 @@ def unpack_qkv_with_mask(self, q, k, v, cu_seqlens): num_heads = q.shape[1] head_dim = q.shape[2] - seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] - max_seq_len = seq_lengths.max().item() + seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] # Keep as tensor + max_seq_len = seq_lengths.max().item() # Use .max() on tensor + + if settings.FOUNDATION_STATIC_CACHE: + # Pad max_seq_len to the nearest multiple for compilation + max_seq_len = get_nearest_pad(max_seq_len, pad_multiple=16) + + # Pad batch_size to the nearest multiple for compilation + batch_size = get_nearest_pad(batch_size, pad_multiple=2) + + # Ensure seq_lengths is a tensor of the correct size + seq_lengths = F.pad( + seq_lengths, (0, batch_size - seq_lengths.size(0)), "constant", 0 + ) + + # some day, you may look at this, and think: "what if I used repeat_interlave or some other fancy torch instead"? + # don't do this - it's a path to madness. For some readon, this loop is optimal on TPU batch_indices = [] position_indices = [] - for i, seq_len in enumerate(seq_lengths): + for i, seq_len in enumerate( + seq_lengths.tolist() + ): # Convert to list only for iteration batch_indices.extend([i] * seq_len) position_indices.extend(list(range(seq_len))) batch_indices = torch.tensor(batch_indices, device=device) position_indices = torch.tensor(position_indices, device=device) - batched_q = torch.zeros((batch_size, max_seq_len, num_heads, head_dim), device=device, dtype=dtype) + batched_q = torch.zeros( + (batch_size, max_seq_len, num_heads, head_dim), device=device, dtype=dtype + ) batched_k = torch.zeros_like(batched_q) batched_v = torch.zeros_like(batched_q) - # Create additive attention mask: shape (batch_size, 1, max_seq_len, max_seq_len) - # Each batch has a (max_seq_len, max_seq_len) matrix: - # - Rows = queries, Columns = keys - # - If query or key is padding, set to -inf + # Create additive attention mask attention_mask = torch.full( (batch_size, max_seq_len, max_seq_len), - fill_value=float('-inf'), + fill_value=float("-inf"), device=device, - dtype=dtype + dtype=dtype, ) - for b in range(batch_size): - valid_len = seq_lengths[b].item() - attention_mask[b, :valid_len, :valid_len] = 0 # Unmasked - attention_mask = attention_mask.unsqueeze(1) # (batch_size, 1, max_seq_len, max_seq_len) + # Create mask for valid positions + seq_range = torch.arange(max_seq_len, device=device) + valid_mask = seq_range.unsqueeze(0) < seq_lengths.unsqueeze( + 1 + ) # (batch_size, max_seq_len) + valid_2d = valid_mask.unsqueeze(2) & valid_mask.unsqueeze( + 1 + ) # (batch_size, max_seq_len, max_seq_len) + + # Simply use boolean indexing to set valid positions to 0 + attention_mask[valid_2d] = 0 + + attention_mask = attention_mask.unsqueeze( + 1 + ) # (batch_size, 1, max_seq_len, max_seq_len) batched_q[batch_indices, position_indices] = q batched_k[batch_indices, position_indices] = k batched_v[batch_indices, position_indices] = v - return batched_q, batched_k, batched_v, attention_mask - - def repack_hidden_states(self, batched_output, cu_seqlens): - """ - Reverses the unpacking operation using indexing to convert batched outputs - back to a flat tensor of shape (total_seq_len, hidden_dim). - - Args: - batched_output: Tensor of shape (batch_size, max_seq_len, hidden_dim) - cu_seqlens: Tensor of shape (batch_size + 1,) with cumulative sequence lengths - - Returns: - packed_output: Tensor of shape (total_seq_len, hidden_dim) - """ - device = batched_output.device - dtype = batched_output.dtype - - batch_size, max_seq_len, hidden_dim = batched_output.shape - seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] - total_seq_len = seq_lengths.sum().item() - - batch_indices = [] - position_indices = [] - - for i, seq_len in enumerate(seq_lengths): - batch_indices.extend([i] * seq_len) - position_indices.extend(list(range(seq_len))) - - batch_indices = torch.tensor(batch_indices, device=device) - position_indices = torch.tensor(position_indices, device=device) - - packed_output = batched_output[batch_indices, position_indices] - - return packed_output # Shape: (total_seq_len, hidden_dim) + return ( + batched_q, + batched_k, + batched_v, + attention_mask, + batch_indices, + position_indices, + ) def forward( self, @@ -392,11 +575,14 @@ def forward( cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) - q, k, v, attention_mask = self.unpack_qkv_with_mask(q, k, v, cu_seqlens) + q, k, v, attention_mask, batch_indices, position_indices = ( + self.unpack_qkv_with_mask(q, k, v, cu_seqlens) + ) batch_size, max_seqlen = q.shape[:2] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) + attn_output = F.scaled_dot_product_attention( q, k, @@ -404,9 +590,11 @@ def forward( attention_mask, dropout_p=0.0, ) - attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, max_seqlen, -1) # Bring back to (batch_size, max_seqlen, hidden_dim) + attn_output = attn_output.permute(0, 2, 1, 3).reshape( + batch_size, max_seqlen, -1 + ) # Bring back to (batch_size, max_seqlen, hidden_dim) + attn_output = attn_output[batch_indices, position_indices] attn_output = self.proj(attn_output) - attn_output = self.repack_hidden_states(attn_output, cu_seqlens) return attn_output @@ -414,6 +602,7 @@ def forward( QWEN2_5_VL_VISION_ATTENTION_CLASSES = { "eager": Qwen2_5_VLVisionAttention, "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, + "flash_attention_xla": Qwen2_5_VLVisionSdpaAttention, "sdpa": Qwen2_5_VLVisionSdpaAttention, } @@ -522,7 +711,8 @@ def __init__(self, config, *inputs, **kwargs) -> None: def rot_pos_emb(self, grid_thw): pos_ids = [] - for t, h, w in grid_thw: + grid_thw_list = grid_thw.tolist() + for t, h, w in grid_thw_list: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = hpos_ids.reshape( h // self.spatial_merge_size, @@ -550,6 +740,7 @@ def rot_pos_emb(self, grid_thw): return rotary_pos_emb def get_window_index(self, grid_thw): + grid_thw_list = grid_thw.tolist() window_index: list = [] cu_window_seqlens: list = [0] window_index_id = 0 @@ -557,7 +748,7 @@ def get_window_index(self, grid_thw): self.window_size // self.spatial_merge_size // self.patch_size ) - for grid_t, grid_h, grid_w in grid_thw: + for grid_t, grid_h, grid_w in grid_thw_list: llm_grid_h, llm_grid_w = ( grid_h // self.spatial_merge_size, grid_w // self.spatial_merge_size, @@ -591,13 +782,16 @@ def get_window_index(self, grid_thw): seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] ) cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() - window_index = torch.cat(window_index, dim=0) + window_index_id += grid_t * llm_grid_h * llm_grid_w + window_index = torch.cat(window_index, dim=0).to(device=grid_thw.device) return window_index, cu_window_seqlens def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + valid_grid_len: torch.Tensor | None = None, ) -> torch.Tensor: """ Args: @@ -609,7 +803,11 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ + if valid_grid_len is not None: + grid_thw = grid_thw[:valid_grid_len] + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) window_index, cu_window_seqlens = self.get_window_index(grid_thw) cu_window_seqlens = torch.tensor( @@ -666,6 +864,7 @@ def forward( ) hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) hidden_states = hidden_states[reverse_indices, :] @@ -691,9 +890,13 @@ def hidden_size(self) -> int: return config.hidden_size def embed_images( - self, image_batch: torch.Tensor, grid_thw: torch.Tensor + self, + image_batch: torch.Tensor, + grid_thw: torch.Tensor, + valid_grid_len: torch.Tensor | None = None, ) -> torch.Tensor: return super().forward( hidden_states=image_batch, grid_thw=grid_thw, + valid_grid_len=valid_grid_len, ) diff --git a/surya/common/surya/processor/__init__.py b/surya/common/surya/processor/__init__.py index d093f83a..5fe6f2cf 100644 --- a/surya/common/surya/processor/__init__.py +++ b/surya/common/surya/processor/__init__.py @@ -100,9 +100,7 @@ def __init__( TaskNames.block_without_boxes: self.special_token_mapping.get( BLOCK_WITHOUT_BOXES_TOKEN ), - TaskNames.layout: self.special_token_mapping.get( - LAYOUT_BOS_TOKEN - ) + TaskNames.layout: self.special_token_mapping.get(LAYOUT_BOS_TOKEN), } if self.image_token_id is None: @@ -379,7 +377,7 @@ def __call__( mixed_batch: List[dict], padding_side: Optional[str] = "left", device: Optional[torch.device] = None, - pad_to_multiple: Optional[int] = None + pad_to_multiple: Optional[int] = None, ): all_image_tiles = [] all_input_ids = [] @@ -405,18 +403,18 @@ def __call__( padding_side=padding_side, padding_value=self.pad_token_id, ) - + if pad_to_multiple is not None: current_len = batched_input_ids.shape[1] # Calculate the next multiple of pad_to_multiple - padded_len = ((current_len + pad_to_multiple - 1) // pad_to_multiple) * pad_to_multiple - + padded_len = ( + (current_len + pad_to_multiple - 1) // pad_to_multiple + ) * pad_to_multiple + if padded_len > current_len: pad_len = padded_len - current_len batched_input_ids = torch.nn.functional.pad( - batched_input_ids, - (pad_len, 0), - value=self.pad_token_id + batched_input_ids, (pad_len, 0), value=self.pad_token_id ) attention_mask = batched_input_ids.ne(self.pad_token_id) @@ -435,7 +433,7 @@ def __call__( batched_grid_thw = torch.from_numpy(np.array(all_grid_thw)) # Pin memory for CUDA - if device == torch.device("cuda"): + if device == torch.device("cuda") or device == torch.device("xla"): batched_image_tiles = batched_image_tiles.pin_memory() batched_grid_thw = batched_grid_thw.pin_memory() attention_mask = attention_mask.pin_memory() diff --git a/surya/common/util.py b/surya/common/util.py index f75d3e9f..bfa6c6d2 100644 --- a/surya/common/util.py +++ b/surya/common/util.py @@ -3,7 +3,6 @@ import torch from surya.common.polygon import PolygonBox -from surya.settings import settings def clean_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]: @@ -84,14 +83,3 @@ def is_flash_attn_2_supported(device: str | torch.device) -> bool: return False return True - - -if settings.TORCH_DEVICE_MODEL == "xla": - import torch_xla.core.xla_model as xm -else: - xm = None - - -def mark_step(): - if xm is not None: - xm.mark_step() diff --git a/surya/common/xla.py b/surya/common/xla.py new file mode 100644 index 00000000..5db33c5e --- /dev/null +++ b/surya/common/xla.py @@ -0,0 +1,28 @@ +import math + +from surya.settings import settings + +if settings.TORCH_DEVICE_MODEL == "xla": + import torch_xla.core.xla_model as xm +else: + xm = None + + +def get_nearest_pad( + length: int, pad_multiple: int = settings.FOUNDATION_PAD_TO_NEAREST +): + return math.ceil(length / pad_multiple) * pad_multiple + + +def mark_step(): + if xm is not None: + xm.mark_step() + + +def get_compile_args(device: str) -> dict: + if device != "xla": + return {} + + return { + "backend": "openxla", + } diff --git a/surya/debug/draw.py b/surya/debug/draw.py index 8fb3b95d..38055557 100644 --- a/surya/debug/draw.py +++ b/surya/debug/draw.py @@ -1,26 +1,33 @@ -import re from PIL import ImageDraw, ImageFont from surya.debug.fonts import get_font_path from surya.debug.text import get_text_size -def draw_bboxes_on_image(bboxes, image, labels=None, label_font_size=10, color: str | list = 'red'): +def draw_bboxes_on_image( + bboxes, image, labels=None, label_font_size=10, color: str | list = "red" +): polys = [] for bb in bboxes: # Clockwise polygon - poly = [ - [bb[0], bb[1]], - [bb[2], bb[1]], - [bb[2], bb[3]], - [bb[0], bb[3]] - ] + poly = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]] polys.append(poly) - return draw_polys_on_image(polys, image, labels, label_font_size=label_font_size, color=color) + return draw_polys_on_image( + polys, image, labels, label_font_size=label_font_size, color=color + ) -def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10, color: str | list = 'red'): +def draw_polys_on_image( + corners, + image, + labels=None, + box_padding=-1, + label_offset=1, + label_font_size=10, + color: str | list = "red", + line_width: int = 1, +): draw = ImageDraw.Draw(image) font_path = get_font_path() label_font = ImageFont.truetype(font_path, label_font_size) @@ -28,27 +35,31 @@ def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offse for i in range(len(corners)): poly = corners[i] poly = [(int(p[0]), int(p[1])) for p in poly] - draw.polygon(poly, outline=color[i] if isinstance(color, list) else color, width=1) + draw.polygon( + poly, + outline=color[i] if isinstance(color, list) else color, + width=line_width, + ) if labels is not None: label = labels[i] text_position = ( min([p[0] for p in poly]) + label_offset, - min([p[1] for p in poly]) + label_offset + min([p[1] for p in poly]) + label_offset, ) text_size = get_text_size(label, label_font) box_position = ( text_position[0] - box_padding + label_offset, text_position[1] - box_padding + label_offset, text_position[0] + text_size[0] + box_padding + label_offset, - text_position[1] + text_size[1] + box_padding + label_offset + text_position[1] + text_size[1] + box_padding + label_offset, ) draw.rectangle(box_position, fill="white") draw.text( text_position, label, fill=color[i] if isinstance(color, list) else color, - font=label_font + font=label_font, ) return image diff --git a/surya/detection/__init__.py b/surya/detection/__init__.py index 513a6444..7715b4ba 100644 --- a/surya/detection/__init__.py +++ b/surya/detection/__init__.py @@ -9,7 +9,7 @@ from tqdm import tqdm from surya.common.predictor import BasePredictor -from surya.common.util import mark_step +from surya.common.xla import mark_step from surya.detection.loader import DetectionModelLoader from surya.detection.parallel import FakeExecutor diff --git a/surya/detection/loader.py b/surya/detection/loader.py index a3a94e52..793725bf 100644 --- a/surya/detection/loader.py +++ b/surya/detection/loader.py @@ -3,6 +3,7 @@ import torch from surya.common.load import ModelLoader +from surya.common.xla import get_compile_args from surya.detection.processor import SegformerImageProcessor from surya.detection.model.config import EfficientViTConfig @@ -47,7 +48,7 @@ def model( logger.info( f"Compiling detection model {self.checkpoint} on device {device} with dtype {dtype}" ) - compile_args = {"backend": "openxla"} if device == "xla" else {} + compile_args = get_compile_args(device) model = torch.compile(model, **compile_args) logger.debug( diff --git a/surya/foundation/__init__.py b/surya/foundation/__init__.py index fef77754..2b82f20d 100644 --- a/surya/foundation/__init__.py +++ b/surya/foundation/__init__.py @@ -1,19 +1,19 @@ from __future__ import annotations +import time from dataclasses import dataclass from typing import List, Optional, Tuple from collections import deque import cv2 +import math import numpy as np import torch from PIL import Image from tqdm import tqdm import torch.nn.functional as F -from transformers import QuantizedCacheConfig -from surya.common.surya import SuryaModelOutput -from surya.common.util import mark_step +from surya.common.xla import mark_step, get_nearest_pad from surya.common.predictor import BasePredictor from surya.foundation.loader import FoundationModelLoader @@ -21,9 +21,9 @@ detect_repeat_token, ) from surya.common.surya.schema import TaskNames -from surya.foundation.cache import ( - ContinuousBatchingCache, -) + +from surya.foundation.cache.static import StaticOpsContinuousBatchingCache +from surya.foundation.cache.dynamic import DynamicOpsContinuousBatchingCache from surya.settings import settings from surya.logging import get_logger, configure_logging @@ -61,13 +61,14 @@ class FoundationPrompt: class FoundationPredictor(BasePredictor): model_loader_cls = FoundationModelLoader - batch_size = settings.RECOGNITION_BATCH_SIZE # Default to the recognition batch size + batch_size = ( + settings.RECOGNITION_BATCH_SIZE + ) # Default to the recognition batch size torch_dtype = None # No default, loader picks the dtype based on device properties - bf16/fp16 - default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 128} + default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 64} encoder_chunk_size: int = 4096 # Default chunk size - encoder_chunk_sizes = {"cpu": 4096, "mps": 4096, "cuda": 32768, "xla": 32768} - min_prefill_ratio: int = 0.2 - min_trim_length: int = 50 + encoder_chunk_sizes = {"cpu": 4096, "mps": 4096, "cuda": 32768, "xla": 4096} + min_prefill_ratio: int = 0.8 tasks = { TaskNames.ocr_with_boxes: { "needs_bboxes": True, @@ -86,7 +87,7 @@ class FoundationPredictor(BasePredictor): }, TaskNames.layout: { "needs_bboxes": False, - "img_size": (1024, 1024), + "img_size": (1024, 1024), # 1369 max tokens "max_tokens": 200, }, } @@ -111,7 +112,11 @@ def __init__(self, checkpoint=None, device=settings.TORCH_DEVICE_MODEL, dtype=No device=self.model.device, ) - self.pad_to_multiple = 512 if settings.FOUNDATION_STATIC_CACHE else None + self.pad_to_multiple = ( + settings.FOUNDATION_PAD_TO_NEAREST + if settings.FOUNDATION_STATIC_CACHE + else None + ) def get_encoder_chunk_size(self) -> int: if settings.FOUNDATION_CHUNK_SIZE is not None: @@ -123,14 +128,21 @@ def get_encoder_chunk_size(self) -> int: chunk_size = self.encoder_chunk_sizes[settings.TORCH_DEVICE_MODEL] return chunk_size - def setup_cache(self, batch_size: int, max_cache_len: int): - self.kv_cache = ContinuousBatchingCache( + def setup_cache(self, batch_size: int, max_image_tokens: int, max_text_tokens: int): + # Pad to multiple of sliding_window + max_image_tokens = ( + math.ceil(max_image_tokens / self.model.config.sliding_window) + * self.model.config.sliding_window + ) + max_cache_len = max_image_tokens + max_text_tokens + kv_cache_cls = StaticOpsContinuousBatchingCache if settings.TORCH_DEVICE_MODEL == "xla" else DynamicOpsContinuousBatchingCache + self.kv_cache = kv_cache_cls( self.model.config, batch_size, max_cache_len, - text_sliding_window=self.model.config.sliding_window, + text_sliding_window=max_text_tokens, device=self.model.device, - dtype=self.model.dtype + dtype=self.model.dtype, ) self.prompt_queue.clear() self.batch_prompt_mapping = {i: None for i in range(batch_size)} @@ -178,15 +190,20 @@ def prepare_input( return batch - def process_outputs(self, outputs: SuryaModelOutput, max_lookahead_tokens: Optional[int]=None) -> ContinuousBatchOutput: + def process_outputs( + self, outputs, max_lookahead_tokens: Optional[int] = None + ) -> ContinuousBatchOutput: # Predictions are multi-token lm_logits = outputs["lm_logits"].float() # shape: [batch_size, seq_len, V] bbox_logits = outputs["bbox_logits"].float() # shape: [batch_size, seq_len, 6] - if max_lookahead_tokens is not None and lm_logits.shape[1] > max_lookahead_tokens + 1: - lm_logits = lm_logits[:, :max_lookahead_tokens + 1, :] - bbox_logits = bbox_logits[:, :max_lookahead_tokens + 1, :] - + if ( + max_lookahead_tokens is not None + and lm_logits.shape[1] > max_lookahead_tokens + 1 + ): + lm_logits = lm_logits[:, : max_lookahead_tokens + 1, :] + bbox_logits = bbox_logits[:, : max_lookahead_tokens + 1, :] + # Get predictions preds = torch.argmax(lm_logits, dim=-1) input_ids = preds.to(torch.long) @@ -204,7 +221,7 @@ def process_outputs(self, outputs: SuryaModelOutput, max_lookahead_tokens: Optio preds=preds, bbox_preds=box_preds, scores=scores, - token_probs=token_probs + token_probs=token_probs, ) # Make space for beacon tokens to be inserted while keeping the same seq len across all batch elements @@ -212,53 +229,79 @@ def process_outputs(self, outputs: SuryaModelOutput, max_lookahead_tokens: Optio # with the causal mask of flash attention, and we are careful to ignore this pad token when inserting # into cache def maybe_insert_beacon_tokens( - self, - input_ids: torch.Tensor, - num_predicted_tokens: torch.Tensor + self, input_ids: torch.Tensor, num_predicted_tokens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size, seq_len = input_ids.shape # seq_len can be >1 - In case of multi-token predictions - + batch_size, seq_len = ( + input_ids.shape + ) # seq_len can be >1 - In case of multi-token predictions + # num_predicted tokens **does not include** the current new input_ids, this number is updated **after beacon tokens are inserted** - token_positions = num_predicted_tokens + torch.arange(1, seq_len + 1, device=input_ids.device).unsqueeze(0) - beacon_positions = (token_positions % self.beacon_token_interval == 0) + token_positions = num_predicted_tokens + torch.arange( + 1, seq_len + 1, device=input_ids.device + ).unsqueeze(0) + beacon_positions = token_positions % self.beacon_token_interval == 0 # If no beacons needed, return original input needs_beacon = beacon_positions.any(dim=1) # shape: [batch_size] if not needs_beacon.any(): - return input_ids, torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * seq_len - - beacon_insert_pos = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device) - for i in range(batch_size): - if needs_beacon[i]: - # Find first position that needs beacon - beacon_insert_pos[i] = torch.where(beacon_positions[i])[0] - + return input_ids, torch.ones( + batch_size, dtype=torch.long, device=input_ids.device + ) * seq_len + + beacon_insert_pos = torch.argmax(beacon_positions.float(), dim=1) + # Padded input ids. - new_input_ids = torch.full((batch_size, seq_len + 1), self.device_pad_token, - dtype=input_ids.dtype, device=input_ids.device) - - # Fill in tokens for each sequence - for i in range(batch_size): - if needs_beacon[i]: - insert_pos = beacon_insert_pos[i] - new_input_ids[i, insert_pos] = self.device_beacon_token - if insert_pos > 0: - new_input_ids[i, :insert_pos] = input_ids[i, :insert_pos] - new_input_ids[i, insert_pos+1:] = input_ids[i, insert_pos:] - else: - new_input_ids[i, 1:] = input_ids[i, :] - + new_input_ids = torch.full( + (batch_size, seq_len + 1), + self.device_pad_token, + dtype=input_ids.dtype, + device=input_ids.device, + ) + + # Create indices for vectorized operations + batch_indices = torch.arange(batch_size, device=input_ids.device).unsqueeze(1) + seq_indices = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) + + # Calculate target positions for each token + # For sequences with beacons: tokens at pos >= beacon_insert_pos get shifted by +1 + # For sequences without beacons: all tokens get shifted by +1 (start at position 1) + beacon_pos_expanded = beacon_insert_pos.unsqueeze(1) + needs_beacon_expanded = needs_beacon.unsqueeze(1) + + # Calculate shift amount for each position + shift_amount = torch.where( + needs_beacon_expanded, + ( + seq_indices >= beacon_pos_expanded + ).long(), # Shift by 1 if pos >= beacon_pos + torch.ones_like( + seq_indices + ), # Shift by 1 for all positions (no beacon case) + ) + + target_positions = seq_indices + shift_amount + + # Use advanced indexing to place all tokens at once + new_input_ids[batch_indices, target_positions] = input_ids + + # Insert beacon tokens at the correct positions + new_input_ids[needs_beacon, beacon_insert_pos[needs_beacon]] = ( + self.device_beacon_token + ) + # Calculate valid token counts for both padded and non padded sequences - valid_token_counts = torch.where( - needs_beacon, - torch.tensor(seq_len + 1, device=input_ids.device), - torch.tensor(seq_len, device=input_ids.device) + valid_token_counts = torch.where(needs_beacon, seq_len + 1, seq_len).to( + dtype=torch.long, device=input_ids.device ) - + return new_input_ids, valid_token_counts - def decode(self, current_inputs: Optional[ContinuousBatchInput] = None, max_lookahead_tokens: Optional[int] = None): - # Note - If we want to use the outputs from the non-last token, we + def decode( + self, + current_inputs: Optional[ContinuousBatchInput] = None, + max_lookahead_tokens: Optional[int] = None, + ): + # Note - If we want to use the outputs from the non-last token, we # need to set the cache position manually to ensure causality. The default # behavior only works for the last token currently input_ids = current_inputs.input_ids @@ -280,19 +323,25 @@ def decode(self, current_inputs: Optional[ContinuousBatchInput] = None, max_look use_cache=True, past_key_values=self.kv_cache, prefill=False, - num_valid_tokens=num_valid_tokens + num_valid_tokens=num_valid_tokens, ) - processed_output: ContinuousBatchOutput = self.process_outputs(outputs, max_lookahead_tokens=max_lookahead_tokens) - + processed_output: ContinuousBatchOutput = self.process_outputs( + outputs, max_lookahead_tokens=max_lookahead_tokens + ) + input_ids = processed_output.input_ids # Update this **before** inserting beacon tokens num_new_tokens = input_ids.shape[1] num_predicted_tokens += num_new_tokens - input_ids, num_valid_tokens = self.maybe_insert_beacon_tokens(input_ids, num_predicted_tokens) - position_ids = position_ids[:, -1:] + torch.arange(1, input_ids.shape[1] + 1, device=input_ids.device) + input_ids, num_valid_tokens = self.maybe_insert_beacon_tokens( + input_ids, num_predicted_tokens + ) + position_ids = position_ids[:, -1:] + torch.arange( + 1, input_ids.shape[1] + 1, device=input_ids.device + ) # Some of the input sequences may now have left padding tokens, so we want to account for that # offset is a per-batch offset of the position_ids offset = (input_ids.shape[1] - num_valid_tokens).unsqueeze(1) @@ -302,7 +351,7 @@ def decode(self, current_inputs: Optional[ContinuousBatchInput] = None, max_look input_ids=input_ids, position_ids=position_ids, num_valid_tokens=num_valid_tokens, - num_predicted_tokens=num_predicted_tokens + num_predicted_tokens=num_predicted_tokens, ) return new_input, processed_output @@ -323,20 +372,30 @@ def pad_and_shift_input_ids_position_ids( """ # No padding if new_seq_len == input_ids.shape[1]: - return input_ids, position_ids[:, -1:] + torch.arange(1, new_seq_len +1, device=self.model.device) + return input_ids, position_ids[:, -1:] + torch.arange( + 1, new_seq_len + 1, device=self.model.device + ) pad_len = new_seq_len - input_ids.shape[1] - padded_input_ids = torch.nn.functional.pad(input_ids, (pad_len, 0), value=self.device_pad_token) + padded_input_ids = torch.nn.functional.pad( + input_ids, (pad_len, 0), value=self.device_pad_token + ) # Since we have **left padding**, offset the new position_ids by the amount of padding # This ensures that the **true tokens** get the correct position_ids # The position_ids assigned to pad tokens do not matter. They are not cached, and not used for outputs - updated_position_ids = position_ids[:, -1:] + torch.arange(1, new_seq_len + 1, device=self.model.device) + updated_position_ids = position_ids[:, -1:] + torch.arange( + 1, new_seq_len + 1, device=self.model.device + ) updated_position_ids -= pad_len return padded_input_ids, updated_position_ids - def prefill(self, current_inputs: Optional[ContinuousBatchInput] = None, max_lookahead_tokens: Optional[int] = None): + def prefill( + self, + current_inputs: Optional[ContinuousBatchInput] = None, + max_lookahead_tokens: Optional[int] = None, + ): logger.debug(f"Prefilling {self.num_empty_slots} slots") prompts: List[FoundationPrompt] = [ self.prompt_queue.popleft() @@ -356,64 +415,134 @@ def prefill(self, current_inputs: Optional[ContinuousBatchInput] = None, max_loo ], # Pass math mode to the processor ) processed_inputs = self.processor( - batch_input, padding_side="left", device=self.model.device, pad_to_multiple=self.pad_to_multiple - ).to(device=self.model.device) + batch_input, + padding_side="left", + device=self.model.device, + pad_to_multiple=self.pad_to_multiple, + ) - input_ids = processed_inputs["input_ids"].to(dtype=torch.long) + input_ids = processed_inputs["input_ids"].to( + device=self.model.device, dtype=torch.long + ) image_tiles = processed_inputs["image_tiles"].to(dtype=self.model.dtype) grid_thw = processed_inputs["grid_thw"].to(dtype=torch.long) - attention_mask = processed_inputs["attention_mask"].to(dtype=torch.long) - position_ids = processed_inputs["position_ids"].to(dtype=torch.long) + attention_mask = processed_inputs["attention_mask"].to( + device=self.model.device, dtype=torch.long + ) + position_ids = processed_inputs["position_ids"].to( + device=self.model.device, dtype=torch.long + ) valid_batch_size = len(idxs_to_merge) + cache_idxs = torch.tensor( + idxs_to_merge, device=self.model.device, dtype=torch.long + ) + cache_idxs_padded = cache_idxs + image_tile_length = torch.tensor(image_tiles.shape[0]).to( + dtype=torch.long, device=self.model.device + ) if settings.FOUNDATION_STATIC_CACHE: - input_ids = self.pad_to_batch_size(input_ids, batch_size=self.kv_cache.max_batch_size) - attention_mask = self.pad_to_batch_size(attention_mask, batch_size=self.kv_cache.max_batch_size) - position_ids = self.pad_to_batch_size(position_ids, batch_size=self.kv_cache.max_batch_size) - - # Find text lengths of each - is_special = (input_ids.unsqueeze(-1) == self.special_token_ids).any(-1) # (batch, seq_len) - text_lengths = [] - for i in range(input_ids.shape[0]): - special_positions = is_special[i].nonzero(as_tuple=True)[0] - if len(special_positions) > 0: - # Assuming special tokens are contiguous at the start - prefix_len = special_positions[-1].item() + 1 - else: - prefix_len = 0 - text_lengths.append(input_ids.shape[1] - prefix_len) + input_ids = self.pad_to_batch_size( + input_ids, batch_size=self.kv_cache.max_batch_size + ) + attention_mask = self.pad_to_batch_size( + attention_mask, batch_size=self.kv_cache.max_batch_size + ) + position_ids = self.pad_to_batch_size( + position_ids, batch_size=self.kv_cache.max_batch_size + ) + cache_idxs_padded = self.pad_to_batch_size( + cache_idxs, batch_size=self.kv_cache.max_batch_size, pad_value=-1 + ) + + image_tile_pad = get_nearest_pad( + image_tiles.shape[0], settings.FOUNDATION_PAD_TO_NEAREST * 32 + ) + image_tiles = self.pad_to_batch_size(image_tiles, batch_size=image_tile_pad) + grid_thw = self.pad_to_batch_size( + grid_thw, batch_size=self.kv_cache.max_batch_size + ) + + is_special = ( + input_ids.unsqueeze(-1).eq(self.special_token_ids).any(-1) + ) # (B, L) bool + + idx = ( + torch.arange(input_ids.size(1), device=input_ids.device, dtype=torch.long) + + 1 + ) # 1…L + special_length = is_special.sum(dim=1) # (B,) number of special tokens + last_special_plus1 = (is_special * idx).max(dim=1).values # (B,) 0 if none + + image_lengths = special_length.to(dtype=torch.long) # (B,) + text_lengths = (input_ids.size(1) - last_special_plus1).to( + dtype=torch.long + ) # (B,) with settings.INFERENCE_MODE(): + logger.debug( + f"Prefill shapes: input_ids={input_ids.shape}, image_tiles={image_tiles.shape}, grid_thw={grid_thw.shape}, attention_mask={attention_mask.shape}, position_ids={position_ids.shape}, cache_idxs_padded={cache_idxs_padded.shape}" + ) + + start = time.time() + image_embeddings = self.model.get_image_embeddings( + pixel_values=image_tiles, + grid_thw=grid_thw, + encoder_chunk_size=self.get_encoder_chunk_size(), + image_tile_length=image_tile_length, + valid_batch_size=valid_batch_size, + ) + print(f"Image embedding took {time.time() - start:.2f} seconds") + outputs = self.model( input_ids=input_ids, - image_tiles=image_tiles, - grid_thw=grid_thw, + image_embeddings=image_embeddings, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=None, past_key_values=self.kv_cache, use_cache=True, encoder_chunk_size=self.get_encoder_chunk_size(), - cache_idxs=idxs_to_merge, + cache_idxs=cache_idxs_padded, prefill=True, - num_valid_tokens=None, # Not required during prefill + num_valid_tokens=None, # Not required during prefill text_lengths=text_lengths, + valid_batch_size=torch.tensor( + valid_batch_size, device=self.model.device, dtype=torch.long + ), ) - + # import torch_xla.core.xla_model as xm + # print(xm.get_metrics_report()) + # Process outputs - processed_outputs = self.process_outputs(outputs, max_lookahead_tokens=max_lookahead_tokens) + processed_outputs = self.process_outputs( + outputs, max_lookahead_tokens=max_lookahead_tokens + ) # Multi-token prediction predicted_tokens = processed_outputs.input_ids.shape[1] - num_valid_tokens = torch.ones((input_ids.shape[0]), device=self.model.device, dtype=torch.long) * predicted_tokens - num_predicted_tokens = torch.ones((input_ids.shape[0], 1), device=self.model.device, dtype=torch.long) * predicted_tokens + num_valid_tokens = ( + torch.ones((input_ids.shape[0]), device=self.model.device, dtype=torch.long) + * predicted_tokens + ) + num_predicted_tokens = ( + torch.ones( + (input_ids.shape[0], 1), device=self.model.device, dtype=torch.long + ) + * predicted_tokens + ) - self.kv_cache.prefill_attention_mask_update(attention_mask, idxs_to_merge, text_lengths[:valid_batch_size]) - self.kv_cache.update_text_counts(idxs_to_merge, text_lengths[:valid_batch_size]) + self.kv_cache.prefill_attention_mask_update( + attention_mask, + cache_idxs_padded, # unpadded + text_lengths, + ) if current_inputs is None: new_seq_len = processed_outputs.input_ids.shape[1] # No padding tokens - So we can safely set position_ids this way - position_ids = position_ids[:, -1:] + torch.arange(1, new_seq_len + 1, device=position_ids.device) + position_ids = position_ids[:, -1:] + torch.arange( + 1, new_seq_len + 1, device=position_ids.device + ) new_input = ContinuousBatchInput( input_ids=processed_outputs.input_ids, position_ids=position_ids, @@ -431,35 +560,69 @@ def prefill(self, current_inputs: Optional[ContinuousBatchInput] = None, max_loo current_input_ids = current_inputs.input_ids current_position_ids = current_inputs.position_ids - assert(current_input_ids.shape[1] == current_position_ids.shape[1]) + assert current_input_ids.shape[1] == current_position_ids.shape[1] input_ids, position_ids = self.pad_and_shift_input_ids_position_ids( - processed_outputs.input_ids, position_ids, new_seq_len=current_input_ids.shape[1] + processed_outputs.input_ids, + position_ids, + new_seq_len=current_input_ids.shape[1], ) current_input_ids[idxs_to_merge] = input_ids[:valid_batch_size] current_position_ids[idxs_to_merge] = position_ids[:valid_batch_size] current_num_valid_tokens = current_inputs.num_valid_tokens - current_num_valid_tokens[idxs_to_merge] = num_valid_tokens[:valid_batch_size] + for i, idx_to_merge in enumerate(idxs_to_merge): + current_num_valid_tokens[idx_to_merge] = num_valid_tokens[i] current_num_predicted_tokens = current_inputs.num_predicted_tokens - current_num_predicted_tokens[idxs_to_merge] = num_predicted_tokens[:valid_batch_size] + current_num_predicted_tokens[idxs_to_merge] = num_predicted_tokens[ + :valid_batch_size + ] new_input = ContinuousBatchInput( input_ids=current_input_ids, position_ids=current_position_ids, num_valid_tokens=current_num_valid_tokens, - num_predicted_tokens=current_num_predicted_tokens + num_predicted_tokens=current_num_predicted_tokens, ) return new_input, processed_outputs, idxs_to_merge - def get_max_image_token_count(self, task_name: TaskNames) -> int: - dummy_image = np.zeros(shape=(*self.tasks[task_name]["img_size"], 3)) - tiles, _ = self.processor._process_and_tile(dummy_image) - num_image_tokens = tiles.shape[0] / self.processor.merge_size**2 - - # Extra 1 to account for rotation token when present. - return 1 + self.processor.num_register_tokens + int(num_image_tokens) + def get_max_image_token_count(self, images: list[np.ndarray], tasks: List[TaskNames]) -> int: + def compute_scaled_size(H: int, W: int, max_size: Tuple[int, int]) -> Tuple[int, int]: + max_W, max_H = max_size + min_W, min_H = (168, 168) + + current_pixels = H * W + max_pixels = max_H * max_W + min_pixels = min_H * min_W + + if current_pixels > max_pixels: + scale = (max_pixels / current_pixels) ** 0.5 + return math.floor(H * scale), math.floor(W * scale) + elif current_pixels < min_pixels: + scale = (min_pixels / current_pixels) ** 0.5 + return math.ceil(H * scale), math.ceil(W * scale) + return H, W + + def get_tile_count(H: int, W: int, factor: int) -> int: + H_bar = math.ceil(H / factor) * factor + W_bar = math.ceil(W / factor) * factor + grid_h = H_bar / self.processor.patch_size + grid_w = W_bar // self.processor.patch_size + return grid_h * grid_w + + max_tokens = 0 + factor = self.processor.patch_size * self.processor.merge_size + + for image, task in zip(images, tasks): + H, W = image.shape[:2] + max_size = self.tasks[task]["img_size"] + scaled_H, scaled_W = compute_scaled_size(H, W, max_size) + token_count = get_tile_count(scaled_H, scaled_W, factor) / (self.processor.merge_size ** 2) + max_tokens = max(max_tokens, token_count) + + # Extra 10 to account for EOS/BOS/Rotation token etc. + return 10 + self.processor.num_register_tokens + int(max_tokens) def prediction_loop( self, @@ -470,7 +633,7 @@ def prediction_loop( math_mode: bool = True, drop_repeated_tokens: bool = True, max_lookahead_tokens: Optional[int] = None, - top_k: int = 0 + top_k: Optional[int] = None ) -> tuple: allowed_tasks = self.tasks.keys() assert all([task_name in allowed_tasks for task_name in task_names]), ( @@ -482,18 +645,23 @@ def prediction_loop( topk_probs = [[] for _ in range(len(images))] if batch_size is None: - batch_size = self.get_batch_size() - - batch_size = min(len(images), batch_size) + batch_size = ( + self.get_batch_size() + if settings.FOUNDATION_STATIC_CACHE + else len(images) + ) + current_inputs = None - - max_image_tokens = max(self.get_max_image_token_count(task) for task in set(task_names)) - self.setup_cache(batch_size, max_cache_len=max_image_tokens + self.model.config.sliding_window) + + max_image_tokens = self.get_max_image_token_count(images, task_names) + self.setup_cache( + batch_size, + max_image_tokens=max_image_tokens, + max_text_tokens=self.model.config.sliding_window, + ) batch_max_tokens = {} - for idx, (img, txt, task) in enumerate( - zip(images, input_texts, task_names) - ): + for idx, (img, txt, task) in enumerate(zip(images, input_texts, task_names)): self.prompt_queue.append( FoundationPrompt( id=idx, task_name=task, text=txt, image=img, math_mode=math_mode @@ -518,11 +686,18 @@ def prediction_loop( if ( self.num_empty_slots / batch_size ) > self.min_prefill_ratio and self.prompt_queue: - updated_inputs, outputs, merge_idxs = self.prefill(current_inputs, max_lookahead_tokens=max_lookahead_tokens) + start = time.time() + updated_inputs, outputs, merge_idxs = self.prefill( + current_inputs, max_lookahead_tokens=max_lookahead_tokens + ) + mark_step() + + logger.debug(f"Prefill took {time.time() - start:.2f} seconds") predicted_tokens_cpu = outputs.preds.cpu() scores_cpu = outputs.scores.cpu() - if top_k > 0: + bbox_preds_cpu = outputs.bbox_preds.cpu() + if top_k is not None: batch_top_k_probs, batch_top_k_indices = torch.topk(outputs.token_probs, k=top_k, dim=-1) batch_top_k_probs_cpu = batch_top_k_probs.cpu() batch_top_k_indices_cpu = batch_top_k_indices.cpu() @@ -534,15 +709,16 @@ def prediction_loop( for t_idx in range(seq_len): token = predicted_tokens_cpu[temp_idx, t_idx].item() predicted_tokens[p_idx].append(token) - batch_bboxes[p_idx, batch_pos[p_idx]] = outputs.bbox_preds[ + + batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[ temp_idx, t_idx ] batch_pos[p_idx] += 1 scores[p_idx].append(scores_cpu[temp_idx, t_idx].item()) - if top_k > 0: + if top_k is not None: top_k_scores = { - batch_top_k_indices[k].item(): batch_top_k_probs[k].item() + batch_top_k_indices_cpu[temp_idx, t_idx, k].item(): batch_top_k_probs_cpu[temp_idx, t_idx, k].item() for k in range(top_k) } topk_probs[p_idx].append(top_k_scores) @@ -555,10 +731,15 @@ def prediction_loop( pbar.update(1) break else: - updated_inputs, outputs = self.decode(current_inputs, max_lookahead_tokens=max_lookahead_tokens) + updated_inputs, outputs = self.decode( + current_inputs, max_lookahead_tokens=max_lookahead_tokens + ) + mark_step() + predicted_tokens_cpu = outputs.preds.cpu() scores_cpu = outputs.scores.cpu() - if top_k > 0: + bbox_preds_cpu = outputs.bbox_preds.cpu() + if top_k is not None: batch_top_k_probs, batch_top_k_indices = torch.topk(outputs.token_probs, k=top_k, dim=-1) batch_top_k_probs_cpu = batch_top_k_probs.cpu() batch_top_k_indices_cpu = batch_top_k_indices.cpu() @@ -571,22 +752,25 @@ def prediction_loop( for t_idx in range(seq_len): token = predicted_tokens_cpu[b_idx, t_idx].item() predicted_tokens[p_idx].append(token) - batch_bboxes[p_idx, batch_pos[p_idx]] = outputs.bbox_preds[ + + batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[ b_idx, t_idx ] batch_pos[p_idx] += 1 scores[p_idx].append(scores_cpu[b_idx, t_idx].item()) - if top_k > 0: + if top_k is not None: top_k_scores = { - batch_top_k_indices[k].item(): batch_top_k_probs[k].item() + batch_top_k_indices_cpu[temp_idx, t_idx, k].item(): batch_top_k_probs_cpu[temp_idx, t_idx, k].item() for k in range(top_k) } topk_probs[p_idx].append(top_k_scores) - repeats = ( - len(predicted_tokens[p_idx]) >= batch_max_tokens[p_idx] - or (drop_repeated_tokens and detect_repeat_token(predicted_tokens[p_idx])) + repeats = len(predicted_tokens[p_idx]) >= batch_max_tokens[ + p_idx + ] or ( + drop_repeated_tokens + and detect_repeat_token(predicted_tokens[p_idx]) ) if ( token @@ -605,11 +789,10 @@ def prediction_loop( # Update inputs and mark XLA step current_inputs = updated_inputs - mark_step() pbar.close() del self.kv_cache self.kv_cache = None torch.cuda.empty_cache() - return predicted_tokens, batch_bboxes, scores, topk_probs \ No newline at end of file + return predicted_tokens, batch_bboxes, scores, topk_probs diff --git a/surya/foundation/cache/__init__.py b/surya/foundation/cache/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/surya/foundation/cache.py b/surya/foundation/cache/dynamic.py similarity index 78% rename from surya/foundation/cache.py rename to surya/foundation/cache/dynamic.py index 7bb13002..65100ea6 100644 --- a/surya/foundation/cache.py +++ b/surya/foundation/cache/dynamic.py @@ -3,126 +3,75 @@ from transformers import PretrainedConfig """ -Special cache class for the surya foundation model that supports - +Special cache implementation for the surya foundation model that supports - 1) Static shape 2) A custom sliding window, where image tokens stay in cache, and text tokens are popped 3) Continuous batching - merging etc 4) Attention mask management - To match with what's currently in the cache +This carefully manages the prefix image cache, while keeping fast updates to the text cache, supporting sliding +window in the text cache, and maintaining support for the format expected by FA2 Heavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079 """ -class ContinuousBatchingCache(): +class DynamicOpsContinuousBatchingLayerCache(): def __init__( self, config: PretrainedConfig, batch_size: int, max_cache_len: int, text_sliding_window: int, - device: int, - dtype: int, + device: torch.device, + dtype: torch.dtype, ): - self.text_sliding_window = text_sliding_window - self.num_layers = config.num_hidden_layers - self.max_batch_size = batch_size self.max_cache_len = max_cache_len - self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - self._dtype = dtype + self.max_batch_size = batch_size + self.head_dim = ( + getattr(config, "head_dim", None) + or config.hidden_size // config.num_attention_heads + ) self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None else config.num_key_value_heads ) + self.text_sliding_window = text_sliding_window + self.dtype = dtype + self.device = device - # Cache init is taken from huggingface StaticCache - https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/cache_utils.py#L1125 - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - device = torch.device(device) if device is not None else None - for _ in range(config.num_hidden_layers): - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - self.attention_mask = torch.zeros( - (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long + self.key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + self.value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + self.text_token_counts = torch.zeros( + self.max_batch_size, dtype=torch.long, device=device ) - self.text_token_counts = [ - torch.zeros(self.max_batch_size, dtype=torch.long, device=device) - for _ in range(self.num_layers) - ] - self.dtype = dtype - self.device = device + + # The attention mask managed by our kv cache automatically masks the tokens + # in the cache, so we can return full length for HF to use in other places + # This is mainly utilized in the cache_positions creation + def get_seq_length(self) -> int: + return self.max_cache_len + def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: prefill = cache_kwargs.get("prefill", False) update_fn = self._prefill_update if prefill else self._decode_update return update_fn( - self.key_cache[layer_idx], - self.value_cache[layer_idx], + self.key_cache, + self.value_cache, key_states, value_states, - self.text_token_counts[layer_idx], - cache_kwargs, - ) - - def update_text_counts(self, cache_idxs: List[int], new_text_lens: List[int]): - assert len(cache_idxs) == len(new_text_lens) - new_text_len_tensor = torch.tensor( - new_text_lens, dtype=torch.long, device=self.device + self.text_token_counts, + cache_kwargs ) - for layer_idx in range(self.num_layers): - self.text_token_counts[layer_idx][cache_idxs] = new_text_len_tensor - - # Mirrors the logic from _prefill_update - # Logic is better explained in this funcrtion - def prefill_attention_mask_update( - self, - prefill_attention_mask: torch.Tensor, - cache_idxs: List[int], - text_lengths: List[int], - ): - seq_len = prefill_attention_mask.shape[1] - sliding_window = self.text_sliding_window - total_cache_len = self.max_cache_len - prefix_cache_space = total_cache_len - sliding_window - - for batch_idx, cache_idx in enumerate(cache_idxs): - text_len = text_lengths[batch_idx] - prefix_len = seq_len - text_len - self.attention_mask[cache_idx] = 0 # Set default - - assert prefix_len > 0, "There are no prefix (image) tokens!" - - end_pos = prefix_cache_space - # Handle prefix part - Which may be left padded - if prefix_len <= prefix_cache_space: - start_pos = prefix_cache_space - prefix_len - self.attention_mask[cache_idx, start_pos: end_pos] = prefill_attention_mask[batch_idx, :prefix_len] - else: - self.attention_mask[cache_idx, :end_pos] = prefill_attention_mask[batch_idx, prefix_len - prefix_cache_space: prefix_len] - - # Handle text part, keeping sliding window in consideration - # All of the left padding is before the prefix, so we can ignore the prefill_attention_mask here - if text_len > 0: - text_cache_start = prefix_cache_space - if text_len <= sliding_window: - self.attention_mask[cache_idx, text_cache_start: text_cache_start + text_len] = 1 - else: - self.attention_mask[cache_idx, -sliding_window: ] = 1 - # Slow impl for now - Prefill time is dominated by the large sequence length forward pass def _prefill_update( self, @@ -181,54 +130,6 @@ def _prefill_update( # Return the full key/value states (not just cached) for use in subsequent layers return key_states, value_states - # """ - # Matches the logic of the decode update, but needs to be called before the updates - # since some parts of the model depend on the attention mask - # """ - def decode_attention_mask_update( - self, num_valid_tokens: torch.Tensor, cache_idxs: List[int] - ): - sliding_window = self.text_sliding_window - text_cache_start = self.max_cache_len - sliding_window - - # Using text_token_counts of first layer, should be same for all though - current_text_lens = self.text_token_counts[0] - cache_idxs_tensor = torch.tensor(cache_idxs, device=current_text_lens.device) - - # Get current text lengths for the relevant cache indices - current_lens = current_text_lens[cache_idxs_tensor] - new_text_lens = current_lens + num_valid_tokens - is_full = new_text_lens > sliding_window - - # Handle full caches - set entire sliding window to 1 - if is_full.any(): - full_mask = is_full - full_cache_idxs = cache_idxs_tensor[full_mask] - self.attention_mask[full_cache_idxs, text_cache_start:] = 1 - - # Handle non-full caches - set specific ranges to 1 - if (~is_full).any(): - non_full_mask = ~is_full - non_full_cache_idxs = cache_idxs_tensor[non_full_mask] - non_full_current_lens = current_lens[non_full_mask] - non_full_valid_tokens = num_valid_tokens[non_full_mask] - - max_valid_tokens = non_full_valid_tokens.max().item() if len(non_full_valid_tokens) > 0 else 0 - if max_valid_tokens > 0: - batch_size = len(non_full_cache_idxs) - offset_range = torch.arange(max_valid_tokens, device=current_text_lens.device) - batch_offsets = offset_range.unsqueeze(0).expand(batch_size, -1) - start_positions = non_full_current_lens.unsqueeze(1) - valid_token_counts = non_full_valid_tokens.unsqueeze(1) - - position_indices = start_positions + batch_offsets - valid_mask = batch_offsets < valid_token_counts - - row_indices = non_full_cache_idxs.unsqueeze(1).expand(-1, max_valid_tokens)[valid_mask] - col_indices = text_cache_start + position_indices[valid_mask] - - self.attention_mask[row_indices, col_indices] = 1 - """ Static cache update - respects per-batch text token limits @@ -279,26 +180,33 @@ def _decode_update( value_cache[:, :, -sliding_window:] = v_slice_rolled # Insert only **valid tokens** into the cache. These are **right aligned** within the input sequence - seq_indices = torch.arange(seq_len, device=device)[None, :] - start_indices = seq_len - num_valid_tokens[:, None] - source_mask = seq_indices >= start_indices - source_mask_expanded = source_mask[:, None, :, None].expand(batch_size, num_head, seq_len, head_dim) - insert_positions = torch.where( needs_rotate, max_cache_len - num_valid_tokens, text_token_counts + cache_text_start ) - # Step 2: Create target mask in cache coordinates - cache_indices = torch.arange(max_cache_len, device=device)[None, :] - insert_start = insert_positions[:, None] - insert_end = insert_start + num_valid_tokens[:, None] - cache_target_mask = ((cache_indices >= insert_start) & - (cache_indices < insert_end)) - cache_target_mask_expanded = cache_target_mask[:, None, :, None].expand(batch_size, num_head, max_cache_len, head_dim) - - key_cache[cache_target_mask_expanded] = key_states[source_mask_expanded] - value_cache[cache_target_mask_expanded] = value_states[source_mask_expanded] + + max_tokens = num_valid_tokens.max().item() + offsets = torch.arange(max_tokens, device=device).unsqueeze(0) # [1, max_T] + valid_mask = offsets < num_valid_tokens.unsqueeze(1) # [B, max_T] + src_indices = (seq_len - num_valid_tokens).unsqueeze(1) + offsets # [B, max_T] + src_indices = src_indices.clamp(max=seq_len - 1) # safety + + tgt_indices = insert_positions.unsqueeze(1) + offsets # [B, max_T] + tgt_indices = tgt_indices.clamp(max=max_cache_len - 1) # safety + + src_idx_exp = src_indices.unsqueeze(1).unsqueeze(-1).expand(batch_size, num_head, max_tokens, head_dim) + tgt_idx_exp = tgt_indices.unsqueeze(1).unsqueeze(-1).expand(batch_size, num_head, max_tokens, head_dim) + valid_mask_exp = valid_mask.unsqueeze(1).unsqueeze(-1).expand(batch_size, num_head, max_tokens, head_dim) + + k_src = torch.gather(key_states, 2, src_idx_exp) + v_src = torch.gather(value_states, 2, src_idx_exp) + k_src = k_src * valid_mask_exp + v_src = v_src * valid_mask_exp + + # Write into cache + key_cache.scatter_(2, tgt_idx_exp, k_src) + value_cache.scatter_(2, tgt_idx_exp, v_src) # In-place edit - Mutates text_token_counts += num_valid_tokens @@ -306,6 +214,125 @@ def _decode_update( return key_cache, value_cache + +class DynamicOpsContinuousBatchingCache: + def __init__( + self, + config: PretrainedConfig, + batch_size: int, + max_cache_len: int, + text_sliding_window: int, + device: int, + dtype: int, + ): + self.text_sliding_window = text_sliding_window + self.num_layers = config.num_hidden_layers + self.max_cache_len = max_cache_len + self.max_batch_size = batch_size + + self.attention_mask = torch.zeros( + (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long + ) + self.layer_caches = [ + DynamicOpsContinuousBatchingLayerCache( + config, + batch_size=batch_size, + max_cache_len=max_cache_len, + text_sliding_window=text_sliding_window, + device=device, + dtype=dtype, + ) + for _ in range(self.num_layers) + ] + + self.dtype = dtype + self.device = device + + # """ + # Matches the logic of the layer decode update, but needs to be called before the updates + # since some parts of the model depend on the attention mask + # """ + def decode_attention_mask_update( + self, num_valid_tokens: torch.Tensor, cache_idxs: List[int] + ): + sliding_window = self.text_sliding_window + text_cache_start = self.max_cache_len - sliding_window + + # Using text_token_counts of first layer, should be same for all though + current_text_lens = self.layer_caches[0].text_token_counts + cache_idxs_tensor = torch.tensor(cache_idxs, device=current_text_lens.device) + + # Get current text lengths for the relevant cache indices + current_lens = current_text_lens[cache_idxs_tensor] + new_text_lens = current_lens + num_valid_tokens + is_full = new_text_lens > sliding_window + + # Handle full caches - set entire sliding window to 1 + if is_full.any(): + full_mask = is_full + full_cache_idxs = cache_idxs_tensor[full_mask] + self.attention_mask[full_cache_idxs, text_cache_start:] = 1 + + # Handle non-full caches - set specific ranges to 1 + if (~is_full).any(): + non_full_mask = ~is_full + non_full_cache_idxs = cache_idxs_tensor[non_full_mask] + non_full_current_lens = current_lens[non_full_mask] + non_full_valid_tokens = num_valid_tokens[non_full_mask] + + max_valid_tokens = non_full_valid_tokens.max().item() if len(non_full_valid_tokens) > 0 else 0 + if max_valid_tokens > 0: + batch_size = len(non_full_cache_idxs) + offset_range = torch.arange(max_valid_tokens, device=current_text_lens.device) + batch_offsets = offset_range.unsqueeze(0).expand(batch_size, -1) + start_positions = non_full_current_lens.unsqueeze(1) + valid_token_counts = non_full_valid_tokens.unsqueeze(1) + + position_indices = start_positions + batch_offsets + valid_mask = batch_offsets < valid_token_counts + + row_indices = non_full_cache_idxs.unsqueeze(1).expand(-1, max_valid_tokens)[valid_mask] + col_indices = text_cache_start + position_indices[valid_mask] + + self.attention_mask[row_indices, col_indices] = 1 + + + # Mirrors the logic from _prefill_update - Check that function for clearer explanation + def prefill_attention_mask_update( + self, + prefill_attention_mask: torch.Tensor, + cache_idxs: List[int], + text_lengths: List[int], + ): + seq_len = prefill_attention_mask.shape[1] + sliding_window = self.text_sliding_window + total_cache_len = self.max_cache_len + prefix_cache_space = total_cache_len - sliding_window + + for batch_idx, cache_idx in enumerate(cache_idxs): + text_len = text_lengths[batch_idx] + prefix_len = seq_len - text_len + self.attention_mask[cache_idx] = 0 # Set default + + assert prefix_len > 0, "There are no prefix (image) tokens!" + + end_pos = prefix_cache_space + # Handle prefix part - Which may be left padded + if prefix_len <= prefix_cache_space: + start_pos = prefix_cache_space - prefix_len + self.attention_mask[cache_idx, start_pos: end_pos] = prefill_attention_mask[batch_idx, :prefix_len] + else: + self.attention_mask[cache_idx, :end_pos] = prefill_attention_mask[batch_idx, prefix_len - prefix_cache_space: prefix_len] + + # Handle text part, keeping sliding window in consideration + # All of the left padding is before the prefix, so we can ignore the prefill_attention_mask here + if text_len > 0: + text_cache_start = prefix_cache_space + if text_len <= sliding_window: + self.attention_mask[cache_idx, text_cache_start: text_cache_start + text_len] = 1 + else: + self.attention_mask[cache_idx, -sliding_window: ] = 1 + # The attention mask managed by our kv cache automatically masks the tokens # in the cache, so we can return full length for HF to use in other places # This is mainly utilized in the cache_positions creation diff --git a/surya/foundation/cache/static.py b/surya/foundation/cache/static.py new file mode 100644 index 00000000..b7cbccfc --- /dev/null +++ b/surya/foundation/cache/static.py @@ -0,0 +1,226 @@ +from typing import Any, Dict, List, Optional, Tuple +import torch +from transformers import StaticCache +from transformers import PretrainedConfig + +""" +Special cache class for the surya foundation model that supports - +1) Static shape +2) A custom sliding window, where image tokens stay in cache, and text tokens are popped +3) Continuous batching - merging etc +4) Attention mask management - To match with what's currently in the cache + +Heavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079 +""" + + +class StaticOpsContinuousBatchingLayerCache(StaticCache): + def __init__( + self, + config: PretrainedConfig, + batch_size: int, + max_cache_len: int, + text_sliding_window: int, + device: torch.device, + dtype: torch.dtype, + ): + # No need for the super class call, it just overwrites the caches + # At some point, we should consider not inheriting from StaticCache + self.max_cache_len = max_cache_len + self.max_batch_size = batch_size + + self.head_dim = ( + getattr(config, "head_dim", None) + or config.hidden_size // config.num_attention_heads + ) + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + + cache_shape = ( + self.max_batch_size, + self.num_key_value_heads, + self.max_cache_len, + self.head_dim, + ) + self.key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + self.value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + + self.text_sliding_window = text_sliding_window + self.text_token_counts = torch.zeros( + self.max_batch_size, dtype=torch.long, device=device + ) + self.cache_image_end = self.max_cache_len - self.text_sliding_window + + self.dtype = dtype + self.device = device + + # This is used by HF models to determine the causal relationship between new tokens and cache + # Our cache is left padded - So all tokens should always be visible to new tokens + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + return self.max_cache_len + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + prefill = cache_kwargs.get("prefill", False) + if prefill: + return self._prefill_update( + self.key_cache, + self.value_cache, + key_states, + value_states, + cache_kwargs, + ) + else: + return self._decode_update(key_states, value_states, cache_kwargs) + + def _prefill_update( + self, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[Dict[str, Any]] = None, + ): + cache_idxs: torch.tensor = cache_kwargs.get("cache_idxs", None) + text_lengths: List[int] = cache_kwargs.get("text_lengths", None) + cache_idx_length: int = cache_kwargs.get("cache_idxs_length", None) + assert cache_idxs is not None, "cache_idxs must be specified during prefill" + assert text_lengths is not None, "text_lengths must be specified during prefill" + + cache_idxs = cache_idxs[ + :cache_idx_length + ] # Ensure we only use the valid indices + + # Insert key and value states at the end of the cache + new_tokens = key_states.shape[2] + + # Direct right-aligned assignment + key_cache[cache_idxs, :, -new_tokens:] = key_states[:cache_idx_length] + value_cache[cache_idxs, :, -new_tokens:] = value_states[:cache_idx_length] + + return key_states, value_states + + def _decode_update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Static cache update + - respects per-batch text token limits + - per-batch valid token lengths (right-padded inputs) + + kv states are expected to have shape [batch_size, kv_heads, T_pad, head_dim] + They may have different `true` lengths, to account for multi token preds, or beacon tokens + Expects `num_valid_tokens` in cache_kwargs: a tensor of shape (B,) indicating the number + of actual (non-padded) tokens to add per batch element. + """ + + num_valid_tokens: torch.Tensor = cache_kwargs.get( + "num_valid_tokens" + ) # shape: (B,) + assert num_valid_tokens is not None, ( + "`num_valid_tokens` must be provided in `cache_kwargs`" + ) + # (B, H, L, D) + max_valid_tokens = num_valid_tokens.max().item() + + self.key_cache = torch.roll(self.key_cache, -max_valid_tokens, dims=2) + self.value_cache = torch.roll(self.value_cache, -max_valid_tokens, dims=2) + + new_k = key_states[:, :, -max_valid_tokens:, :] + new_v = value_states[:, :, -max_valid_tokens:, :] + + self.key_cache[:, :, -max_valid_tokens:, :] = new_k + self.value_cache[:, :, -max_valid_tokens:, :] = new_v + return self.key_cache, self.value_cache + + +class StaticOpsContinuousBatchingCache: + def __init__( + self, + config: PretrainedConfig, + batch_size: int, + max_cache_len: int, + text_sliding_window: int, + device: int, + dtype: int, + ): + self.text_sliding_window = text_sliding_window + self.num_layers = config.num_hidden_layers + self.max_cache_len = max_cache_len + self.max_batch_size = batch_size + self.cache_image_end = self.max_cache_len - self.text_sliding_window + + self.attention_mask = torch.zeros( + (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long + ) + self.layer_caches = [ + StaticOpsContinuousBatchingLayerCache( + config, + batch_size=batch_size, + max_cache_len=max_cache_len, + text_sliding_window=text_sliding_window, + device=device, + dtype=dtype, + ) + for _ in range(self.num_layers) + ] + + self.dtype = dtype + self.device = device + + def decode_attention_mask_update( + self, num_valid_tokens: torch.Tensor, cache_idxs: List[int] + ): + max_valid_tokens = num_valid_tokens.max().item() + if max_valid_tokens == 0: + # If no valid tokens, we don't need to update the attention mask + return + + # Shift the attention mask to the left by max_valid_tokens + self.attention_mask = self.attention_mask.roll(-1 * max_valid_tokens, dims=1) + self.attention_mask[:, -max_valid_tokens:] = 0 # Mask out all new tokens + + seq_len = self.attention_mask.shape[1] + positions = torch.arange(seq_len, device=self.attention_mask.device).unsqueeze( + 0 + ) + + # Since cache_idxs is padded, num_valid_tokens should also be padded with zeros + # for inactive positions, so we can process the full batch uniformly + valid_mask = (positions >= (seq_len - num_valid_tokens.unsqueeze(1))).to( + dtype=self.attention_mask.dtype + ) + + # Update the attention mask for the current batch elements + self.attention_mask = self.attention_mask | valid_mask + + # Mirrors the logic from _prefill_update + def prefill_attention_mask_update( + self, + attention_mask: torch.Tensor, + cache_idxs: torch.Tensor, + text_lengths: List[int], + ): + # Set from -(image_length + text_length) to end to 1 for each batch element + seq_len = attention_mask.shape[1] + self.attention_mask[cache_idxs] = ( + 0 # Reset the attention mask for the current batch elements + ) + self.attention_mask[cache_idxs, -seq_len:] = attention_mask[ + : cache_idxs.size(0) + ] + + # This is used by HF models to determine the causal relationship between new tokens and cache + # Our cache is left padded - So all tokens should always be visible to new tokens + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + return self.max_cache_len diff --git a/surya/foundation/loader.py b/surya/foundation/loader.py index f96c07ce..fc8e3f7c 100644 --- a/surya/foundation/loader.py +++ b/surya/foundation/loader.py @@ -9,11 +9,13 @@ from surya.common.surya.processor import SuryaOCRProcessor from surya.common.surya.processor.tokenizer import SuryaOCRTokenizer from surya.common.util import is_flash_attn_2_supported +from surya.common.xla import get_compile_args from surya.logging import get_logger from surya.settings import settings logger = get_logger() + class FoundationModelLoader(ModelLoader): def __init__(self, checkpoint: Optional[str] = None): super().__init__(checkpoint) @@ -42,6 +44,8 @@ def model( if is_flash_attn_2_available() and is_flash_attn_2_supported(device): config.decoder._attn_implementation = "flash_attention_2" config.vision_encoder._attn_implementation = "flash_attention_2" + elif "xla" in str(device): + config.vision_encoder._attn_implementation = "flash_attention_xla" else: config.decoder._attn_implementation = "sdpa" config.vision_encoder._attn_implementation = "sdpa" @@ -51,6 +55,20 @@ def model( ).to(device) model = model.eval() + if settings.COMPILE_ALL or settings.COMPILE_FOUNDATION: + torch._dynamo.config.cache_size_limit = 1000 + torch._dynamo.config.suppress_errors = False + torch._dynamo.config.specialize_int = False + torch._dynamo.config.recompile_limit = 32 + + logger.info( + f"Compiling foundation model {self.checkpoint} on device {device} with dtype {dtype}" + ) + compile_args = get_compile_args(device) + model.encoder = torch.compile(model.vision_encoder, **compile_args) + model.decoder = torch.compile(model.decoder, **compile_args) + model.embedder = torch.compile(model.embedder, **compile_args) + logger.debug( f"Loaded recognition model {self.checkpoint} from {SuryaModel.get_local_path(self.checkpoint)} onto device {model.device} with dtype {dtype}, using decoder attention mechanism {model.config.decoder._attn_implementation}, encoder attention mechanism {model.config.vision_encoder._attn_implementation}." ) @@ -74,7 +92,7 @@ def processor( merge_size=config.vision_encoder.spatial_merge_size, model_device=device, num_beacon_tokens=config.num_beacon_tokens, - beacon_token_interval=config.beacon_token_interval + beacon_token_interval=config.beacon_token_interval, ) - return processor \ No newline at end of file + return processor diff --git a/surya/layout/__init__.py b/surya/layout/__init__.py index 7964368b..7792070c 100644 --- a/surya/layout/__init__.py +++ b/surya/layout/__init__.py @@ -1,12 +1,8 @@ from typing import List -import numpy as np -import torch from PIL import Image -from tqdm import tqdm from surya.common.predictor import BasePredictor -from surya.common.util import clean_boxes from surya.layout.schema import LayoutBox, LayoutResult from surya.settings import settings from surya.foundation import FoundationPredictor, TaskNames @@ -14,14 +10,10 @@ from surya.input.processing import convert_if_not_rgb from surya.layout.label import LAYOUT_PRED_RELABEL + class LayoutPredictor(BasePredictor): batch_size = settings.LAYOUT_BATCH_SIZE - default_batch_sizes = { - "cpu": 4, - "mps": 4, - "cuda": 32, - "xla": 16 - } + default_batch_sizes = {"cpu": 4, "mps": 4, "cuda": 32, "xla": 16} # Override base init - Do not load model def __init__(self, foundation_predictor: FoundationPredictor): @@ -31,14 +23,9 @@ def __init__(self, foundation_predictor: FoundationPredictor): self.tasks = self.foundation_predictor.tasks def __call__( - self, - images: List[Image.Image], - batch_size: int | None = None, - top_k: int = 5 + self, images: List[Image.Image], batch_size: int | None = None, top_k: int = 5 ) -> List[LayoutResult]: assert all([isinstance(image, Image.Image) for image in images]) - if batch_size is None: - batch_size = self.get_batch_size() images = convert_if_not_rgb(images) images = [self.processor.image_processor(image) for image in images] @@ -48,18 +35,22 @@ def __call__( input_texts=["" for _ in range(len(images))], task_names=[TaskNames.layout for _ in range(len(images))], batch_size=batch_size, - max_lookahead_tokens=0 # Do not do MTP for layout + max_lookahead_tokens=0, # Do not do MTP for layout + top_k=5 ) - + image_sizes = [img.shape for img in images] predicted_polygons = prediction_to_polygon_batch( batch_bboxes, image_sizes, self.bbox_size, self.bbox_size // 2 ) - layout_results = [] - for image, image_tokens, image_polygons, image_scores, image_topk_scores in zip(images, predicted_tokens, predicted_polygons, scores, topk_scores): + for image, image_tokens, image_polygons, image_scores, image_topk_scores in zip( + images, predicted_tokens, predicted_polygons, scores, topk_scores + ): layout_boxes = [] - for z, (tok, poly, score, tok_topk) in enumerate(zip(image_tokens, image_polygons, image_scores, image_topk_scores)): + for z, (tok, poly, score, tok_topk) in enumerate( + zip(image_tokens, image_polygons, image_scores, image_topk_scores) + ): if tok == self.processor.eos_token_id: break @@ -68,23 +59,23 @@ def __call__( top_k_dict = {} for k, v in tok_topk.items(): - l = self.processor.decode([k], "layout") - if l in LAYOUT_PRED_RELABEL: - l = LAYOUT_PRED_RELABEL[l] - top_k_dict.update({l: v}) - layout_boxes.append(LayoutBox( - polygon=poly.tolist(), - label=label, - position=z, - top_k=top_k_dict, - confidence=score - )) + topk_label = self.processor.decode([k], "layout") + if topk_label in LAYOUT_PRED_RELABEL: + topk_label = LAYOUT_PRED_RELABEL[topk_label] + top_k_dict.update({topk_label: v}) + layout_boxes.append( + LayoutBox( + polygon=poly.tolist(), + label=label, + position=z, + top_k=top_k_dict, + confidence=score, + ) + ) # layout_boxes = clean_boxes(layout_boxes) - layout_results.append(LayoutResult( - bboxes=layout_boxes, - image_bbox=[0, 0, *image.shape] - )) - + layout_results.append( + LayoutResult(bboxes=layout_boxes, image_bbox=[0, 0, *image.shape]) + ) assert len(layout_results) == len(images) - return layout_results \ No newline at end of file + return layout_results diff --git a/surya/ocr_error/__init__.py b/surya/ocr_error/__init__.py index 9bac7b83..782b6c7b 100644 --- a/surya/ocr_error/__init__.py +++ b/surya/ocr_error/__init__.py @@ -8,7 +8,7 @@ from surya.ocr_error.model.config import ID2LABEL from surya.ocr_error.schema import OCRErrorDetectionResult from surya.settings import settings -from surya.common.util import mark_step +from surya.common.xla import mark_step class OCRErrorPredictor(BasePredictor): diff --git a/surya/ocr_error/loader.py b/surya/ocr_error/loader.py index 07851303..b21c6e4a 100644 --- a/surya/ocr_error/loader.py +++ b/surya/ocr_error/loader.py @@ -3,6 +3,7 @@ import torch from surya.common.load import ModelLoader +from surya.common.xla import get_compile_args from surya.logging import get_logger from surya.ocr_error.model.config import DistilBertConfig from surya.ocr_error.model.encoder import DistilBertForSequenceClassification @@ -46,7 +47,7 @@ def model( logger.info( f"Compiling detection model {self.checkpoint} from {DistilBertForSequenceClassification.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" ) - compile_args = {"backend": "openxla"} if device == "xla" else {} + compile_args = get_compile_args(device) model = torch.compile(model, **compile_args) return model diff --git a/surya/recognition/__init__.py b/surya/recognition/__init__.py index 09bb2030..1d65fd74 100644 --- a/surya/recognition/__init__.py +++ b/surya/recognition/__init__.py @@ -25,6 +25,7 @@ clean_close_polygons, unwrap_math, clean_math_tags, + filter_blacklist_tags, words_from_chars ) from surya.foundation.util import detect_repeat_token, prediction_to_polygon_batch @@ -347,6 +348,8 @@ def __call__( math_mode: bool = True, return_words: bool = False, drop_repeated_text: bool = False, + max_sliding_window: int | None = None, + max_tokens: int | None = None, ) -> List[OCRResult]: if task_names is None: task_names = [TaskNames.ocr_with_boxes] * len(images) @@ -402,7 +405,12 @@ def __call__( # No images passed, or no boxes passed, or no text detected in the images if len(flat["slices"]) == 0: - return [] + return [ + OCRResult( + text_lines=[], image_bbox=[0, 0, im.size[0], im.size[1]] + ) + for im in images + ] # Sort by line widths. Negative so that longer images come first, fits in with continuous batching better sorted_pairs = sorted(enumerate(flat["slices"]), key=lambda x: -x[1].shape[1]) @@ -421,6 +429,9 @@ def __call__( batch_size=recognition_batch_size, math_mode=math_mode, drop_repeated_tokens=True, + max_lookahead_tokens=0, + max_sliding_window=max_sliding_window, + max_tokens=max_tokens, ) # Get text and bboxes in structured form @@ -481,6 +492,7 @@ def __call__( text_line = fix_unbalanced_tags( text_line, self.processor.ocr_tokenizer.special_tokens ) + text_line = filter_blacklist_tags(text_line) text = "".join([char.text for char in text_line]) text = unwrap_math(text) text = clean_math_tags(text) diff --git a/surya/recognition/util.py b/surya/recognition/util.py index 7131e83d..51d4a4b2 100644 --- a/surya/recognition/util.py +++ b/surya/recognition/util.py @@ -29,6 +29,43 @@ def unwrap_math(text: str) -> str: MATH_BLOCK = re.compile(r"(]*>)(.*?)", flags=re.I | re.S) STRIP_TAGS = re.compile(r"]*>", flags=re.I | re.S) +BLACKLIST_TAGS = {"p", "li", "ul", "ol", "table", "td", "tr", "th"} + +def filter_blacklist_tags(text_chars: List[TextChar]) -> List[TextChar]: + filtered_chars = [] + char_buffer = [] + in_tag = False + + for text_char in text_chars: + char = text_char.text + + if char == "<": + in_tag = True + char_buffer = [text_char] + elif in_tag: + char_buffer.append(text_char) + if char == ">": + full_tag = ''.join(c.text for c in char_buffer) + inner = full_tag[1:-1].strip() # remove < > + tag_name_candidate = inner.strip("/").split()[0] # remove '/' and any attributes + + if tag_name_candidate in BLACKLIST_TAGS: + # Discard tag + pass + else: + # Keep tag + filtered_chars.extend(char_buffer) + + in_tag = False + char_buffer = [] + else: + filtered_chars.append(text_char) + + # Flush buffer if we never reached a tag close + if char_buffer: + filtered_chars.extend(char_buffer) + + return filtered_chars def clean_math_tags(html: str) -> str: diff --git a/surya/settings.py b/surya/settings.py index e987ca4e..2142d0b8 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -76,10 +76,11 @@ def TORCH_DEVICE_MODEL(self) -> str: COMPILE_DETECTOR: bool = False # Text recognition - FOUNDATION_MODEL_CHECKPOINT: str = "datalab-to/foundation-2.10" + FOUNDATION_MODEL_CHECKPOINT: str = "datalab-to/foundation-alpha-nocce-continue" FOUNDATION_MODEL_QUANTIZE: bool = False FOUNDATION_MAX_TOKENS: Optional[int] = None FOUNDATION_CHUNK_SIZE: Optional[int] = None + FOUNDATION_PAD_TO_NEAREST: int = 256 # Pad to the nearest multiple of this value COMPILE_FOUNDATION: bool = False RECOGNITION_BATCH_SIZE: Optional[int] = ( diff --git a/surya/table_rec/__init__.py b/surya/table_rec/__init__.py index ddeb6bb8..0ea88399 100644 --- a/surya/table_rec/__init__.py +++ b/surya/table_rec/__init__.py @@ -7,57 +7,72 @@ from PIL import Image from tqdm import tqdm -from surya.common.util import mark_step +from surya.common.xla import mark_step from surya.common.predictor import BasePredictor from surya.table_rec.schema import TableCell, TableRow, TableCol, TableResult from surya.common.polygon import PolygonBox from surya.settings import settings from surya.table_rec.loader import TableRecModelLoader -from surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM, CATEGORY_TO_ID, MERGE_KEYS, \ - MERGE_VALUES +from surya.table_rec.model.config import ( + BOX_PROPERTIES, + SPECIAL_TOKENS, + BOX_DIM, + CATEGORY_TO_ID, + MERGE_KEYS, + MERGE_VALUES, +) from surya.table_rec.shaper import LabelShaper class TableRecPredictor(BasePredictor): model_loader_cls = TableRecModelLoader batch_size = settings.TABLE_REC_BATCH_SIZE - default_batch_sizes = { - "cpu": 8, - "mps": 8, - "cuda": 32, - "xla": 16 - } - - def __call__(self, images: List[Image.Image], batch_size: int | None = None) -> List[TableResult]: + default_batch_sizes = {"cpu": 8, "mps": 8, "cuda": 32, "xla": 16} + + def __call__( + self, images: List[Image.Image], batch_size: int | None = None + ) -> List[TableResult]: return self.batch_table_recognition(images, batch_size) def inference_loop( - self, - encoder_hidden_states: torch.Tensor, - batch_input_ids: torch.Tensor, - current_batch_size: int, - batch_size: int + self, + encoder_hidden_states: torch.Tensor, + batch_input_ids: torch.Tensor, + current_batch_size: int, + batch_size: int, ): shaper = LabelShaper() batch_predictions = [[] for _ in range(current_batch_size)] max_tokens = settings.TABLE_REC_MAX_BOXES - decoder_position_ids = torch.ones_like(batch_input_ids[0, :, 0], dtype=torch.int64, device=self.model.device).cumsum( - 0) - 1 + decoder_position_ids = ( + torch.ones_like( + batch_input_ids[0, :, 0], dtype=torch.int64, device=self.model.device + ).cumsum(0) + - 1 + ) inference_token_count = batch_input_ids.shape[1] if settings.TABLE_REC_STATIC_CACHE: - encoder_hidden_states = self.pad_to_batch_size(encoder_hidden_states, batch_size) + encoder_hidden_states = self.pad_to_batch_size( + encoder_hidden_states, batch_size + ) batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) # Move to device after padding for XLA encoder_hidden_states = encoder_hidden_states.to(self.model.device) batch_input_ids = batch_input_ids.to(self.model.device) - self.model.decoder.model._setup_cache(self.model.config, batch_size, self.model.device, self.model.dtype) + self.model.decoder.model._setup_cache( + self.model.config, batch_size, self.model.device, self.model.dtype + ) with settings.INFERENCE_MODE(): token_count = 0 - all_done = torch.zeros(encoder_hidden_states.shape[0], dtype=torch.bool, device=self.model.device) + all_done = torch.zeros( + encoder_hidden_states.shape[0], + dtype=torch.bool, + device=self.model.device, + ) while token_count < max_tokens: is_prefill = token_count == 0 @@ -66,7 +81,7 @@ def inference_loop( encoder_hidden_states=encoder_hidden_states, cache_position=decoder_position_ids, use_cache=True, - prefill=is_prefill + prefill=is_prefill, ) decoder_position_ids = decoder_position_ids[-1:] + 1 @@ -78,13 +93,17 @@ def inference_loop( # Pre-process all logits at once processed_logits = {} for k, _, mode in BOX_PROPERTIES: - k_logits = return_dict["box_property_logits"][k][:, -1, :] # Get all batch logits at once - + k_logits = return_dict["box_property_logits"][k][ + :, -1, : + ] # Get all batch logits at once + if mode == "classification": # Process all classification logits in one operation items = torch.argmax(k_logits, dim=-1) if k == "category": - done = (items == self.model.decoder.config.eos_token_id) | (items == self.model.decoder.config.pad_token_id) + done = (items == self.model.decoder.config.eos_token_id) | ( + items == self.model.decoder.config.pad_token_id + ) items = items - SPECIAL_TOKENS processed_logits[k] = items elif mode == "regression": @@ -114,10 +133,16 @@ def inference_loop( if all_done_cpu[:current_batch_size].all(): break - batch_input_ids = torch.tensor(shaper.dict_to_labels(box_properties), dtype=torch.long) - batch_input_ids = batch_input_ids.unsqueeze(1) # Add sequence length dimension + batch_input_ids = torch.tensor( + shaper.dict_to_labels(box_properties), dtype=torch.long + ) + batch_input_ids = batch_input_ids.unsqueeze( + 1 + ) # Add sequence length dimension - for j, (box_property, status) in enumerate(zip(box_properties, all_done_cpu)): + for j, (box_property, status) in enumerate( + zip(box_properties, all_done_cpu) + ): if not status: batch_predictions[j].append(box_property) @@ -125,16 +150,17 @@ def inference_loop( inference_token_count = batch_input_ids.shape[1] if settings.TABLE_REC_STATIC_CACHE: - batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) + batch_input_ids = self.pad_to_batch_size( + batch_input_ids, batch_size + ) # Move to device after padding for XLA batch_input_ids = batch_input_ids.to(self.model.device) return batch_predictions def batch_table_recognition( - self, - images: List, - batch_size=None) -> List[TableResult]: + self, images: List, batch_size=None + ) -> List[TableResult]: assert all([isinstance(image, Image.Image) for image in images]) if batch_size is None: batch_size = self.get_batch_size() @@ -144,33 +170,52 @@ def batch_table_recognition( query_items = [] for image in images: - query_items.append({ - "polygon": [[0, 0], [image.width, 0], [image.width, image.height], [0, image.height]], - "category": CATEGORY_TO_ID["Table"], - "colspan": 0, - "merges": 0, - "is_header": 0 - }) + query_items.append( + { + "polygon": [ + [0, 0], + [image.width, 0], + [image.width, image.height], + [0, image.height], + ], + "category": CATEGORY_TO_ID["Table"], + "colspan": 0, + "merges": 0, + "is_header": 0, + } + ) output_order = [] - for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables", disable=self.disable_tqdm): - batch_query_items = query_items[i:i + batch_size] - - batch_images = images[i:i + batch_size] - batch_images = [image.convert("RGB") for image in batch_images] # also copies the images + for i in tqdm( + range(0, len(images), batch_size), + desc="Recognizing tables", + disable=self.disable_tqdm, + ): + batch_query_items = query_items[i : i + batch_size] + + batch_images = images[i : i + batch_size] + batch_images = [ + image.convert("RGB") for image in batch_images + ] # also copies the images current_batch_size = len(batch_images) orig_sizes = [image.size for image in batch_images] - model_inputs = self.processor(images=batch_images, query_items=batch_query_items) + model_inputs = self.processor( + images=batch_images, query_items=batch_query_items + ) batch_pixel_values = model_inputs["pixel_values"] batch_input_ids = model_inputs["input_ids"] - batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=self.model.dtype) + batch_pixel_values = torch.tensor( + np.array(batch_pixel_values), dtype=self.model.dtype + ) if settings.TABLE_REC_STATIC_CACHE: - batch_pixel_values = self.pad_to_batch_size(batch_pixel_values, batch_size) + batch_pixel_values = self.pad_to_batch_size( + batch_pixel_values, batch_size + ) # Move to device after padding for XLA batch_pixel_values = batch_pixel_values.to(self.model.device) @@ -179,14 +224,13 @@ def batch_table_recognition( # We only need to process each image once with settings.INFERENCE_MODE(): - encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state + encoder_hidden_states = self.model.encoder( + pixel_values=batch_pixel_values + ).last_hidden_state # Inference to get rows and columns rowcol_predictions = self.inference_loop( - encoder_hidden_states, - batch_input_ids, - current_batch_size, - batch_size + encoder_hidden_states, batch_input_ids, current_batch_size, batch_size ) mark_step() @@ -198,90 +242,125 @@ def batch_table_recognition( for row_prediction in img_predictions: polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"]) if row_prediction["category"] == CATEGORY_TO_ID["Table-row"]: - row_query_items.append({ - "polygon": polygon, - "category": row_prediction["category"], - "colspan": 0, - "merges": 0, - "is_header": int(row_prediction["is_header"] == 1) - }) + row_query_items.append( + { + "polygon": polygon, + "category": row_prediction["category"], + "colspan": 0, + "merges": 0, + "is_header": int(row_prediction["is_header"] == 1), + } + ) row_encoder_hidden_states.append(encoder_hidden_states[j]) idx_map.append(j) elif row_prediction["category"] == CATEGORY_TO_ID["Table-column"]: - columns.append({ - "polygon": polygon, - "category": row_prediction["category"], - "colspan": 0, - "merges": 0, - "is_header": int(row_prediction["is_header"] == 1) - }) + columns.append( + { + "polygon": polygon, + "category": row_prediction["category"], + "colspan": 0, + "merges": 0, + "is_header": int(row_prediction["is_header"] == 1), + } + ) # Re-inference to predict cells row_encoder_hidden_states = torch.stack(row_encoder_hidden_states) - row_inputs = self.processor(images=None, query_items=row_query_items, columns=columns, convert_images=False) + row_inputs = self.processor( + images=None, + query_items=row_query_items, + columns=columns, + convert_images=False, + ) row_input_ids = row_inputs["input_ids"] cell_predictions = [] for j in range(0, len(row_input_ids), batch_size): - cell_batch_hidden_states = row_encoder_hidden_states[j:j + batch_size] - cell_batch_input_ids = row_input_ids[j:j + batch_size] + cell_batch_hidden_states = row_encoder_hidden_states[j : j + batch_size] + cell_batch_input_ids = row_input_ids[j : j + batch_size] cell_batch_size = len(cell_batch_input_ids) cell_predictions.extend( - self.inference_loop(cell_batch_hidden_states, cell_batch_input_ids, cell_batch_size, batch_size) + self.inference_loop( + cell_batch_hidden_states, + cell_batch_input_ids, + cell_batch_size, + batch_size, + ) ) mark_step() - result = self.decode_batch_predictions(rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper) + result = self.decode_batch_predictions( + rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper + ) output_order.extend(result) return output_order - - def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper): + def decode_batch_predictions( + self, rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper + ): results = [] - for j, (img_predictions, orig_size) in enumerate(zip(rowcol_predictions, orig_sizes)): - row_cell_predictions = [c for i, c in enumerate(cell_predictions) if idx_map[i] == j] + for j, (img_predictions, orig_size) in enumerate( + zip(rowcol_predictions, orig_sizes) + ): + row_cell_predictions = [ + c for i, c in enumerate(cell_predictions) if idx_map[i] == j + ] # Each row prediction matches a cell prediction rows = [] cells = [] columns = [] cell_id = 0 - row_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-row"]] - col_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-column"]] + row_predictions = [ + pred + for pred in img_predictions + if pred["category"] == CATEGORY_TO_ID["Table-row"] + ] + col_predictions = [ + pred + for pred in img_predictions + if pred["category"] == CATEGORY_TO_ID["Table-column"] + ] # Generate table columns for z, col_prediction in enumerate(col_predictions): polygon = shaper.convert_bbox_to_polygon(col_prediction["bbox"]) - polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size) + polygon = self.processor.resize_polygon( + polygon, (BOX_DIM, BOX_DIM), orig_size + ) columns.append( TableCol( polygon=polygon, col_id=z, - is_header=col_prediction["is_header"] == 1 + is_header=col_prediction["is_header"] == 1, ) ) # Generate table rows for z, row_prediction in enumerate(row_predictions): polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"]) - polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size) + polygon = self.processor.resize_polygon( + polygon, (BOX_DIM, BOX_DIM), orig_size + ) row = TableRow( polygon=polygon, row_id=z, - is_header=row_prediction["is_header"] == 1 + is_header=row_prediction["is_header"] == 1, ) rows.append(row) # Get cells that span multiple columns within a row spanning_cells = [] - for l, spanning_cell in enumerate(row_cell_predictions[z]): + for col_idx, spanning_cell in enumerate(row_cell_predictions[z]): polygon = shaper.convert_bbox_to_polygon(spanning_cell["bbox"]) - polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size) + polygon = self.processor.resize_polygon( + polygon, (BOX_DIM, BOX_DIM), orig_size + ) colspan = max(1, int(spanning_cell["colspan"])) if colspan == 1 and spanning_cell["merges"] not in MERGE_VALUES: # Skip single column cells if they don't merge continue - if PolygonBox(polygon=polygon).height < row.height * .85: + if PolygonBox(polygon=polygon).height < row.height * 0.85: # Spanning cell must cover most of the row continue @@ -291,12 +370,13 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si row_id=z, rowspan=1, cell_id=cell_id, - within_row_id=l, + within_row_id=col_idx, colspan=colspan, - merge_up=spanning_cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]], - merge_down=spanning_cell["merges"] in [MERGE_KEYS["merge_down"], - MERGE_KEYS["merge_both"]], - is_header=row.is_header or z == 0 + merge_up=spanning_cell["merges"] + in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]], + merge_down=spanning_cell["merges"] + in [MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]], + is_header=row.is_header or z == 0, ) ) cell_id += 1 @@ -304,7 +384,7 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si # Add cells - either add spanning cells (multiple cols), or generate a cell based on row/col used_spanning_cells = set() skip_columns = 0 - for l, col in enumerate(columns): + for col_idx, col in enumerate(columns): if skip_columns: skip_columns -= 1 continue @@ -312,19 +392,30 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si cell_added = False for zz, spanning_cell in enumerate(spanning_cells): cell_polygonbox = PolygonBox(polygon=cell_polygon) - intersection_pct = cell_polygonbox.intersection_pct(spanning_cell) + intersection_pct = cell_polygonbox.intersection_pct( + spanning_cell + ) # Make sure cells intersect, and that the spanning cell is wider than the current cell (takes up multiple columns) - correct_col_width = sum([col.width for col in columns[l:l + spanning_cell.colspan]]) - if intersection_pct > .9: - if spanning_cell.width > (correct_col_width * .85): + correct_col_width = sum( + [ + col.width + for col in columns[ + col_idx : col_idx + spanning_cell.colspan + ] + ] + ) + if intersection_pct > 0.9: + if spanning_cell.width > (correct_col_width * 0.85): cell_added = True if zz not in used_spanning_cells: used_spanning_cells.add(zz) - spanning_cell.col_id = l + spanning_cell.col_id = col_idx cells.append(spanning_cell) - skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell + skip_columns = ( + spanning_cell.colspan - 1 + ) # Skip columns that are part of the spanning cell else: - used_spanning_cells.add(zz) # Skip this spanning cell + used_spanning_cells.add(zz) # Skip this spanning cell if not cell_added: cells.append( @@ -333,39 +424,40 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si row_id=z, rowspan=1, cell_id=cell_id, - within_row_id=l, + within_row_id=col_idx, colspan=1, merge_up=False, merge_down=False, - col_id=l, - is_header=row.is_header or col.is_header or z == 0 + col_id=col_idx, + is_header=row.is_header or col.is_header or z == 0, ) ) cell_id += 1 # Turn cells into a row grid - grid_cells = deepcopy([ - [cell for cell in cells if cell.row_id == row.row_id] - for row in rows - ]) + grid_cells = deepcopy( + [[cell for cell in cells if cell.row_id == row.row_id] for row in rows] + ) # Merge cells across rows for z, grid_row in enumerate(grid_cells[1:]): prev_row = grid_cells[z] - for l, cell in enumerate(grid_row): - if l >= len(prev_row): + for col_idx, cell in enumerate(grid_row): + if col_idx >= len(prev_row): continue - above_cell = prev_row[l] - if all([ - above_cell.merge_down, - cell.merge_up, - above_cell.col_id == cell.col_id, - above_cell.colspan == cell.colspan, - ]): + above_cell = prev_row[col_idx] + if all( + [ + above_cell.merge_down, + cell.merge_up, + above_cell.col_id == cell.col_id, + above_cell.colspan == cell.colspan, + ] + ): above_cell.merge(cell) above_cell.rowspan += cell.rowspan - grid_row[l] = above_cell + grid_row[col_idx] = above_cell merged_cells_all = list(chain.from_iterable(grid_cells)) used_ids = set() diff --git a/surya/table_rec/loader.py b/surya/table_rec/loader.py index 75ee8554..5fc52eae 100644 --- a/surya/table_rec/loader.py +++ b/surya/table_rec/loader.py @@ -3,6 +3,7 @@ import torch from surya.common.load import ModelLoader +from surya.common.xla import get_compile_args from surya.logging import get_logger from surya.settings import settings from surya.table_rec.model.config import ( @@ -55,7 +56,7 @@ def model( logger.info( f"Compiling table recognition model {self.checkpoint} on device {device} with dtype {dtype}" ) - compile_args = {"backend": "openxla"} if device == "xla" else {} + compile_args = get_compile_args(device) model.encoder = torch.compile(model.encoder, **compile_args) model.decoder = torch.compile(model.decoder, **compile_args)