5
5
from typing import Self
6
6
import attrs .validators
7
7
from pydra .utils .typing import is_optional , is_fileset_or_union
8
- from pydra .utils .general import task_fields
8
+ from pydra .utils .general import get_fields
9
9
from pydra .utils .typing import StateArray , is_lazy
10
10
from pydra .utils .hash import hash_function
11
11
import os
@@ -67,7 +67,7 @@ def _from_job(cls, job: "Job[TaskType]") -> Self:
67
67
The outputs of the job
68
68
"""
69
69
defaults = {}
70
- for output in task_fields (cls ):
70
+ for output in get_fields (cls ):
71
71
if output .mandatory :
72
72
default = attrs .NOTHING
73
73
elif isinstance (output .default , attrs .Factory ):
@@ -116,18 +116,18 @@ def __getitem__(self, name_or_index: str | int) -> ty.Any:
116
116
def __eq__ (self , other : ty .Any ) -> bool :
117
117
"""Check if two tasks are equal"""
118
118
values = attrs .asdict (self )
119
- fields = task_fields (self )
119
+ fields = get_fields (self )
120
120
try :
121
121
other_values = attrs .asdict (other )
122
122
except AttributeError :
123
123
return False
124
124
try :
125
- other_fields = task_fields (other )
125
+ other_fields = get_fields (other )
126
126
except AttributeError :
127
127
return False
128
128
if fields != other_fields :
129
129
return False
130
- for field in task_fields (self ):
130
+ for field in get_fields (self ):
131
131
if field .hash_eq :
132
132
values [field .name ] = hash_function (values [field .name ])
133
133
other_values [field .name ] = hash_function (other_values [field .name ])
@@ -137,7 +137,7 @@ def __repr__(self) -> str:
137
137
"""A string representation of the task"""
138
138
fields_str = ", " .join (
139
139
f"{ f .name } ={ getattr (self , f .name )!r} "
140
- for f in task_fields (self )
140
+ for f in get_fields (self )
141
141
if getattr (self , f .name ) != f .default
142
142
)
143
143
return f"{ self .__class__ .__name__ } ({ fields_str } )"
@@ -407,7 +407,7 @@ def __repr__(self) -> str:
407
407
"""A string representation of the task"""
408
408
fields_str = ", " .join (
409
409
f"{ f .name } ={ getattr (self , f .name )!r} "
410
- for f in task_fields (self )
410
+ for f in get_fields (self )
411
411
if getattr (self , f .name ) != f .default
412
412
)
413
413
return f"{ self .__class__ .__name__ } ({ fields_str } )"
@@ -416,7 +416,7 @@ def __iter__(self) -> ty.Generator[str, None, None]:
416
416
"""Iterate through all the names in the task"""
417
417
return (
418
418
f .name
419
- for f in task_fields (self )
419
+ for f in get_fields (self )
420
420
if not (f .name .startswith ("_" ) or f .name in self .RESERVED_FIELD_NAMES )
421
421
)
422
422
@@ -431,7 +431,7 @@ def __eq__(self, other: ty.Any) -> bool:
431
431
return False
432
432
if set (values ) != set (other_values ):
433
433
return False # Return if attribute keys don't match
434
- for field in task_fields (self ):
434
+ for field in get_fields (self ):
435
435
if field .hash_eq :
436
436
values [field .name ] = hash_function (values [field .name ])
437
437
other_values [field .name ] = hash_function (other_values [field .name ])
@@ -486,7 +486,7 @@ def _hash_changes(self):
486
486
def _compute_hashes (self ) -> ty .Tuple [bytes , ty .Dict [str , bytes ]]:
487
487
"""Compute a basic hash for any given set of fields."""
488
488
inp_dict = {}
489
- for field in task_fields (self ):
489
+ for field in get_fields (self ):
490
490
if isinstance (field , Out ):
491
491
continue # Skip output fields
492
492
# removing values that are not set from hash calculation
@@ -508,7 +508,7 @@ def _rule_violations(self) -> list[str]:
508
508
509
509
field : Arg
510
510
errors = []
511
- for field in task_fields (self ):
511
+ for field in get_fields (self ):
512
512
value = self [field .name ]
513
513
514
514
if is_lazy (value ):
@@ -625,7 +625,7 @@ def _check_resolved(self):
625
625
@register_serializer
626
626
def bytes_repr_task (obj : Task , cache : Cache ) -> ty .Iterator [bytes ]:
627
627
yield f"task[{ obj ._task_type ()} ]:(" .encode ()
628
- for field in task_fields (obj ):
628
+ for field in get_fields (obj ):
629
629
yield f"{ field .name } =" .encode ()
630
630
yield hash_single (getattr (obj , field .name ), cache )
631
631
yield b","
0 commit comments