Skip to content

Commit 6ed002f

Browse files
committed
added module bytes_repr serializer
1 parent 6a72657 commit 6ed002f

File tree

5 files changed

+104
-34
lines changed

5 files changed

+104
-34
lines changed

pydra/design/base.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
MultiOutputObj,
2424
MultiOutputFile,
2525
)
26+
from pydra.utils.hash import hash_function
2627

2728

2829
if ty.TYPE_CHECKING:
@@ -179,6 +180,8 @@ class Field:
179180
The converter for the field passed through to the attrs.field, by default it is None
180181
validator: callable | iterable[callable], optional
181182
The validator(s) for the field passed through to the attrs.field, by default it is None
183+
hash_eq: bool, optional
184+
Whether to use the hash of the value for equality comparison, by default it is False
182185
"""
183186

184187
name: str | None = None
@@ -192,8 +195,9 @@ class Field:
192195
requires: list[RequirementSet] = attrs.field(
193196
factory=list, converter=requires_converter
194197
)
195-
converter: ty.Callable | None = None
196-
validator: ty.Callable | None = None
198+
converter: ty.Callable[..., ty.Any] | None = None
199+
validator: ty.Callable[..., bool] | None = None
200+
hash_eq: bool = False
197201

198202
def requirements_satisfied(self, inputs: "TaskDef") -> bool:
199203
"""Check if all the requirements are satisfied by the inputs"""
@@ -408,6 +412,7 @@ def make_task_def(
408412
klass : type
409413
The class created using the attrs package
410414
"""
415+
411416
spec_type._check_arg_refs(inputs, outputs)
412417

413418
for inpt in inputs.values():
@@ -448,15 +453,15 @@ def make_task_def(
448453
# Now that we have saved the attributes in lists to be
449454
for arg in inputs.values():
450455
# If an outarg input then the field type should be Path not a FileSet
451-
default_kwargs = _get_default(arg)
456+
attrs_kwargs = _get_attrs_kwargs(arg)
452457
if isinstance(arg, Out) and is_fileset_or_union(arg.type):
453458
if getattr(arg, "path_template", False):
454459
if is_optional(arg.type):
455460
field_type = Path | bool | None
456461
# Will default to None and not be inserted into the command
457462
else:
458463
field_type = Path | bool
459-
default_kwargs = {"default": True}
464+
attrs_kwargs = {"default": True}
460465
elif is_optional(arg.type):
461466
field_type = Path | None
462467
else:
@@ -471,14 +476,14 @@ def make_task_def(
471476
validator=make_validator(arg, klass.__name__),
472477
metadata={PYDRA_ATTR_METADATA: arg},
473478
on_setattr=attrs.setters.convert,
474-
**default_kwargs,
479+
**attrs_kwargs,
475480
),
476481
)
477482
klass.__annotations__[arg.name] = field_type
478483

479484
# Create class using attrs package, will create attributes for all columns and
480485
# parameters
481-
attrs_klass = attrs.define(auto_attribs=False, kw_only=True)(klass)
486+
attrs_klass = attrs.define(auto_attribs=False, kw_only=True, eq=False)(klass)
482487

483488
return attrs_klass
484489

@@ -541,13 +546,15 @@ def make_outputs_spec(
541546
n: attrs.field(
542547
converter=make_converter(o, f"{spec_name}.Outputs"),
543548
metadata={PYDRA_ATTR_METADATA: o},
544-
**_get_default(o),
549+
**_get_attrs_kwargs(o),
545550
)
546551
for n, o in outputs.items()
547552
},
548553
)
549554
outputs_klass.__annotations__.update((o.name, o.type) for o in outputs.values())
550-
outputs_klass = attrs.define(auto_attribs=False, kw_only=True)(outputs_klass)
555+
outputs_klass = attrs.define(auto_attribs=False, kw_only=True, eq=False)(
556+
outputs_klass
557+
)
551558
return outputs_klass
552559

553560

@@ -972,14 +979,19 @@ def check_explicit_fields_are_none(klass, inputs, outputs):
972979
)
973980

974981

975-
def _get_default(field: Field) -> dict[str, ty.Any]:
982+
def _get_attrs_kwargs(field: Field) -> dict[str, ty.Any]:
983+
kwargs = {}
976984
if not hasattr(field, "default"):
977-
return {"factory": nothing_factory}
978-
if field.default is not EMPTY:
979-
return {"default": field.default}
980-
if is_optional(field.type):
981-
return {"default": None}
982-
return {"factory": nothing_factory}
985+
kwargs["factory"] = nothing_factory
986+
elif field.default is not EMPTY:
987+
kwargs["default"] = field.default
988+
elif is_optional(field.type):
989+
kwargs["default"] = None
990+
else:
991+
kwargs["factory"] = nothing_factory
992+
if field.hash_eq:
993+
kwargs["eq"] = hash_function
994+
return kwargs
983995

984996

985997
def nothing_factory():

pydra/design/python.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def make(wrapped: ty.Callable | type) -> PythonDef:
158158
)
159159

160160
parsed_inputs["function"] = arg(
161-
name="function", type=ty.Callable, default=function
161+
name="function", type=ty.Callable, default=function, hash_eq=True
162162
)
163163

164164
defn = make_task_def(

pydra/engine/specs.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def is_set(value: ty.Any) -> bool:
5454
return value not in (attrs.NOTHING, EMPTY)
5555

5656

57-
@attrs.define
57+
@attrs.define(kw_only=True, auto_attribs=False, eq=False)
5858
class TaskOutputs:
5959
"""Base class for all output definitions"""
6060

@@ -113,11 +113,31 @@ def __getitem__(self, name_or_index: str | int) -> ty.Any:
113113
f"{self} doesn't have an attribute {name_or_index}"
114114
) from None
115115

116+
def __eq__(self, other: ty.Any) -> bool:
117+
"""Check if two task definitions are equal"""
118+
values = attrs.asdict(self)
119+
fields = list_fields(self)
120+
try:
121+
other_values = attrs.asdict(other)
122+
except AttributeError:
123+
return False
124+
try:
125+
other_fields = list_fields(other)
126+
except AttributeError:
127+
return False
128+
if fields != other_fields:
129+
return False
130+
for field in list_fields(self):
131+
if field.hash_eq:
132+
values[field.name] = hash_function(values[field.name])
133+
other_values[field.name] = hash_function(other_values[field.name])
134+
return values == other_values
135+
116136

117137
OutputsType = ty.TypeVar("OutputType", bound=TaskOutputs)
118138

119139

120-
@attrs.define(kw_only=True, auto_attribs=False)
140+
@attrs.define(kw_only=True, auto_attribs=False, eq=False)
121141
class TaskDef(ty.Generic[OutputsType]):
122142
"""Base class for all task definitions"""
123143

@@ -341,6 +361,34 @@ def __iter__(self) -> ty.Generator[str, None, None]:
341361
if not (f.name.startswith("_") or f.name in self.RESERVED_FIELD_NAMES)
342362
)
343363

364+
def __eq__(self, other: ty.Any) -> bool:
365+
"""Check if two task definitions are equal"""
366+
values = attrs.asdict(self)
367+
try:
368+
other_values = attrs.asdict(other)
369+
except AttributeError:
370+
return False
371+
if set(values) != set(other_values):
372+
return False # Return if attribute keys don't match
373+
for field in list_fields(self):
374+
if field.hash_eq:
375+
values[field.name] = hash_function(values[field.name])
376+
other_values[field.name] = hash_function(other_values[field.name])
377+
if values != other_values:
378+
return False
379+
hash_cache = Cache()
380+
if hash_function(type(self), cache=hash_cache) != hash_function(
381+
type(other), cache=hash_cache
382+
):
383+
return False
384+
try:
385+
other_outputs = other.Outputs
386+
except AttributeError:
387+
return False
388+
return hash_function(self.Outputs, cache=hash_cache) == hash_function(
389+
other_outputs, cache=hash_cache
390+
)
391+
344392
def __getitem__(self, name: str) -> ty.Any:
345393
"""Return the value for the given attribute, resolving any templates
346394
@@ -595,7 +643,7 @@ class RuntimeSpec:
595643
network: bool = False
596644

597645

598-
@attrs.define(kw_only=True, auto_attribs=False)
646+
@attrs.define(kw_only=True, auto_attribs=False, eq=False)
599647
class PythonOutputs(TaskOutputs):
600648

601649
@classmethod
@@ -624,7 +672,7 @@ def _from_task(cls, task: "Task[PythonDef]") -> Self:
624672
PythonOutputsType = ty.TypeVar("OutputType", bound=PythonOutputs)
625673

626674

627-
@attrs.define(kw_only=True, auto_attribs=False)
675+
@attrs.define(kw_only=True, auto_attribs=False, eq=False)
628676
class PythonDef(TaskDef[PythonOutputsType]):
629677

630678
_task_type = "python"
@@ -653,7 +701,7 @@ def _run(self, task: "Task[PythonDef]") -> None:
653701
)
654702

655703

656-
@attrs.define(kw_only=True, auto_attribs=False)
704+
@attrs.define(kw_only=True, auto_attribs=False, eq=False)
657705
class WorkflowOutputs(TaskOutputs):
658706

659707
@classmethod
@@ -707,7 +755,7 @@ def _from_task(cls, task: "Task[WorkflowDef]") -> Self:
707755
WorkflowOutputsType = ty.TypeVar("OutputType", bound=WorkflowOutputs)
708756

709757

710-
@attrs.define(kw_only=True, auto_attribs=False)
758+
@attrs.define(kw_only=True, auto_attribs=False, eq=False)
711759
class WorkflowDef(TaskDef[WorkflowOutputsType]):
712760

713761
_task_type = "workflow"
@@ -738,7 +786,7 @@ def construct(self) -> "Workflow":
738786
STDERR_HELP = """The standard error stream produced by the command."""
739787

740788

741-
@attrs.define(kw_only=True, auto_attribs=False)
789+
@attrs.define(kw_only=True, auto_attribs=False, eq=False)
742790
class ShellOutputs(TaskOutputs):
743791
"""Output definition of a generic shell process."""
744792

@@ -899,7 +947,7 @@ def _resolve_value(
899947
ShellOutputsType = ty.TypeVar("OutputType", bound=ShellOutputs)
900948

901949

902-
@attrs.define(kw_only=True, auto_attribs=False)
950+
@attrs.define(kw_only=True, auto_attribs=False, eq=False)
903951
class ShellDef(TaskDef[ShellOutputsType]):
904952

905953
_task_type = "shell"

pydra/engine/submitter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class Submitter:
8787
def __init__(
8888
self,
8989
cache_dir: os.PathLike | None = None,
90-
worker: str | ty.Type[Worker] | Worker = "debug",
90+
worker: str | ty.Type[Worker] | Worker | None = "debug",
9191
environment: "Environment | None" = None,
9292
rerun: bool = False,
9393
cache_locations: list[os.PathLike] | None = None,
@@ -98,6 +98,9 @@ def __init__(
9898
**kwargs,
9999
):
100100

101+
if worker is None:
102+
worker = "debug"
103+
101104
from . import check_latest_version
102105

103106
if Task._etelemetry_version_data is None:

pydra/utils/hash.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from filelock import SoftFileLock
2525
import attrs.exceptions
2626
from fileformats.core.fileset import FileSet, MockMixin
27+
from fileformats.generic import FsObject
2728
import fileformats.core.exceptions
2829
from . import user_cache_dir, add_exc_note
2930
from .misc import in_stdlib
@@ -332,17 +333,16 @@ def bytes_repr(obj: object, cache: Cache) -> Iterator[bytes]:
332333
elif hasattr(obj, "__slots__") and obj.__slots__ is not None:
333334
dct = {attr: getattr(obj, attr) for attr in obj.__slots__}
334335
else:
336+
337+
def is_special_or_method(n: str):
338+
return (n.startswith("__") and n.endswith("__")) or inspect.ismethod(
339+
getattr(obj, n)
340+
)
341+
335342
try:
336-
dct = obj.__dict__
343+
dct = {n: v for n, v in obj.__dict__.items() if not is_special_or_method(n)}
337344
except AttributeError:
338-
dct = {
339-
n: getattr(obj, n)
340-
for n in dir(obj)
341-
if not (
342-
(n.startswith("__") and n.endswith("__"))
343-
or inspect.ismethod(getattr(obj, n))
344-
)
345-
}
345+
dct = {n: getattr(obj, n) for n in dir(obj) if not is_special_or_method(n)}
346346
yield from bytes_repr_mapping_contents(dct, cache)
347347
yield b"}"
348348

@@ -456,6 +456,13 @@ def bytes_repr_dict(obj: dict, cache: Cache) -> Iterator[bytes]:
456456
yield b"}"
457457

458458

459+
@register_serializer
460+
def bytes_repr_module(obj: types.ModuleType, cache: Cache) -> Iterator[bytes]:
461+
yield b"module:("
462+
yield hash_single(FsObject(obj.__file__), cache=cache)
463+
yield b")"
464+
465+
459466
@register_serializer(ty._GenericAlias)
460467
@register_serializer(ty._SpecialForm)
461468
@register_serializer(type)

0 commit comments

Comments
 (0)