Skip to content

Commit 2a4454a

Browse files
authored
Fix incorrect instantiation order when instantiation targets share a parent (#662)
1 parent eee8ea9 commit 2a4454a

File tree

4 files changed

+104
-4
lines changed

4 files changed

+104
-4
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ Added
2121
``add_argument`` (`#661
2222
<https://github.com/omni-us/jsonargparse/pull/661>`__).
2323

24+
Fixed
25+
^^^^^
26+
- Incorrect instantiation order when instantiation targets share a parent (`#662
27+
<https://github.com/omni-us/jsonargparse/pull/662>`__).
28+
2429

2530
v4.36.0 (2025-01-17)
2631
--------------------

jsonargparse/_link_arguments.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
_find_parent_action,
1919
filter_default_actions,
2020
)
21-
from ._namespace import Namespace, split_key_leaf
21+
from ._namespace import Namespace, split_key, split_key_leaf
2222
from ._parameter_resolvers import get_signature_parameters
2323
from ._type_checking import ArgumentGroup, ArgumentParser
2424

@@ -81,7 +81,10 @@ def add_edge(self, source, target):
8181
for node in [source, target]:
8282
if node not in self.nodes:
8383
self.nodes.append(node)
84-
self.edges_dict[self.nodes.index(source)].append(self.nodes.index(target))
84+
source_targets_list = self.edges_dict[self.nodes.index(source)]
85+
target_index = self.nodes.index(target)
86+
if target_index not in source_targets_list:
87+
source_targets_list.append(target_index)
8588

8689
def get_topological_order(self):
8790
exploring = [False] * len(self.nodes)
@@ -97,7 +100,7 @@ def topological_sort(self, source, exploring, visited, order):
97100
for target in self.edges_dict[source]:
98101
if exploring[target]:
99102
raise ValueError(
100-
f"Graph has cycles, found while checking {self.nodes[source]} --> " + self.nodes[target]
103+
f"Graph has cycles, found while checking {self.nodes[source]} --> {self.nodes[target]}"
101104
)
102105
elif not visited[target]:
103106
self.topological_sort(target, exploring, visited, order)
@@ -406,11 +409,27 @@ def set_target_value(action: "ActionLink", value: Any, cfg: Namespace, logger) -
406409
def instantiation_order(parser):
407410
actions = get_link_actions(parser, "instantiate")
408411
if actions:
412+
targets = set()
409413
graph = DirectedGraph()
414+
415+
# Add instantiation links as edges
410416
for action in actions:
411417
target = re.sub(r"\.init_args$", "", split_key_leaf(action.target[0])[0])
412418
for _, source_action in action.source:
413419
graph.add_edge(source_action.dest, target)
420+
targets.add(target)
421+
422+
# Add instantiation target prefixes as edges
423+
targets = sorted(targets, key=lambda x: len(split_key(x)))
424+
seen_targets = {targets[0]}
425+
for target in targets[1:]:
426+
parts = [x.replace("|", ".") for x in target.replace("init_args.", "init_args|").split(".")]
427+
for num in range(len(parts) - 1):
428+
target_prefix = ".".join(parts[: num + 1])
429+
if target_prefix in seen_targets:
430+
graph.add_edge(target, target_prefix)
431+
seen_targets.add(target)
432+
414433
return graph.get_topological_order()
415434
return []
416435

jsonargparse/_signatures.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,11 @@ def _add_signature_parameter(
351351
annotation = Any
352352
default = None if is_required else default
353353
is_required = False
354+
is_required_link_target = False
354355
if is_required and linked_targets is not None and name in linked_targets:
355356
default = None
356357
is_required = False
358+
is_required_link_target = True
357359
if (
358360
kind in {kinds.VAR_POSITIONAL, kinds.VAR_KEYWORD}
359361
or (not is_required and name[0] == "_")
@@ -371,7 +373,7 @@ def _add_signature_parameter(
371373
kwargs["help"] = param.doc
372374
if not is_required:
373375
kwargs["default"] = default
374-
if default is None and not is_optional(annotation, object):
376+
if default is None and not is_optional(annotation, object) and not is_required_link_target:
375377
annotation = Optional[annotation]
376378
elif not as_positional:
377379
kwargs["required"] = True

jsonargparse_tests/test_link_arguments.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,80 @@ def test_on_instantiate_within_deeper_subclass(parser, caplog):
826826
assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog.text
827827

828828

829+
class SourceA:
830+
def __init__(self, param_a: str):
831+
self.attr_a = param_a
832+
833+
834+
class SourceB:
835+
def __init__(self, param_b: str):
836+
self.attr_b = param_b
837+
838+
839+
class HierarchyGrandchild:
840+
def __init__(self, param_grandchild: str):
841+
self.attr_grandchild = param_grandchild
842+
843+
844+
class HierarchyChild:
845+
def __init__(self, param_child: str, grandchild: HierarchyGrandchild):
846+
self.attr_child = param_child
847+
self.grandchild = grandchild
848+
849+
850+
class HierarchyRoot:
851+
def __init__(self, child: HierarchyChild):
852+
self.child = child
853+
854+
855+
def get_source_a_attr(source_a):
856+
assert isinstance(source_a, SourceA)
857+
return source_a.attr_a
858+
859+
860+
def test_on_instantiate_targets_share_parent(parser):
861+
config = {
862+
"source_a": {
863+
"param_a": "value a",
864+
},
865+
"source_b": {
866+
"param_b": "value b",
867+
},
868+
"root": {
869+
"child": {
870+
"class_path": "HierarchyChild",
871+
"init_args": {
872+
"grandchild": {
873+
"class_path": "HierarchyGrandchild",
874+
},
875+
},
876+
},
877+
},
878+
}
879+
parser.add_argument("--config", action="config")
880+
parser.add_class_arguments(SourceA, "source_a")
881+
parser.add_class_arguments(SourceB, "source_b")
882+
parser.add_class_arguments(HierarchyRoot, "root")
883+
parser.link_arguments(
884+
"source_a",
885+
"root.child.init_args.grandchild.init_args.param_grandchild",
886+
apply_on="instantiate",
887+
compute_fn=get_source_a_attr,
888+
)
889+
parser.link_arguments("source_b.attr_b", "root.child.init_args.param_child", apply_on="instantiate")
890+
cfg = parser.parse_args([f"--config={json.dumps(config)}"])
891+
init = parser.instantiate_classes(cfg)
892+
assert isinstance(init.source_a, SourceA)
893+
assert isinstance(init.source_b, SourceB)
894+
assert isinstance(init.root, HierarchyRoot)
895+
assert isinstance(init.root.child, HierarchyChild)
896+
assert isinstance(init.root.child.grandchild, HierarchyGrandchild)
897+
assert init.source_a.attr_a == "value a"
898+
assert init.root.child.grandchild.attr_grandchild is init.source_a.attr_a
899+
assert init.source_b.attr_b == "value b"
900+
assert init.root.child.attr_child is init.source_b.attr_b
901+
902+
829903
# link creation failures
830904

831905

0 commit comments

Comments
 (0)