Skip to content

Commit 9082638

Browse files
Fix return type hints
Closes #40
1 parent b887bc1 commit 9082638

File tree

3 files changed

+10
-16
lines changed

3 files changed

+10
-16
lines changed

pytensor_federated/common.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ def __call__(self, *inputs: Sequence[np.ndarray]) -> np.ndarray:
7878
"""Alias for ``.evaluate(*inputs)``."""
7979
return self.evaluate(*inputs)
8080

81-
def evaluate(
82-
self, *inputs: Sequence[np.ndarray], use_stream=True
83-
) -> Tuple[np.ndarray, Sequence[np.ndarray]]:
81+
def evaluate(self, *inputs: Sequence[np.ndarray], use_stream=True) -> np.ndarray:
8482
"""Evaluate the federated blackbox logp gradient function on inputs.
8583
8684
Parameters
@@ -99,9 +97,7 @@ def evaluate(
9997
(logp,) = self._client.evaluate(*inputs, use_stream=use_stream)
10098
return logp
10199

102-
async def evaluate_async(
103-
self, *inputs: Sequence[np.ndarray], use_stream=True
104-
) -> Tuple[np.ndarray, Sequence[np.ndarray]]:
100+
async def evaluate_async(self, *inputs: Sequence[np.ndarray], use_stream=True) -> np.ndarray:
105101
(logp,) = await self._client.evaluate_async(*inputs, use_stream=use_stream)
106102
return logp
107103

pytensor_federated/service.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
import uuid
77
from typing import (
88
TYPE_CHECKING,
9-
Any,
109
AsyncIterator,
11-
Coroutine,
1210
Dict,
1311
List,
1412
Optional,
@@ -190,7 +188,7 @@ async def get_load_async(host: str, port: int, timeout: float = 5) -> Optional[G
190188

191189
async def get_loads_async(
192190
hosts_and_ports: Sequence[Tuple[str, int]], *, timeout: float = 5
193-
) -> Sequence[Optional[GetLoadResult]]:
191+
) -> List[Optional[GetLoadResult]]:
194192
"""Retrieve load information from all servers that respond within a timeout.
195193
196194
Parameters
@@ -366,18 +364,18 @@ def __del__(self):
366364
del _privates[_id]
367365
return
368366

369-
def __call__(self, *inputs: Sequence[np.ndarray]) -> Sequence[np.ndarray]:
367+
def __call__(self, *inputs: Sequence[np.ndarray]) -> List[np.ndarray]:
370368
"""Alias for ``.evaluate(*inputs)``."""
371369
return self.evaluate(*inputs)
372370

373-
def evaluate(self, *inputs: Sequence[np.ndarray], **kwargs) -> Sequence[np.ndarray]:
371+
def evaluate(self, *inputs: Sequence[np.ndarray], **kwargs) -> List[np.ndarray]:
374372
loop = get_useful_event_loop()
375373
eval_coro = self.evaluate_async(*inputs, **kwargs)
376374
return loop.run_until_complete(eval_coro)
377375

378376
async def evaluate_async(
379377
self, *inputs: Sequence[np.ndarray], use_stream: bool = True, retries: int = 2
380-
) -> Sequence[np.ndarray]:
378+
) -> List[np.ndarray]:
381379
"""Evaluate the federated compute function on inputs.
382380
383381
Parameters
@@ -390,8 +388,8 @@ async def evaluate_async(
390388
391389
Returns
392390
-------
393-
*outputs
394-
Sequence of ``ndarray``s returned by the federated compute function.
391+
outputs
392+
List of ``ndarray``s returned by the federated compute function.
395393
"""
396394
if retries < 0:
397395
raise ValueError("Number of retries must be >= 0.")

pytensor_federated/test_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import multiprocessing
22
import platform
33
import time
4-
from typing import Sequence
4+
from typing import Sequence, Tuple
55
from unittest import mock
66

77
import grpclib
@@ -42,7 +42,7 @@ def compute_fun(a, b):
4242
pass
4343

4444

45-
def product_func(*inputs: Sequence[np.ndarray]) -> Sequence[np.ndarray]:
45+
def product_func(*inputs: Sequence[np.ndarray]) -> Tuple[np.ndarray]:
4646
"""Calculates the product of NumPy arrays"""
4747
return (np.prod(inputs),)
4848

0 commit comments

Comments
 (0)