@@ -798,17 +798,17 @@ def _reset(self):
798
798
799
799
800
800
def _sanitize_input_spec (
801
- input_spec : ty .Union [BaseSpec , ty .List [str ]],
801
+ input_spec : ty .Union [SpecInfo , ty .List [str ]],
802
802
wf_name : str ,
803
- ) -> BaseSpec :
803
+ ) -> SpecInfo :
804
804
"""Makes sure the provided input specifications are valid.
805
805
806
806
If the input specification is a list of strings, this will
807
807
build a proper SpecInfo object out of it.
808
808
809
809
Parameters
810
810
----------
811
- input_spec : BaseSpec or List[str]
811
+ input_spec : SpecInfo or List[str]
812
812
Input specification to be sanitized.
813
813
814
814
wf_name : str
@@ -817,7 +817,7 @@ def _sanitize_input_spec(
817
817
818
818
Returns
819
819
-------
820
- input_spec : BaseSpec
820
+ input_spec : SpecInfo
821
821
Sanitized input specifications.
822
822
823
823
Raises
@@ -827,9 +827,7 @@ def _sanitize_input_spec(
827
827
"""
828
828
graph_checksum_input = ("_graph_checksums" , ty .Any )
829
829
if input_spec :
830
- if isinstance (input_spec , BaseSpec ):
831
- return input_spec
832
- elif isinstance (input_spec , SpecInfo ):
830
+ if isinstance (input_spec , SpecInfo ):
833
831
if not any ([x == BaseSpec for x in input_spec .bases ]):
834
832
raise ValueError ("Provided SpecInfo must have BaseSpec as it's base." )
835
833
if "_graph_checksums" not in {f [0 ] for f in input_spec .fields }:
@@ -866,7 +864,7 @@ def __init__(
866
864
audit_flags : AuditFlag = AuditFlag .NONE ,
867
865
cache_dir = None ,
868
866
cache_locations = None ,
869
- input_spec : ty .Optional [ty .Union [ty .List [ty .Text ], SpecInfo , BaseSpec ]] = None ,
867
+ input_spec : ty .Optional [ty .Union [ty .List [ty .Text ], SpecInfo ]] = None ,
870
868
cont_dim = None ,
871
869
messenger_args = None ,
872
870
messengers = None ,
0 commit comments