@@ -49,6 +49,9 @@ from . import profiler
4949from . import pytree
5050from . import transfer_guard_lib
5151
52+ custom_call_targets = Any
53+ hlo_sharding_util = Any
54+
5255_LiteralSlice = Any
5356_Status = Any
5457_Dtype = Any
@@ -300,6 +303,8 @@ def register_custom_call_as_batch_partitionable(
300303 c_api : Optional [Any ] = ...,
301304) -> None : ...
302305
306+ def register_custom_type_id (type_name : str , type_id : Any ) -> None : ...
307+
303308class AutotuneCacheMode (enum .IntEnum ):
304309 UNSPECIFIED : AutotuneCacheMode
305310 UPDATE : AutotuneCacheMode
@@ -626,6 +631,7 @@ def get_gpu_client(
626631 allowed_devices : Optional [Any ] = ...,
627632 platform_name : Optional [str ] = ...,
628633 mock : Optional [bool ] = ...,
634+ mock_gpu_topology : Optional [str ] = ...,
629635) -> Client : ...
630636def get_mock_gpu_client (
631637 asynchronous : bool = ...,
@@ -645,6 +651,11 @@ def get_default_c_api_topology(
645651 topology_name : str ,
646652 options : Dict [str , Union [str , int , List [int ], float ]],
647653) -> DeviceTopology : ...
654+ def get_c_api_topology (
655+ c_api : Any ,
656+ topology_name : str ,
657+ options : Dict [str , Union [str , int , List [int ], float ]],
658+ ) -> DeviceTopology : ...
648659def get_topology_for_devices (devices : List [Device ]) -> DeviceTopology : ...
649660def load_pjrt_plugin (platform_name : str , library_path : Optional [str ], c_api : Optional [Any ]) -> _Status : ...
650661def pjrt_plugin_loaded (plugin_name : str ) -> bool : ...
0 commit comments