Skip to content

Commit 955cdf5

Browse files
authored
1 parent 2e85a70 commit 955cdf5

File tree

3 files changed

+54
-0
lines changed

3 files changed

+54
-0
lines changed

stubs/tensorflow/tensorflow/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ from tensorflow.dtypes import *
3838
from tensorflow.dtypes import DType as DType
3939
from tensorflow.experimental.dtensor import Layout
4040
from tensorflow.keras import losses as losses
41+
from tensorflow.linalg import eye as eye
4142

4243
# Most tf.math functions are exported as tf, but sadly not all are.
4344
from tensorflow.math import (

stubs/tensorflow/tensorflow/_aliases.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class KerasSerializable2(Protocol):
2727

2828
KerasSerializable: TypeAlias = KerasSerializable1 | KerasSerializable2
2929

30+
Integer: TypeAlias = tf.Tensor | int | IntArray | np.number[Any] # Here tf.Tensor and IntArray are assumed to be 0D.
3031
Slice: TypeAlias = int | slice | None
3132
FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence]
3233
IntDataSequence: TypeAlias = Sequence[int] | Sequence[IntDataSequence]
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from _typeshed import Incomplete
2+
from builtins import bool as _bool
3+
from collections.abc import Iterable
4+
from typing import Literal, overload
5+
6+
import tensorflow as tf
7+
from tensorflow import RaggedTensor, Tensor, norm as norm
8+
from tensorflow._aliases import DTypeLike, IntArray, Integer, ScalarTensorCompatible, TensorCompatible
9+
from tensorflow.math import l2_normalize as l2_normalize
10+
11+
@overload
12+
def matmul(
13+
a: TensorCompatible,
14+
b: TensorCompatible,
15+
transpose_a: _bool = False,
16+
transpose_b: _bool = False,
17+
adjoint_a: _bool = False,
18+
adjoint_b: _bool = False,
19+
a_is_sparse: _bool = False,
20+
b_is_sparse: _bool = False,
21+
output_type: DTypeLike | None = None,
22+
name: str | None = None,
23+
) -> Tensor: ...
24+
@overload
25+
def matmul(
26+
a: RaggedTensor,
27+
b: RaggedTensor,
28+
transpose_a: _bool = False,
29+
transpose_b: _bool = False,
30+
adjoint_a: _bool = False,
31+
adjoint_b: _bool = False,
32+
a_is_sparse: _bool = False,
33+
b_is_sparse: _bool = False,
34+
output_type: DTypeLike | None = None,
35+
name: str | None = None,
36+
) -> RaggedTensor: ...
37+
def set_diag(
38+
input: TensorCompatible,
39+
diagonal: TensorCompatible,
40+
name: str | None = "set_diag",
41+
k: int = 0,
42+
align: Literal["RIGHT_LEFT", "RIGHT_RIGHT", "LEFT_LEFT", "LEFT_RIGHT"] = "RIGHT_LEFT",
43+
) -> Tensor: ...
44+
def eye(
45+
num_rows: ScalarTensorCompatible,
46+
num_columns: ScalarTensorCompatible | None = None,
47+
batch_shape: Iterable[int] | IntArray | tf.Tensor | None = None,
48+
dtype: DTypeLike = ...,
49+
name: str | None = None,
50+
) -> Tensor: ...
51+
def band_part(input: TensorCompatible, num_lower: Integer, num_upper: Integer, name: str | None = None) -> Tensor: ...
52+
def __getattr__(name: str) -> Incomplete: ...

0 commit comments

Comments
 (0)