|
9 | 9 | vecdot as _aliases_vecdot)
|
10 | 10 | from .._internal import get_xp
|
11 | 11 |
|
| 12 | +import torch |
| 13 | + |
12 | 14 | from typing import TYPE_CHECKING
|
13 | 15 | if TYPE_CHECKING:
|
14 | 16 | from typing import List, Optional, Sequence, Tuple, Union
|
15 | 17 | from ..common._typing import Device
|
16 | 18 | from torch import dtype as Dtype
|
17 | 19 |
|
18 |
| -import torch |
19 |
| -array = torch.Tensor |
| 20 | + array = torch.Tensor |
20 | 21 |
|
21 | 22 | _int_dtypes = {
|
22 | 23 | torch.uint8,
|
@@ -547,6 +548,14 @@ def empty(shape: Union[int, Tuple[int, ...]],
|
547 | 548 | **kwargs) -> array:
|
548 | 549 | return torch.empty(shape, dtype=dtype, device=device, **kwargs)
|
549 | 550 |
|
| 551 | +# tril and triu do not call the keyword argument k |
| 552 | + |
| 553 | +def tril(x: array, /, *, k: int = 0) -> array: |
| 554 | + return torch.tril(x, k) |
| 555 | + |
| 556 | +def triu(x: array, /, *, k: int = 0) -> array: |
| 557 | + return torch.triu(x, k) |
| 558 | + |
550 | 559 | # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
|
551 | 560 | def expand_dims(x: array, /, *, axis: int = 0) -> array:
|
552 | 561 | return torch.unsqueeze(x, axis)
|
@@ -651,6 +660,7 @@ def isdtype(
|
651 | 660 | 'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
|
652 | 661 | 'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll',
|
653 | 662 | 'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones',
|
654 |
| - 'zeros', 'empty', 'expand_dims', 'astype', 'broadcast_arrays', |
655 |
| - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', |
656 |
| - 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype'] |
| 663 | + 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', |
| 664 | + 'broadcast_arrays', 'unique_all', 'unique_counts', |
| 665 | + 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', |
| 666 | + 'vecdot', 'tensordot', 'isdtype'] |
0 commit comments