Skip to content

Commit 172f9e1

Browse files
authored
Fixed nested generic dataclasses not working correctly (#709)
1 parent 7637f7f commit 172f9e1

File tree

4 files changed

+36
-2
lines changed

4 files changed

+36
-2
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ Fixed
3636
^^^^^
3737
- ``ActionParser`` not updating ``dest`` of groups for instantiation (`#707
3838
<https://github.com/omni-us/jsonargparse/pull/707>`__).
39+
- Nested generic dataclasses not working correctly (`#709
40+
<https://github.com/omni-us/jsonargparse/pull/709>`__).
3941

4042

4143
v4.38.0 (2025-03-26)

jsonargparse/_signatures.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def _create_group_if_requested(
537537
doc_group = str(obj[0])
538538
else:
539539
doc_group = str(obj)
540-
name = obj.__name__ if nested_key is None else nested_key
540+
name = get_object_name(obj) if nested_key is None else nested_key
541541
group = self.add_argument_group(strip_title(doc_group), name=name)
542542
if config_load and nested_key is not None:
543543
group.add_argument("--" + nested_key, action=_ActionConfigLoad(basetype=config_load_type))
@@ -548,6 +548,12 @@ def _create_group_if_requested(
548548
return group
549549

550550

551+
def get_object_name(obj) -> str:
552+
if hasattr(obj, "__name__"):
553+
return obj.__name__
554+
return str(obj).split(".")[-1].replace("[", "_").replace("]", "")
555+
556+
551557
def group_instantiate_class(group, cfg):
552558
try:
553559
value, parent, key = cfg.get_value_and_parent(group.dest)

jsonargparse/_typehints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ def get_class_parser(val_class, sub_add_kwargs=None, skip_args=0):
644644
parser = parent_parser.get()
645645
parser = type(parser)(exit_on_error=False, logger=parser.logger, parser_mode=parser.parser_mode)
646646
remove_actions(parser, (ActionConfigFile, _ActionPrintConfig))
647-
if inspect.isclass(val_class):
647+
if inspect.isclass(val_class) or inspect.isclass(get_typehint_origin(val_class)):
648648
parser.add_class_arguments(val_class, **kwargs)
649649
else:
650650
kwargs = {k: v for k, v in kwargs.items() if k != "instantiate"}

jsonargparse_tests/test_dataclass_like.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import dataclasses
44
import json
5+
import sys
56
from typing import Any, Dict, Generic, List, Literal, Optional, Tuple, TypeVar, Union
67
from unittest.mock import patch
78

@@ -539,6 +540,31 @@ def test_nested_generic_dataclass(parser):
539540
assert "--x.y.g4 g4 (required, type: dict[str, union[float, bool]])" in help_str
540541

541542

543+
if sys.version_info >= (3, 9):
544+
V = TypeVar("V")
545+
546+
@dataclasses.dataclass(frozen=True)
547+
class GenericChild(Generic[V]):
548+
value: V
549+
550+
@dataclasses.dataclass(frozen=True)
551+
class GenericBase(Generic[V]):
552+
children: tuple[GenericChild[V], ...]
553+
554+
@dataclasses.dataclass(frozen=True)
555+
class GenericSubclass(GenericBase[str]):
556+
children: tuple[GenericChild[str], ...]
557+
558+
def test_generic_dataclass_subclass(parser):
559+
parser.add_class_arguments(GenericSubclass, "x")
560+
cfg = parser.parse_args(['--x.children=[{"value": "a"}, {"value": "b"}]'])
561+
init = parser.instantiate_classes(cfg)
562+
assert cfg.x.children == (Namespace(value="a"), Namespace(value="b"))
563+
assert isinstance(init.x, GenericSubclass)
564+
assert isinstance(init.x.children[0], GenericChild)
565+
assert isinstance(init.x.children[1], GenericChild)
566+
567+
542568
# union mixture tests
543569

544570

0 commit comments

Comments
 (0)