|
| 1 | +import torch |
| 2 | +from transformers import AutoModel, AutoTokenizer, AutoConfig |
| 3 | +import graph_net.torch |
| 4 | +from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS |
| 5 | +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
| 6 | +from transformers.integrations.executorch import sdpa_mask_without_vmap |
| 7 | + |
| 8 | + |
| 9 | +def get_model_name(): |
| 10 | + return "Qwen/Qwen2.5-3B" |
| 11 | + |
| 12 | + |
| 13 | +def create_model(): |
| 14 | + config = AutoConfig.from_pretrained(get_model_name()) |
| 15 | + model = AutoModel.from_config(config) |
| 16 | + # https://github.com/huggingface/transformers/blob/6b5bd117231f969713ed79fd4870903ab3c93edf/docs/source/en/attention_interface.md?plain=1#L194-L195 |
| 17 | + ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) |
| 18 | + ALL_ATTENTION_FUNCTIONS.register( |
| 19 | + "sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"] |
| 20 | + ) |
| 21 | + model.config._attn_implementation = "sdpa_without_vmap" |
| 22 | + model.eval() |
| 23 | + return model.to(device) |
| 24 | + |
| 25 | + |
| 26 | +if __name__ == "__main__": |
| 27 | + tokenizer = AutoTokenizer.from_pretrained(get_model_name()) |
| 28 | + |
| 29 | + text = "Hello world" |
| 30 | + inputs = tokenizer(text, return_tensors="pt") |
| 31 | + |
| 32 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 33 | + inputs = {k: v.to(device) for k, v in inputs.items()} |
| 34 | + |
| 35 | + class DummyAutocast: |
| 36 | + def __init__(self, *args, **kwargs): |
| 37 | + pass |
| 38 | + |
| 39 | + def __enter__(self): |
| 40 | + pass |
| 41 | + |
| 42 | + def __exit__(self, *args): |
| 43 | + pass |
| 44 | + |
| 45 | + class PatchedModel(torch.nn.Module): |
| 46 | + def __init__(self, model): |
| 47 | + super().__init__() |
| 48 | + self.model = model |
| 49 | + |
| 50 | + def forward(self, *args, **kwargs): |
| 51 | + # TODO |
| 52 | + original_autocast1 = torch.cuda.amp.autocast |
| 53 | + original_autocast2 = torch.amp.autocast |
| 54 | + original_autocast3 = torch.autocast |
| 55 | + try: |
| 56 | + torch.cuda.amp.autocast = DummyAutocast |
| 57 | + torch.amp.autocast = DummyAutocast |
| 58 | + torch.autocast = DummyAutocast |
| 59 | + return self.model.forward(*args, **kwargs) |
| 60 | + finally: |
| 61 | + torch.cuda.amp.autocast = original_autocast1 |
| 62 | + torch.amp.autocast = original_autocast2 |
| 63 | + torch.autocast = original_autocast3 |
| 64 | + |
| 65 | + model = create_model() |
| 66 | + model = PatchedModel(model) |
| 67 | + model = graph_net.torch.extract(name=get_model_name(), dynamic=False)(model) |
| 68 | + |
| 69 | + print("Running inference...") |
| 70 | + output = model(**inputs) |
| 71 | + print("Inference finished. Output shape:", output.last_hidden_state.shape) |
0 commit comments