Skip to content

Commit 8f71ae5

Browse files
committed
use kernel
1 parent 668afb6 commit 8f71ae5

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

accelerated_scan/warp.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
]
2828
)
2929
warpscan_forward = module.warpscan_forward
30+
warpscan_backward = module.warpscan_backward
3031

3132
def scan_forward(gates, tokens, reverse=False):
3233
output = torch.zeros_like(tokens)
@@ -57,13 +58,10 @@ def backward(ctx, grad_output):
5758
assert states.is_contiguous()
5859
assert gates.is_contiguous()
5960

60-
padded_shifted_gates = torch.cat([gates, torch.ones_like(gates[:, :, :1])], dim=-1)[:, :, 1:].contiguous()
61-
d_states = scan_forward(padded_shifted_gates, grad_output, reverse=True)
61+
d_gates = torch.empty_like(gates)
62+
d_tokens = torch.empty_like(gates)
63+
warpscan_backward(gates, states, grad_output, d_gates, d_tokens)
6264

63-
padded_outputs = torch.cat([torch.zeros_like(states[:, :, :1]), states], dim=-1)[:, :, :-1]
64-
d_gates = padded_outputs * d_states
65-
66-
d_tokens = d_states
6765
return d_gates, d_tokens
6866

6967

0 commit comments

Comments
 (0)