|
35 | 35 | import torch |
36 | 36 | import torch.nn.functional as F |
37 | 37 | from easydict import EasyDict as edict |
38 | | -from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 |
39 | 38 | from huggingface_hub import hf_hub_download |
40 | 39 | from ocr_recog.RecModel import RecModel |
41 | 40 | from PIL import Image, ImageDraw, ImageFont |
@@ -520,6 +519,222 @@ def get_ctcloss(self, preds, gt_text, weight): |
520 | 519 | return loss |
521 | 520 |
|
522 | 521 |
|
| 522 | +import torch |
| 523 | +from torch import nn |
| 524 | +from transformers import CLIPTextModel, CLIPTokenizer |
| 525 | +from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask |
| 526 | + |
| 527 | + |
| 528 | +class AbstractEncoder(nn.Module): |
| 529 | + def __init__(self): |
| 530 | + super().__init__() |
| 531 | + |
| 532 | + def encode(self, *args, **kwargs): |
| 533 | + raise NotImplementedError |
| 534 | + |
| 535 | + |
| 536 | +class FrozenCLIPEmbedderT3(AbstractEncoder): |
| 537 | + """Uses the CLIP transformer encoder for text (from Hugging Face)""" |
| 538 | + |
| 539 | + def __init__( |
| 540 | + self, |
| 541 | + version="openai/clip-vit-large-patch14", |
| 542 | + device="cpu", |
| 543 | + max_length=77, |
| 544 | + freeze=True, |
| 545 | + use_fp16=False, |
| 546 | + ): |
| 547 | + super().__init__() |
| 548 | + self.tokenizer = CLIPTokenizer.from_pretrained(version) |
| 549 | + self.transformer = CLIPTextModel.from_pretrained( |
| 550 | + version, use_safetensors=True, torch_dtype=torch.float16 if use_fp16 else torch.float32 |
| 551 | + ).to(device) |
| 552 | + self.device = device |
| 553 | + self.max_length = max_length |
| 554 | + if freeze: |
| 555 | + self.freeze() |
| 556 | + |
| 557 | + def embedding_forward( |
| 558 | + self, |
| 559 | + input_ids=None, |
| 560 | + position_ids=None, |
| 561 | + inputs_embeds=None, |
| 562 | + embedding_manager=None, |
| 563 | + ): |
| 564 | + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] |
| 565 | + if position_ids is None: |
| 566 | + position_ids = self.position_ids[:, :seq_length] |
| 567 | + if inputs_embeds is None: |
| 568 | + inputs_embeds = self.token_embedding(input_ids) |
| 569 | + if embedding_manager is not None: |
| 570 | + inputs_embeds = embedding_manager(input_ids, inputs_embeds) |
| 571 | + position_embeddings = self.position_embedding(position_ids) |
| 572 | + embeddings = inputs_embeds + position_embeddings |
| 573 | + return embeddings |
| 574 | + |
| 575 | + self.transformer.text_model.embeddings.forward = embedding_forward.__get__( |
| 576 | + self.transformer.text_model.embeddings |
| 577 | + ) |
| 578 | + |
| 579 | + def encoder_forward( |
| 580 | + self, |
| 581 | + inputs_embeds, |
| 582 | + attention_mask=None, |
| 583 | + causal_attention_mask=None, |
| 584 | + output_attentions=None, |
| 585 | + output_hidden_states=None, |
| 586 | + return_dict=None, |
| 587 | + ): |
| 588 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 589 | + output_hidden_states = ( |
| 590 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 591 | + ) |
| 592 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 593 | + encoder_states = () if output_hidden_states else None |
| 594 | + all_attentions = () if output_attentions else None |
| 595 | + hidden_states = inputs_embeds |
| 596 | + for idx, encoder_layer in enumerate(self.layers): |
| 597 | + if output_hidden_states: |
| 598 | + encoder_states = encoder_states + (hidden_states,) |
| 599 | + layer_outputs = encoder_layer( |
| 600 | + hidden_states, |
| 601 | + attention_mask, |
| 602 | + causal_attention_mask, |
| 603 | + output_attentions=output_attentions, |
| 604 | + ) |
| 605 | + hidden_states = layer_outputs[0] |
| 606 | + if output_attentions: |
| 607 | + all_attentions = all_attentions + (layer_outputs[1],) |
| 608 | + if output_hidden_states: |
| 609 | + encoder_states = encoder_states + (hidden_states,) |
| 610 | + return hidden_states |
| 611 | + |
| 612 | + self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder) |
| 613 | + |
| 614 | + def text_encoder_forward( |
| 615 | + self, |
| 616 | + input_ids=None, |
| 617 | + attention_mask=None, |
| 618 | + position_ids=None, |
| 619 | + output_attentions=None, |
| 620 | + output_hidden_states=None, |
| 621 | + return_dict=None, |
| 622 | + embedding_manager=None, |
| 623 | + ): |
| 624 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 625 | + output_hidden_states = ( |
| 626 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 627 | + ) |
| 628 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 629 | + if input_ids is None: |
| 630 | + raise ValueError("You have to specify either input_ids") |
| 631 | + input_shape = input_ids.size() |
| 632 | + input_ids = input_ids.view(-1, input_shape[-1]) |
| 633 | + hidden_states = self.embeddings( |
| 634 | + input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager |
| 635 | + ) |
| 636 | + # CLIP's text model uses causal mask, prepare it here. |
| 637 | + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 |
| 638 | + causal_attention_mask = _create_4d_causal_attention_mask( |
| 639 | + input_shape, hidden_states.dtype, device=hidden_states.device |
| 640 | + ) |
| 641 | + # expand attention_mask |
| 642 | + if attention_mask is not None: |
| 643 | + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] |
| 644 | + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) |
| 645 | + last_hidden_state = self.encoder( |
| 646 | + inputs_embeds=hidden_states, |
| 647 | + attention_mask=attention_mask, |
| 648 | + causal_attention_mask=causal_attention_mask, |
| 649 | + output_attentions=output_attentions, |
| 650 | + output_hidden_states=output_hidden_states, |
| 651 | + return_dict=return_dict, |
| 652 | + ) |
| 653 | + last_hidden_state = self.final_layer_norm(last_hidden_state) |
| 654 | + return last_hidden_state |
| 655 | + |
| 656 | + self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model) |
| 657 | + |
| 658 | + def transformer_forward( |
| 659 | + self, |
| 660 | + input_ids=None, |
| 661 | + attention_mask=None, |
| 662 | + position_ids=None, |
| 663 | + output_attentions=None, |
| 664 | + output_hidden_states=None, |
| 665 | + return_dict=None, |
| 666 | + embedding_manager=None, |
| 667 | + ): |
| 668 | + return self.text_model( |
| 669 | + input_ids=input_ids, |
| 670 | + attention_mask=attention_mask, |
| 671 | + position_ids=position_ids, |
| 672 | + output_attentions=output_attentions, |
| 673 | + output_hidden_states=output_hidden_states, |
| 674 | + return_dict=return_dict, |
| 675 | + embedding_manager=embedding_manager, |
| 676 | + ) |
| 677 | + |
| 678 | + self.transformer.forward = transformer_forward.__get__(self.transformer) |
| 679 | + |
| 680 | + def freeze(self): |
| 681 | + self.transformer = self.transformer.eval() |
| 682 | + for param in self.parameters(): |
| 683 | + param.requires_grad = False |
| 684 | + |
| 685 | + def forward(self, text, **kwargs): |
| 686 | + batch_encoding = self.tokenizer( |
| 687 | + text, |
| 688 | + truncation=False, |
| 689 | + max_length=self.max_length, |
| 690 | + return_length=True, |
| 691 | + return_overflowing_tokens=False, |
| 692 | + padding="longest", |
| 693 | + return_tensors="pt", |
| 694 | + ) |
| 695 | + input_ids = batch_encoding["input_ids"] |
| 696 | + tokens_list = self.split_chunks(input_ids) |
| 697 | + z_list = [] |
| 698 | + for tokens in tokens_list: |
| 699 | + tokens = tokens.to(self.device) |
| 700 | + _z = self.transformer(input_ids=tokens, **kwargs) |
| 701 | + z_list += [_z] |
| 702 | + return torch.cat(z_list, dim=1) |
| 703 | + |
| 704 | + def encode(self, text, **kwargs): |
| 705 | + return self(text, **kwargs) |
| 706 | + |
| 707 | + def split_chunks(self, input_ids, chunk_size=75): |
| 708 | + tokens_list = [] |
| 709 | + bs, n = input_ids.shape |
| 710 | + id_start = input_ids[:, 0].unsqueeze(1) # dim --> [bs, 1] |
| 711 | + id_end = input_ids[:, -1].unsqueeze(1) |
| 712 | + if n == 2: # empty caption |
| 713 | + tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1)) |
| 714 | + |
| 715 | + trimmed_encoding = input_ids[:, 1:-1] |
| 716 | + num_full_groups = (n - 2) // chunk_size |
| 717 | + |
| 718 | + for i in range(num_full_groups): |
| 719 | + group = trimmed_encoding[:, i * chunk_size : (i + 1) * chunk_size] |
| 720 | + group_pad = torch.cat((id_start, group, id_end), dim=1) |
| 721 | + tokens_list.append(group_pad) |
| 722 | + |
| 723 | + remaining_columns = (n - 2) % chunk_size |
| 724 | + if remaining_columns > 0: |
| 725 | + remaining_group = trimmed_encoding[:, -remaining_columns:] |
| 726 | + padding_columns = chunk_size - remaining_group.shape[1] |
| 727 | + padding = id_end.expand(bs, padding_columns) |
| 728 | + remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1) |
| 729 | + tokens_list.append(remaining_group_pad) |
| 730 | + return tokens_list |
| 731 | + |
| 732 | + def to(self, *args, **kwargs): |
| 733 | + self.transformer = self.transformer.to(*args, **kwargs) |
| 734 | + self.device = self.transformer.device |
| 735 | + return self |
| 736 | + |
| 737 | + |
523 | 738 | class TextEmbeddingModule(nn.Module): |
524 | 739 | def __init__(self, font_path, use_fp16=False, device="cpu"): |
525 | 740 | super().__init__() |
|
0 commit comments