Skip to content

Commit 50947f6

Browse files
hanwen-sunStrivin0311
authored andcommitted
update transformers example for v1.1.0 (#260)
1 parent 8ca0dd4 commit 50947f6

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

examples/transformers/README.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,18 @@ def _prepare_inputs():
101101
+ head_dim=self.model.config.head_dim,
102102
+ pad_size=pad_size,
103103
+ )
104+
+
105+
+ # Propagate cp_group to all attention modules (needed by magi_attention_forward)
106+
+ if not getattr(self, "_cp_group_propagated", False):
107+
+ cp_group = self.cp_group
108+
+ unwrapped_model = (
109+
+ self.model.module if hasattr(self.model, "module") else self.model
110+
+ )
111+
+ for module in unwrapped_model.modules():
112+
+ if "Attention" in type(module).__name__:
113+
+ module.cp_group = cp_group
114+
+ self._cp_group_propagated = True
115+
+
104116
+ position_ids = get_position_ids(magi_attn_key).unsqueeze(0)
105117
+
106118
+ inputs["position_ids"] = position_ids
@@ -142,7 +154,7 @@ def _prepare_inputs():
142154
+ x_padded = dispatch(inputs, key=dist_attn_runtime_key)
143155
+ x_padded = x_padded.unsqueeze(0)
144156
+
145-
+ return x_padded, dist_attn_runtime_key
157+
+ return x_padded, dist_attn_runtime_key
146158
```
147159

148160
Override `compute_loss` because we need to undispatch logits first:
@@ -152,7 +164,7 @@ def compute_loss():
152164
outputs = model(**inputs)
153165
+ logits = outputs.logits
154166

155-
+ magi_attn_key = get_magi_attention_key()
167+
+ magi_attn_key = get_most_recent_key(self.cp_group)
156168
+ if magi_attn_key is not None:
157169
+ logits = squash_batch_dim(logits)
158170

@@ -205,7 +217,7 @@ trainer.train()
205217
```
206218

207219
### Register Magi_Attention implementation
208-
The following code are all avaliable at Magi_attention.py.
220+
The following code are all available at `magi_attention_func.py`.
209221

210222
What's more, MagiAttention provides a new type of attention implenmentation(flexible flash_attention), so we need to register it for use:
211223
``` python

examples/transformers/magi_trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,17 @@ def _prepare_inputs(
400400
pad_size=pad_size,
401401
)
402402

403+
# Propagate cp_group to all attention modules (needed by magi_attention_forward)
404+
if not getattr(self, "_cp_group_propagated", False):
405+
cp_group = self.cp_group
406+
unwrapped_model = (
407+
self.model.module if hasattr(self.model, "module") else self.model
408+
)
409+
for module in unwrapped_model.modules():
410+
if "Attention" in type(module).__name__:
411+
module.cp_group = cp_group
412+
self._cp_group_propagated = True
413+
403414
position_ids = get_position_ids(magi_attn_key).unsqueeze(0)
404415

405416
inputs["position_ids"] = position_ids

0 commit comments

Comments
 (0)