Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ Added
``add_argument`` (`#661
<https://github.com/omni-us/jsonargparse/pull/661>`__).

Fixed
^^^^^
- Incorrect instantiation order when instantiation targets share a parent (`#662
<https://github.com/omni-us/jsonargparse/pull/662>`__).


v4.36.0 (2025-01-17)
--------------------
Expand Down
25 changes: 22 additions & 3 deletions jsonargparse/_link_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
_find_parent_action,
filter_default_actions,
)
from ._namespace import Namespace, split_key_leaf
from ._namespace import Namespace, split_key, split_key_leaf
from ._parameter_resolvers import get_signature_parameters
from ._type_checking import ArgumentGroup, ArgumentParser

Expand Down Expand Up @@ -81,7 +81,10 @@ def add_edge(self, source, target):
for node in [source, target]:
if node not in self.nodes:
self.nodes.append(node)
self.edges_dict[self.nodes.index(source)].append(self.nodes.index(target))
source_targets_list = self.edges_dict[self.nodes.index(source)]
target_index = self.nodes.index(target)
if target_index not in source_targets_list:
source_targets_list.append(target_index)

def get_topological_order(self):
exploring = [False] * len(self.nodes)
Expand All @@ -97,7 +100,7 @@ def topological_sort(self, source, exploring, visited, order):
for target in self.edges_dict[source]:
if exploring[target]:
raise ValueError(
f"Graph has cycles, found while checking {self.nodes[source]} --> " + self.nodes[target]
f"Graph has cycles, found while checking {self.nodes[source]} --> {self.nodes[target]}"
)
elif not visited[target]:
self.topological_sort(target, exploring, visited, order)
Expand Down Expand Up @@ -406,11 +409,27 @@ def set_target_value(action: "ActionLink", value: Any, cfg: Namespace, logger) -
def instantiation_order(parser):
actions = get_link_actions(parser, "instantiate")
if actions:
targets = set()
graph = DirectedGraph()

# Add instantiation links as edges
for action in actions:
target = re.sub(r"\.init_args$", "", split_key_leaf(action.target[0])[0])
for _, source_action in action.source:
graph.add_edge(source_action.dest, target)
targets.add(target)

# Add instantiation target prefixes as edges
targets = sorted(targets, key=lambda x: len(split_key(x)))
seen_targets = {targets[0]}
for target in targets[1:]:
parts = [x.replace("|", ".") for x in target.replace("init_args.", "init_args|").split(".")]
for num in range(len(parts) - 1):
target_prefix = ".".join(parts[: num + 1])
if target_prefix in seen_targets:
graph.add_edge(target, target_prefix)
seen_targets.add(target)

return graph.get_topological_order()
return []

Expand Down
4 changes: 3 additions & 1 deletion jsonargparse/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,11 @@ def _add_signature_parameter(
annotation = Any
default = None if is_required else default
is_required = False
is_required_link_target = False
if is_required and linked_targets is not None and name in linked_targets:
default = None
is_required = False
is_required_link_target = True
if (
kind in {kinds.VAR_POSITIONAL, kinds.VAR_KEYWORD}
or (not is_required and name[0] == "_")
Expand All @@ -371,7 +373,7 @@ def _add_signature_parameter(
kwargs["help"] = param.doc
if not is_required:
kwargs["default"] = default
if default is None and not is_optional(annotation, object):
if default is None and not is_optional(annotation, object) and not is_required_link_target:
annotation = Optional[annotation]
elif not as_positional:
kwargs["required"] = True
Expand Down
74 changes: 74 additions & 0 deletions jsonargparse_tests/test_link_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,80 @@ def test_on_instantiate_within_deeper_subclass(parser, caplog):
assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog.text


class SourceA:
def __init__(self, param_a: str):
self.attr_a = param_a


class SourceB:
def __init__(self, param_b: str):
self.attr_b = param_b


class HierarchyGrandchild:
def __init__(self, param_grandchild: str):
self.attr_grandchild = param_grandchild


class HierarchyChild:
def __init__(self, param_child: str, grandchild: HierarchyGrandchild):
self.attr_child = param_child
self.grandchild = grandchild


class HierarchyRoot:
def __init__(self, child: HierarchyChild):
self.child = child


def get_source_a_attr(source_a):
assert isinstance(source_a, SourceA)
return source_a.attr_a


def test_on_instantiate_targets_share_parent(parser):
config = {
"source_a": {
"param_a": "value a",
},
"source_b": {
"param_b": "value b",
},
"root": {
"child": {
"class_path": "HierarchyChild",
"init_args": {
"grandchild": {
"class_path": "HierarchyGrandchild",
},
},
},
},
}
parser.add_argument("--config", action="config")
parser.add_class_arguments(SourceA, "source_a")
parser.add_class_arguments(SourceB, "source_b")
parser.add_class_arguments(HierarchyRoot, "root")
parser.link_arguments(
"source_a",
"root.child.init_args.grandchild.init_args.param_grandchild",
apply_on="instantiate",
compute_fn=get_source_a_attr,
)
parser.link_arguments("source_b.attr_b", "root.child.init_args.param_child", apply_on="instantiate")
cfg = parser.parse_args([f"--config={json.dumps(config)}"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.source_a, SourceA)
assert isinstance(init.source_b, SourceB)
assert isinstance(init.root, HierarchyRoot)
assert isinstance(init.root.child, HierarchyChild)
assert isinstance(init.root.child.grandchild, HierarchyGrandchild)
assert init.source_a.attr_a == "value a"
assert init.root.child.grandchild.attr_grandchild is init.source_a.attr_a
assert init.source_b.attr_b == "value b"
assert init.root.child.attr_child is init.source_b.attr_b


# link creation failures


Expand Down
Loading