|
8 | 8 | from torch import Tensor, nn |
9 | 9 | from torch.nn import functional as F, grad |
10 | 10 |
|
11 | | -from truegrad.functional import add, einsum, matmul, mul, reshape |
| 11 | +from truegrad.functional import add, chunk, einsum, matmul, mul, reshape, split, transpose |
12 | 12 |
|
13 | 13 | _torch_functional = {k: getattr(F, k) for k in dir(F)} |
14 | 14 | _torch = {k: getattr(torch, k) for k in dir(torch)} |
@@ -551,7 +551,7 @@ def leaky_relu_(input: Tensor, negative_slope: float = 0.01): |
551 | 551 |
|
552 | 552 | @call_torch |
553 | 553 | def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]): |
554 | | - input = matmul(input, weight) |
| 554 | + input = matmul(input, transpose(weight, (0, 1))) |
555 | 555 | if bias is None: |
556 | 556 | return input |
557 | 557 | return add(input, bias) |
@@ -641,21 +641,21 @@ def _in_projection_packed( |
641 | 641 | if k is v: |
642 | 642 | if q is k: |
643 | 643 | # self-attention |
644 | | - return linear(q, w, b).chunk(3, dim=-1) |
| 644 | + return linear(q, w, b).chunk(3, -1) |
645 | 645 | else: |
646 | 646 | # encoder-decoder attention |
647 | | - w_q, w_kv = w.split([E, E * 2]) |
| 647 | + w_q, w_kv = split(w, [E, E * 2], 0) |
648 | 648 | if b is None: |
649 | 649 | b_q = b_kv = None |
650 | 650 | else: |
651 | | - b_q, b_kv = b.split([E, E * 2]) |
652 | | - return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, dim=-1) |
| 651 | + b_q, b_kv = split(b, [E, E * 2], 0) |
| 652 | + return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, -1) |
653 | 653 | else: |
654 | | - w_q, w_k, w_v = w.chunk(3) |
| 654 | + w_q, w_k, w_v = chunk(w, 3, 0) |
655 | 655 | if b is None: |
656 | 656 | b_q = b_k = b_v = None |
657 | 657 | else: |
658 | | - b_q, b_k, b_v = b.chunk(3) |
| 658 | + b_q, b_k, b_v = chunk(b, 3, 0) |
659 | 659 | return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) |
660 | 660 |
|
661 | 661 |
|
@@ -965,7 +965,7 @@ def multi_head_attention_forward(query: Tensor, key: Tensor, value: Tensor, embe |
965 | 965 | if in_proj_bias is None: |
966 | 966 | b_q = b_k = b_v = None |
967 | 967 | else: |
968 | | - b_q, b_k, b_v = in_proj_bias.chunk(3) |
| 968 | + b_q, b_k, b_v = chunk(in_proj_bias, 3, 0) |
969 | 969 | q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) |
970 | 970 |
|
971 | 971 | # prep attention mask |
|
0 commit comments