Skip to content

Commit 2da2329

Browse files
SpEcHiDeZeN220
andcommitted
Add generic return type for method invoke
KurimuzonAkuma/kurigram#252 Co-authored-by: ZeN220 <[email protected]>
1 parent 6b521b4 commit 2da2329

File tree

4 files changed

+36
-9
lines changed

4 files changed

+36
-9
lines changed

compiler/api/compiler.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,19 @@ def camel(s: str):
9898
return "".join([i[0].upper() + i[1:] for i in s.split("_")])
9999

100100

101+
# noinspection PyShadowingBuiltins, PyShadowingNames
102+
def get_return_type_hint(qualtype: str) -> str:
103+
"""Get return type hint for generic TLObject"""
104+
if qualtype.startswith("Vector"):
105+
# Extract inner type from Vector<Type>
106+
inner = qualtype.split("<")[1][:-1]
107+
ns, name = inner.split(".") if "." in inner else ("", inner)
108+
return f'"List[raw.base.{".".join([ns, name]).strip(".")}]"'
109+
else:
110+
ns, name = qualtype.split(".") if "." in qualtype else ("", qualtype)
111+
return f'"raw.base.{".".join([ns, name]).strip(".")}"'
112+
113+
101114
# noinspection PyShadowingBuiltins, PyShadowingNames
102115
def get_type_hint(type: str) -> str:
103116
is_flag = FLAGS_RE.match(type)
@@ -546,6 +559,12 @@ def start(format: bool = False):
546559
slots = ", ".join([f'"{i[0]}"' for i in sorted_args])
547560
return_arguments = ", ".join([f"{i[0]}={i[0]}" for i in sorted_args])
548561

562+
# Generate generic type hint for functions
563+
if c.section == "functions":
564+
generic_type = f"[{get_return_type_hint(c.qualtype)}]"
565+
else:
566+
generic_type = ""
567+
549568
compiled_combinator = combinator_tmpl.format(
550569
notice=notice,
551570
warning=WARNING,
@@ -558,7 +577,8 @@ def start(format: bool = False):
558577
fields=fields,
559578
read_types=read_types,
560579
write_types=write_types,
561-
return_arguments=return_arguments
580+
return_arguments=return_arguments,
581+
generic_type=generic_type
562582
)
563583

564584
directory = "types" if c.section == "types" else c.section

compiler/api/template/combinator.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
{notice}
22

33
from io import BytesIO
4+
from typing import TYPE_CHECKING, Optional, Any
45

56
from pyrogram.raw.core.primitives import Int, Long, Int128, Int256, Bool, Bytes, String, Double, Vector
67
from pyrogram.raw.core import TLObject
7-
from pyrogram import raw
8-
from typing import Optional, Any
8+
9+
if TYPE_CHECKING:
10+
from pyrogram import raw
911

1012
{warning}
1113

1214

13-
class {name}(TLObject): # type: ignore
15+
class {name}(TLObject{generic_type}):
1416
"""{docstring}
1517
"""
1618

pyrogram/methods/advanced/invoke.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

1919
import logging
20+
from typing import TypeVar
2021

2122
import pyrogram
2223
from pyrogram import raw
@@ -25,15 +26,17 @@
2526

2627
log = logging.getLogger(__name__)
2728

29+
ReturnType = TypeVar('ReturnType')
30+
2831

2932
class Invoke:
3033
async def invoke(
3134
self: "pyrogram.Client",
32-
query: TLObject,
35+
query: TLObject[ReturnType],
3336
retries: int = Session.MAX_RETRIES,
3437
timeout: float = Session.WAIT_TIMEOUT,
3538
sleep_threshold: float = None
36-
):
39+
) -> ReturnType:
3740
"""Invoke raw Telegram functions.
3841
3942
This method makes it possible to manually call every single Telegram API method in a low-level manner.

pyrogram/raw/core/tl_object.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818

1919
from io import BytesIO
2020
from json import dumps
21-
from typing import cast, Any, Union
21+
from typing import cast, Any, Union, TypeVar, Generic
2222

2323
from ..all import objects
2424

25+
ReturnType = TypeVar("ReturnType")
2526

26-
class TLObject:
27+
28+
class TLObject(Generic[ReturnType]):
2729
__slots__: list[str] = []
2830

2931
QUALNAME = "Base"
@@ -78,5 +80,5 @@ def __eq__(self, other: Any) -> bool:
7880
def __len__(self) -> int:
7981
return len(self.write())
8082

81-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
83+
def __call__(self, *args: Any, **kwargs: Any) -> ReturnType:
8284
pass

0 commit comments

Comments
 (0)