@@ -797,6 +797,64 @@ def _reset(self):
797
797
task ._reset ()
798
798
799
799
800
+ def _sanitize_input_spec (
801
+ input_spec : ty .Union [SpecInfo , ty .List [str ]],
802
+ wf_name : str ,
803
+ ) -> SpecInfo :
804
+ """Makes sure the provided input specifications are valid.
805
+
806
+ If the input specification is a list of strings, this will
807
+ build a proper SpecInfo object out of it.
808
+
809
+ Parameters
810
+ ----------
811
+ input_spec : SpecInfo or List[str]
812
+ Input specification to be sanitized.
813
+
814
+ wf_name : str
815
+ The name of the workflow for which the input specifications
816
+ are sanitized.
817
+
818
+ Returns
819
+ -------
820
+ input_spec : SpecInfo
821
+ Sanitized input specifications.
822
+
823
+ Raises
824
+ ------
825
+ ValueError
826
+ If provided `input_spec` is None.
827
+ """
828
+ graph_checksum_input = ("_graph_checksums" , ty .Any )
829
+ if input_spec :
830
+ if isinstance (input_spec , SpecInfo ):
831
+ if not any ([x == BaseSpec for x in input_spec .bases ]):
832
+ raise ValueError ("Provided SpecInfo must have BaseSpec as it's base." )
833
+ if "_graph_checksums" not in {f [0 ] for f in input_spec .fields }:
834
+ input_spec .fields .insert (0 , graph_checksum_input )
835
+ return input_spec
836
+ else :
837
+ return SpecInfo (
838
+ name = "Inputs" ,
839
+ fields = [graph_checksum_input ]
840
+ + [
841
+ (
842
+ nm ,
843
+ attr .ib (
844
+ type = ty .Any ,
845
+ metadata = {
846
+ "help_string" : f"{ nm } input from { wf_name } workflow"
847
+ },
848
+ ),
849
+ )
850
+ for nm in input_spec
851
+ ],
852
+ bases = (BaseSpec ,),
853
+ )
854
+ else :
855
+ raise ValueError (f"Empty input_spec provided to Workflow { wf_name } ." )
856
+
857
+
800
858
class Workflow (TaskBase ):
801
859
"""A composite task with structure of computational graph."""
802
860
@@ -806,7 +864,7 @@ def __init__(
806
864
audit_flags : AuditFlag = AuditFlag .NONE ,
807
865
cache_dir = None ,
808
866
cache_locations = None ,
809
- input_spec : ty .Optional [ty .Union [ty .List [ty .Text ], SpecInfo , BaseSpec ]] = None ,
867
+ input_spec : ty .Optional [ty .Union [ty .List [ty .Text ], SpecInfo ]] = None ,
810
868
cont_dim = None ,
811
869
messenger_args = None ,
812
870
messengers = None ,
@@ -842,35 +900,7 @@ def __init__(
842
900
TODO
843
901
844
902
"""
845
- if input_spec :
846
- if isinstance (input_spec , BaseSpec ):
847
- self .input_spec = input_spec
848
- elif isinstance (input_spec , SpecInfo ):
849
- if not any ([x == BaseSpec for x in input_spec .bases ]):
850
- raise ValueError (
851
- "Provided SpecInfo must have BaseSpec as it's base."
852
- )
853
- self .input_spec = input_spec
854
- else :
855
- self .input_spec = SpecInfo (
856
- name = "Inputs" ,
857
- fields = [("_graph_checksums" , ty .Any )]
858
- + [
859
- (
860
- nm ,
861
- attr .ib (
862
- type = ty .Any ,
863
- metadata = {
864
- "help_string" : f"{ nm } input from { name } workflow"
865
- },
866
- ),
867
- )
868
- for nm in input_spec
869
- ],
870
- bases = (BaseSpec ,),
871
- )
872
- else :
873
- raise ValueError ("Empty input_spec provided to Workflow" )
903
+ self .input_spec = _sanitize_input_spec (input_spec , name )
874
904
875
905
self .output_spec = output_spec
876
906
0 commit comments