Skip to content

Commit 17592e8

Browse files
committed
[Feature Enhancement] Add a hack extractor script to solve vmap and autocast
1 parent bdfbfeb commit 17592e8

31 files changed

+51036
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import torch
2+
from transformers import AutoModel, AutoTokenizer, AutoConfig
3+
import graph_net.torch
4+
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
5+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
6+
from transformers.integrations.executorch import sdpa_mask_without_vmap
7+
8+
9+
def get_model_name():
10+
return "Qwen/Qwen2.5-3B"
11+
12+
13+
def create_model():
14+
config = AutoConfig.from_pretrained(get_model_name())
15+
model = AutoModel.from_config(config)
16+
# https://github.com/huggingface/transformers/blob/6b5bd117231f969713ed79fd4870903ab3c93edf/docs/source/en/attention_interface.md?plain=1#L194-L195
17+
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
18+
ALL_ATTENTION_FUNCTIONS.register(
19+
"sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]
20+
)
21+
model.config._attn_implementation = "sdpa_without_vmap"
22+
model.eval()
23+
return model.to(device)
24+
25+
26+
if __name__ == "__main__":
27+
tokenizer = AutoTokenizer.from_pretrained(get_model_name())
28+
29+
text = "Hello world"
30+
inputs = tokenizer(text, return_tensors="pt")
31+
32+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33+
inputs = {k: v.to(device) for k, v in inputs.items()}
34+
35+
class DummyAutocast:
36+
def __init__(self, *args, **kwargs):
37+
pass
38+
39+
def __enter__(self):
40+
pass
41+
42+
def __exit__(self, *args):
43+
pass
44+
45+
class PatchedModel(torch.nn.Module):
46+
def __init__(self, model):
47+
super().__init__()
48+
self.model = model
49+
50+
def forward(self, *args, **kwargs):
51+
# TODO
52+
original_autocast1 = torch.cuda.amp.autocast
53+
original_autocast2 = torch.amp.autocast
54+
original_autocast3 = torch.autocast
55+
try:
56+
torch.cuda.amp.autocast = DummyAutocast
57+
torch.amp.autocast = DummyAutocast
58+
torch.autocast = DummyAutocast
59+
return self.model.forward(*args, **kwargs)
60+
finally:
61+
torch.cuda.amp.autocast = original_autocast1
62+
torch.amp.autocast = original_autocast2
63+
torch.autocast = original_autocast3
64+
65+
model = create_model()
66+
model = PatchedModel(model)
67+
model = graph_net.torch.extract(name=get_model_name(), dynamic=False)(model)
68+
69+
print("Running inference...")
70+
output = model(**inputs)
71+
print("Inference finished. Output shape:", output.last_hidden_state.shape)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
70d4001ac71bdeb9de5fe851688032c3ba4cd2b9abf22950fc2061c98358408b
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"framework": "torch",
3+
"num_devices_required": 1,
4+
"num_nodes_required": 1,
5+
"dynamic": false
6+
}

samples/transformers-auto-model/Qwen/Qwen2.5-0.5B/input_meta.py

Whitespace-only changes.

samples/transformers-auto-model/Qwen/Qwen2.5-0.5B/input_tensor_constraints.py

Whitespace-only changes.

samples/transformers-auto-model/Qwen/Qwen2.5-0.5B/model.py

Lines changed: 5275 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)