@@ -101,6 +101,18 @@ def _prepare_inputs():
101101+ head_dim=self.model.config.head_dim,
102102+ pad_size=pad_size,
103103+ )
104+ +
105+ + # Propagate cp_group to all attention modules (needed by magi_attention_forward)
106+ + if not getattr(self, "_cp_group_propagated", False):
107+ + cp_group = self.cp_group
108+ + unwrapped_model = (
109+ + self.model.module if hasattr(self.model, "module") else self.model
110+ + )
111+ + for module in unwrapped_model.modules():
112+ + if "Attention" in type(module).__name__:
113+ + module.cp_group = cp_group
114+ + self._cp_group_propagated = True
115+ +
104116+ position_ids = get_position_ids(magi_attn_key).unsqueeze(0)
105117+
106118+ inputs["position_ids"] = position_ids
@@ -142,7 +154,7 @@ def _prepare_inputs():
142154+ x_padded = dispatch(inputs, key=dist_attn_runtime_key)
143155+ x_padded = x_padded.unsqueeze(0)
144156+
145- + return x_padded, dist_attn_runtime_key
157+ + return x_padded, dist_attn_runtime_key
146158```
147159
148160Override ` compute_loss ` because we need to undispatch logits first:
@@ -152,7 +164,7 @@ def compute_loss():
152164 outputs = model(**inputs)
153165+ logits = outputs.logits
154166
155- + magi_attn_key = get_magi_attention_key( )
167+ + magi_attn_key = get_most_recent_key(self.cp_group )
156168+ if magi_attn_key is not None:
157169+ logits = squash_batch_dim(logits)
158170
@@ -205,7 +217,7 @@ trainer.train()
205217```
206218
207219### Register Magi_Attention implementation
208- The following code are all avaliable at Magi_attention .py.
220+ The following code are all available at ` magi_attention_func .py` .
209221
210222What's more, MagiAttention provides a new type of attention implenmentation(flexible flash_attention), so we need to register it for use:
211223``` python
0 commit comments