44import warnings
55from collections .abc import Callable , Mapping , MutableSequence , Sequence
66from functools import partial , reduce
7- from typing import TYPE_CHECKING , Literal , TypeVar , Union
7+ from typing import TYPE_CHECKING , Literal , TypeVar , Union , overload
88
99import numpy as np
1010
@@ -414,6 +414,32 @@ def Lop(
414414 return as_list_or_tuple (using_list , using_tuple , ret )
415415
416416
417+ @overload
418+ def grad (
419+ cost : Variable | None ,
420+ wrt : Variable | Sequence [Variable ],
421+ consider_constant : Sequence [Variable ] | None = ...,
422+ disconnected_inputs : Literal ["ignore" , "warn" , "raise" ] = ...,
423+ add_names : bool = ...,
424+ known_grads : Mapping [Variable , Variable ] | None = ...,
425+ return_disconnected : Literal ["zero" , "disconnected" ] = ...,
426+ null_gradients : Literal ["raise" , "return" ] = ...,
427+ ) -> Variable | None | Sequence [Variable ]: ...
428+
429+
430+ @overload
431+ def grad (
432+ cost : Variable | None ,
433+ wrt : Variable | Sequence [Variable ],
434+ consider_constant : Sequence [Variable ] | None = ...,
435+ disconnected_inputs : Literal ["ignore" , "warn" , "raise" ] = ...,
436+ add_names : bool = ...,
437+ known_grads : Mapping [Variable , Variable ] | None = ...,
438+ return_disconnected : Literal ["none" ] = ...,
439+ null_gradients : Literal ["raise" , "return" ] = ...,
440+ ) -> Variable | None | Sequence [Variable | None ]: ...
441+
442+
417443def grad (
418444 cost : Variable | None ,
419445 wrt : Variable | Sequence [Variable ],
@@ -423,7 +449,7 @@ def grad(
423449 known_grads : Mapping [Variable , Variable ] | None = None ,
424450 return_disconnected : Literal ["none" , "zero" , "disconnected" ] = "zero" ,
425451 null_gradients : Literal ["raise" , "return" ] = "raise" ,
426- ) -> Variable | None | Sequence [Variable | None ]:
452+ ) -> Variable | None | Sequence [Variable | None ] | Sequence [ Variable ] :
427453 """
428454 Return symbolic gradients of one cost with respect to one or more variables.
429455
0 commit comments