Skip to content

Commit 5a7955e

Browse files
tcloseeffigies
authored andcommitted
fixes issues with super->sub-class auto-cast and handles MultiInputObj coercion
1 parent 2d840be commit 5a7955e

File tree

1 file changed

+55
-17
lines changed

1 file changed

+55
-17
lines changed

pydra/utils/typing.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,18 @@ def coerce_obj(obj, type_):
366366
f"Cannot coerce {obj!r} into {type_}{msg}{self.label_str}"
367367
) from e
368368

369-
return expand_and_coerce(object_, self.pattern)
369+
# Special handling for MultiInputObjects (which are annoying)
370+
if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj:
371+
try:
372+
self.check_coercible(object_, self.pattern[1][0])
373+
except TypeError:
374+
pass
375+
else:
376+
obj = [object_]
377+
else:
378+
obj = object_
379+
380+
return expand_and_coerce(obj, self.pattern)
370381

371382
def check_type(self, type_: ty.Type[ty.Any]):
372383
"""Checks the given type to see whether it matches or is a subtype of the
@@ -413,7 +424,7 @@ def expand_and_check(tp, pattern: ty.Union[type, tuple]):
413424
f"{self.pattern}{self.label_str}"
414425
)
415426
tp_args = get_args(tp)
416-
self.check_coercible(tp_origin, pattern_origin)
427+
self.check_type_coercible(tp_origin, pattern_origin)
417428
if issubclass(pattern_origin, ty.Mapping):
418429
return check_mapping(tp_args, pattern_args)
419430
if issubclass(pattern_origin, tuple):
@@ -446,7 +457,7 @@ def check_basic(tp, target):
446457
+ "\n\n".join(f"{a} -> {e}" for a, e in zip(tp_args, reasons))
447458
)
448459
if not self.is_subclass(tp, target):
449-
self.check_coercible(tp, target)
460+
self.check_type_coercible(tp, target)
450461

451462
def check_union(tp, pattern_args):
452463
if get_origin(tp) in UNION_TYPES:
@@ -526,19 +537,46 @@ def check_sequence(tp_args, pattern_args):
526537
for arg in tp_args:
527538
expand_and_check(arg, pattern_args[0])
528539

529-
return expand_and_check(type_, self.pattern)
540+
# Special handling for MultiInputObjects (which are annoying)
541+
if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj:
542+
pattern = (ty.Union, [self.pattern[1][0], (ty.List, self.pattern[1])])
543+
else:
544+
pattern = self.pattern
545+
return expand_and_check(type_, pattern)
546+
547+
def check_coercible(self, source: ty.Any, target: ty.Union[type, ty.Any]):
548+
"""Checks whether the source object is coercible to the target type given the coercion
549+
rules defined in the `coercible` and `not_coercible` attrs
550+
551+
Parameters
552+
----------
553+
source : object
554+
the object to be coerced
555+
target : type or typing.Any
556+
the target type for the object to be coerced to
557+
558+
Raises
559+
------
560+
TypeError
561+
If the object cannot be coerced into the target type depending on the explicit
562+
inclusions and exclusions set in the `coercible` and `not_coercible` member attrs
563+
"""
564+
self.check_type_coercible(type(source), target, source_repr=repr(source))
530565

531-
def check_coercible(
532-
self, source: ty.Union[object, type], target: ty.Union[type, ty.Any]
566+
def check_type_coercible(
567+
self,
568+
source: ty.Union[type, ty.Any],
569+
target: ty.Union[type, ty.Any],
570+
source_repr: ty.Optional[str] = None,
533571
):
534-
"""Checks whether the source object or type is coercible to the target type
572+
"""Checks whether the source type is coercible to the target type
535573
given the coercion rules defined in the `coercible` and `not_coercible` attrs
536574
537575
Parameters
538576
----------
539-
source : object or type
540-
source object or type to be coerced
541-
target : type or ty.Any
577+
source : type or typing.Any
578+
source type to be coerced
579+
target : type or typing.Any
542580
target type for the source to be coerced to
543581
544582
Raises
@@ -548,10 +586,12 @@ def check_coercible(
548586
explicit inclusions and exclusions set in the `coercible` and `not_coercible`
549587
member attrs
550588
"""
589+
if source_repr is None:
590+
source_repr = repr(source)
551591
# Short-circuit the basic cases where the source and target are the same
552592
if source is target:
553593
return
554-
if self.superclass_auto_cast and self.is_subclass(target, type(source)):
594+
if self.superclass_auto_cast and self.is_subclass(target, source):
555595
logger.info(
556596
"Attempting to coerce %s into %s due to super-to-sub class coercion "
557597
"being permitted",
@@ -563,13 +603,11 @@ def check_coercible(
563603
if source_origin is not None:
564604
source = source_origin
565605

566-
source_check = self.is_subclass if inspect.isclass(source) else self.is_instance
567-
568606
def matches_criteria(criteria):
569607
return [
570608
(src, tgt)
571609
for src, tgt in criteria
572-
if source_check(source, src) and self.is_subclass(target, tgt)
610+
if self.is_subclass(source, src) and self.is_subclass(target, tgt)
573611
]
574612

575613
def type_name(t):
@@ -580,7 +618,7 @@ def type_name(t):
580618

581619
if not matches_criteria(self.coercible):
582620
raise TypeError(
583-
f"Cannot coerce {repr(source)} into {target}{self.label_str} as the "
621+
f"Cannot coerce {source_repr} into {target}{self.label_str} as the "
584622
"coercion doesn't match any of the explicit inclusion criteria: "
585623
+ ", ".join(
586624
f"{type_name(s)} -> {type_name(t)}" for s, t in self.coercible
@@ -589,7 +627,7 @@ def type_name(t):
589627
matches_not_coercible = matches_criteria(self.not_coercible)
590628
if matches_not_coercible:
591629
raise TypeError(
592-
f"Cannot coerce {repr(source)} into {target}{self.label_str} as it is explicitly "
630+
f"Cannot coerce {source_repr} into {target}{self.label_str} as it is explicitly "
593631
"excluded by the following coercion criteria: "
594632
+ ", ".join(
595633
f"{type_name(s)} -> {type_name(t)}"
@@ -683,7 +721,7 @@ def is_instance(
683721
if inspect.isclass(obj):
684722
return candidate is type
685723
if issubtype(type(obj), candidate) or (
686-
type(obj) is dict and candidate is ty.Mapping
724+
type(obj) is dict and candidate is ty.Mapping # noqa: E721
687725
):
688726
return True
689727
else:

0 commit comments

Comments
 (0)