Skip to content

Commit 6a6d047

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Leverage __call__ impl of nn Module instead of calling forward on attention
Summary: In the current llama transformer definition we explicitly invoke forward method on various attention impls. This prevents us from leveraging register_forward_hook which explicitly gets called only via __call__ override here https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1781. By removing explicit call to forward we enable hooks to appropriately execute Created from CodeHub with https://fburl.com/edit-in-codehub Differential Revision: D83156099
1 parent e252353 commit 6a6d047

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
117117
return TransformerBlock(args, attention)
118118

119119
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
120-
h, attn_options_update = self.attention.forward(
120+
h, attn_options_update = self.attention(
121121
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
122122
)
123123

0 commit comments

Comments
 (0)