|
| 1 | +from _typeshed import Incomplete |
| 2 | +from builtins import bool as _bool |
| 3 | +from collections.abc import Iterable |
| 4 | +from typing import Literal, overload |
| 5 | + |
| 6 | +import tensorflow as tf |
| 7 | +from tensorflow import RaggedTensor, Tensor, norm as norm |
| 8 | +from tensorflow._aliases import DTypeLike, IntArray, Integer, ScalarTensorCompatible, TensorCompatible |
| 9 | +from tensorflow.math import l2_normalize as l2_normalize |
| 10 | + |
| 11 | +@overload |
| 12 | +def matmul( |
| 13 | + a: TensorCompatible, |
| 14 | + b: TensorCompatible, |
| 15 | + transpose_a: _bool = False, |
| 16 | + transpose_b: _bool = False, |
| 17 | + adjoint_a: _bool = False, |
| 18 | + adjoint_b: _bool = False, |
| 19 | + a_is_sparse: _bool = False, |
| 20 | + b_is_sparse: _bool = False, |
| 21 | + output_type: DTypeLike | None = None, |
| 22 | + name: str | None = None, |
| 23 | +) -> Tensor: ... |
| 24 | +@overload |
| 25 | +def matmul( |
| 26 | + a: RaggedTensor, |
| 27 | + b: RaggedTensor, |
| 28 | + transpose_a: _bool = False, |
| 29 | + transpose_b: _bool = False, |
| 30 | + adjoint_a: _bool = False, |
| 31 | + adjoint_b: _bool = False, |
| 32 | + a_is_sparse: _bool = False, |
| 33 | + b_is_sparse: _bool = False, |
| 34 | + output_type: DTypeLike | None = None, |
| 35 | + name: str | None = None, |
| 36 | +) -> RaggedTensor: ... |
| 37 | +def set_diag( |
| 38 | + input: TensorCompatible, |
| 39 | + diagonal: TensorCompatible, |
| 40 | + name: str | None = "set_diag", |
| 41 | + k: int = 0, |
| 42 | + align: Literal["RIGHT_LEFT", "RIGHT_RIGHT", "LEFT_LEFT", "LEFT_RIGHT"] = "RIGHT_LEFT", |
| 43 | +) -> Tensor: ... |
| 44 | +def eye( |
| 45 | + num_rows: ScalarTensorCompatible, |
| 46 | + num_columns: ScalarTensorCompatible | None = None, |
| 47 | + batch_shape: Iterable[int] | IntArray | tf.Tensor | None = None, |
| 48 | + dtype: DTypeLike = ..., |
| 49 | + name: str | None = None, |
| 50 | +) -> Tensor: ... |
| 51 | +def band_part(input: TensorCompatible, num_lower: Integer, num_upper: Integer, name: str | None = None) -> Tensor: ... |
| 52 | +def __getattr__(name: str) -> Incomplete: ... |
0 commit comments