Skip to content

Commit f9926c7

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
[XLA:Python] Add missing type stubs.
PiperOrigin-RevId: 738949498
1 parent 4dd2860 commit f9926c7

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

xla/python/xla_extension/__init__.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ from . import profiler
4949
from . import pytree
5050
from . import transfer_guard_lib
5151

52+
custom_call_targets = Any
53+
hlo_sharding_util = Any
54+
5255
_LiteralSlice = Any
5356
_Status = Any
5457
_Dtype = Any
@@ -300,6 +303,8 @@ def register_custom_call_as_batch_partitionable(
300303
c_api: Optional[Any] = ...,
301304
) -> None: ...
302305

306+
def register_custom_type_id(type_name: str, type_id: Any) -> None: ...
307+
303308
class AutotuneCacheMode(enum.IntEnum):
304309
UNSPECIFIED: AutotuneCacheMode
305310
UPDATE: AutotuneCacheMode
@@ -626,6 +631,7 @@ def get_gpu_client(
626631
allowed_devices: Optional[Any] = ...,
627632
platform_name: Optional[str] = ...,
628633
mock: Optional[bool] = ...,
634+
mock_gpu_topology: Optional[str] = ...,
629635
) -> Client: ...
630636
def get_mock_gpu_client(
631637
asynchronous: bool = ...,
@@ -645,6 +651,11 @@ def get_default_c_api_topology(
645651
topology_name: str,
646652
options: Dict[str, Union[str, int, List[int], float]],
647653
) -> DeviceTopology: ...
654+
def get_c_api_topology(
655+
c_api: Any,
656+
topology_name: str,
657+
options: Dict[str, Union[str, int, List[int], float]],
658+
) -> DeviceTopology: ...
648659
def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: ...
649660
def load_pjrt_plugin(platform_name: str, library_path: Optional[str], c_api: Optional[Any]) -> _Status: ...
650661
def pjrt_plugin_loaded(plugin_name: str) -> bool: ...

0 commit comments

Comments
 (0)