@@ -54,7 +54,7 @@ def is_set(value: ty.Any) -> bool:
5454 return value not in (attrs .NOTHING , EMPTY )
5555
5656
57- @attrs .define
57+ @attrs .define ( kw_only = True , auto_attribs = False , eq = False )
5858class TaskOutputs :
5959 """Base class for all output definitions"""
6060
@@ -113,11 +113,31 @@ def __getitem__(self, name_or_index: str | int) -> ty.Any:
113113 f"{ self } doesn't have an attribute { name_or_index } "
114114 ) from None
115115
116+ def __eq__ (self , other : ty .Any ) -> bool :
117+ """Check if two task definitions are equal"""
118+ values = attrs .asdict (self )
119+ fields = list_fields (self )
120+ try :
121+ other_values = attrs .asdict (other )
122+ except AttributeError :
123+ return False
124+ try :
125+ other_fields = list_fields (other )
126+ except AttributeError :
127+ return False
128+ if fields != other_fields :
129+ return False
130+ for field in list_fields (self ):
131+ if field .hash_eq :
132+ values [field .name ] = hash_function (values [field .name ])
133+ other_values [field .name ] = hash_function (other_values [field .name ])
134+ return values == other_values
135+
116136
117137OutputsType = ty .TypeVar ("OutputType" , bound = TaskOutputs )
118138
119139
120- @attrs .define (kw_only = True , auto_attribs = False )
140+ @attrs .define (kw_only = True , auto_attribs = False , eq = False )
121141class TaskDef (ty .Generic [OutputsType ]):
122142 """Base class for all task definitions"""
123143
@@ -341,6 +361,34 @@ def __iter__(self) -> ty.Generator[str, None, None]:
341361 if not (f .name .startswith ("_" ) or f .name in self .RESERVED_FIELD_NAMES )
342362 )
343363
364+ def __eq__ (self , other : ty .Any ) -> bool :
365+ """Check if two task definitions are equal"""
366+ values = attrs .asdict (self )
367+ try :
368+ other_values = attrs .asdict (other )
369+ except AttributeError :
370+ return False
371+ if set (values ) != set (other_values ):
372+ return False # Return if attribute keys don't match
373+ for field in list_fields (self ):
374+ if field .hash_eq :
375+ values [field .name ] = hash_function (values [field .name ])
376+ other_values [field .name ] = hash_function (other_values [field .name ])
377+ if values != other_values :
378+ return False
379+ hash_cache = Cache ()
380+ if hash_function (type (self ), cache = hash_cache ) != hash_function (
381+ type (other ), cache = hash_cache
382+ ):
383+ return False
384+ try :
385+ other_outputs = other .Outputs
386+ except AttributeError :
387+ return False
388+ return hash_function (self .Outputs , cache = hash_cache ) == hash_function (
389+ other_outputs , cache = hash_cache
390+ )
391+
344392 def __getitem__ (self , name : str ) -> ty .Any :
345393 """Return the value for the given attribute, resolving any templates
346394
@@ -595,7 +643,7 @@ class RuntimeSpec:
595643 network : bool = False
596644
597645
598- @attrs .define (kw_only = True , auto_attribs = False )
646+ @attrs .define (kw_only = True , auto_attribs = False , eq = False )
599647class PythonOutputs (TaskOutputs ):
600648
601649 @classmethod
@@ -624,7 +672,7 @@ def _from_task(cls, task: "Task[PythonDef]") -> Self:
624672PythonOutputsType = ty .TypeVar ("OutputType" , bound = PythonOutputs )
625673
626674
627- @attrs .define (kw_only = True , auto_attribs = False )
675+ @attrs .define (kw_only = True , auto_attribs = False , eq = False )
628676class PythonDef (TaskDef [PythonOutputsType ]):
629677
630678 _task_type = "python"
@@ -653,7 +701,7 @@ def _run(self, task: "Task[PythonDef]") -> None:
653701 )
654702
655703
656- @attrs .define (kw_only = True , auto_attribs = False )
704+ @attrs .define (kw_only = True , auto_attribs = False , eq = False )
657705class WorkflowOutputs (TaskOutputs ):
658706
659707 @classmethod
@@ -707,7 +755,7 @@ def _from_task(cls, task: "Task[WorkflowDef]") -> Self:
707755WorkflowOutputsType = ty .TypeVar ("OutputType" , bound = WorkflowOutputs )
708756
709757
710- @attrs .define (kw_only = True , auto_attribs = False )
758+ @attrs .define (kw_only = True , auto_attribs = False , eq = False )
711759class WorkflowDef (TaskDef [WorkflowOutputsType ]):
712760
713761 _task_type = "workflow"
@@ -738,7 +786,7 @@ def construct(self) -> "Workflow":
738786STDERR_HELP = """The standard error stream produced by the command."""
739787
740788
741- @attrs .define (kw_only = True , auto_attribs = False )
789+ @attrs .define (kw_only = True , auto_attribs = False , eq = False )
742790class ShellOutputs (TaskOutputs ):
743791 """Output definition of a generic shell process."""
744792
@@ -899,7 +947,7 @@ def _resolve_value(
899947ShellOutputsType = ty .TypeVar ("OutputType" , bound = ShellOutputs )
900948
901949
902- @attrs .define (kw_only = True , auto_attribs = False )
950+ @attrs .define (kw_only = True , auto_attribs = False , eq = False )
903951class ShellDef (TaskDef [ShellOutputsType ]):
904952
905953 _task_type = "shell"
0 commit comments