Skip to content

Commit 9a34a17

Browse files
committed
Adding support for fields that are only included in __init__
1 parent 6b823c1 commit 9a34a17

File tree

2 files changed

+58
-10
lines changed

2 files changed

+58
-10
lines changed

argparse_dataclass.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -238,16 +238,21 @@
238238
Union,
239239
Any,
240240
Generic,
241+
ClassVar,
241242
)
242243
from dataclasses import (
243244
Field,
244245
is_dataclass,
245246
fields,
246247
MISSING,
248+
InitVar,
247249
dataclass as real_dataclass,
248250
)
249251
from importlib.metadata import version
250252

253+
# This is `typing._GenericAlias` but don't use non-public type names
254+
_ClassVarType = type(ClassVar[object])
255+
251256
# In Python 3.10, we can use types.NoneType
252257
NoneType = type(None)
253258

@@ -284,14 +289,17 @@ def _add_dataclass_options(
284289
if not is_dataclass(options_class):
285290
raise TypeError("cls must be a dataclass")
286291

287-
for field in fields(options_class):
292+
for field in _fields(options_class):
288293
if not field.init:
289294
continue # Ignore fields not in __init__
295+
f_type = field.type
296+
if _is_initvar(f_type):
297+
f_type = f_type.type
290298

291299
args = field.metadata.get("args", [f"--{_get_arg_name(field)}"])
292300
positional = not args[0].startswith("-")
293301
kwargs = {
294-
"type": field.metadata.get("type", field.type),
302+
"type": field.metadata.get("type", f_type),
295303
"help": field.metadata.get("help", None),
296304
}
297305

@@ -304,7 +312,7 @@ def _add_dataclass_options(
304312
kwargs["choices"] = field.metadata["choices"]
305313

306314
# Support Literal types as an alternative means of specifying choices.
307-
if get_origin(field.type) is Literal:
315+
if get_origin(f_type) is Literal:
308316
# Prohibit a potential collision with the choices field
309317
if field.metadata.get("choices") is not None:
310318
raise ValueError(
@@ -314,7 +322,7 @@ def _add_dataclass_options(
314322
)
315323

316324
# Get the types of the arguments of the Literal
317-
types = [type(arg) for arg in get_args(field.type)]
325+
types = [type(arg) for arg in get_args(f_type)]
318326

319327
# Make sure just a single type has been used
320328
if len(set(types)) > 1:
@@ -329,7 +337,7 @@ def _add_dataclass_options(
329337
# Overwrite the type kwarg
330338
kwargs["type"] = types[0]
331339
# Use the literal arguments as choices
332-
kwargs["choices"] = get_args(field.type)
340+
kwargs["choices"] = get_args(f_type)
333341

334342
if field.metadata.get("metavar") is not None:
335343
kwargs["metavar"] = field.metadata["metavar"]
@@ -343,7 +351,7 @@ def _add_dataclass_options(
343351
# did not specify the type of the elements within the list, we
344352
# try to infer it:
345353
try:
346-
kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple
354+
kwargs["type"] = get_args(f_type)[0] # get_args returns a tuple
347355
except IndexError:
348356
# get_args returned an empty tuple, type cannot be inferred
349357
raise ValueError(
@@ -357,12 +365,12 @@ def _add_dataclass_options(
357365
else:
358366
kwargs["default"] = MISSING
359367

360-
if field.type is bool:
368+
if f_type is bool:
361369
_handle_bool_type(field, args, kwargs)
362-
elif get_origin(field.type) is Union:
370+
elif get_origin(f_type) is Union:
363371
if field.metadata.get("type") is None:
364372
# Optional[X] is equivalent to Union[X, None].
365-
f_args = get_args(field.type)
373+
f_args = get_args(f_type)
366374
if len(f_args) == 2 and NoneType in f_args:
367375
arg = next(a for a in f_args if a is not NoneType)
368376
kwargs["type"] = arg
@@ -439,6 +447,24 @@ def _get_arg_name(field: Field):
439447
return field.name.replace("_", "-")
440448

441449

450+
def _fields(dataclass) -> Tuple[Field]:
451+
# dataclass.fields does not return fields that are of type InitVar
452+
dc_fields = getattr(dataclass, "__dataclass_fields__", None)
453+
if dc_fields is None:
454+
return fields(dataclass)
455+
return tuple(f for f in dc_fields.values() if not _is_classvar(f.type))
456+
457+
458+
def _is_classvar(a_type):
459+
return a_type is ClassVar or (
460+
type(a_type) is _ClassVarType and a_type.__origin__ is ClassVar
461+
)
462+
463+
464+
def _is_initvar(a_type):
465+
return a_type is InitVar or type(a_type) is InitVar
466+
467+
442468
class ArgumentParser(argparse.ArgumentParser, Generic[OptionsType]):
443469
"""Command line argument parser that derives its options from a dataclass.
444470

tests/test_functional.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22
import unittest
33
import datetime as dt
4-
from dataclasses import dataclass, field
4+
from dataclasses import dataclass, field, InitVar
55

66
from typing import Optional, Union
77

@@ -358,6 +358,28 @@ def __post_init__(self):
358358
self.assertEqual(params.time, "15:35:59")
359359
self.assertEqual(params.datetime, dt.datetime(1999, 12, 31, 15, 35, 59))
360360

361+
def test_init_only(self):
362+
@dataclass
363+
class Options:
364+
cls_var: ClassVar[str] = "Hello"
365+
date: InitVar[str]
366+
time: InitVar[str] = "00:00"
367+
datetime: dt.datetime = field(init=False)
368+
369+
def __post_init__(self, date, time):
370+
self.datetime = dt.datetime.fromisoformat(f"{date}T{time}")
371+
372+
args = ["--date", "1999-12-31"]
373+
params = parse_args(Options, args)
374+
self.assertFalse(hasattr(params, "date"))
375+
# time is always set to the default value. I think this is a bug..
376+
# self.assertFalse(hasattr(params, "time"))
377+
self.assertEqual(params.datetime, dt.datetime(1999, 12, 31))
378+
379+
args = ["--date", "1999-12-31", "--time", "15:35:59"]
380+
params = parse_args(Options, args)
381+
self.assertEqual(params.datetime, dt.datetime(1999, 12, 31, 15, 35, 59))
382+
361383

362384
if __name__ == "__main__":
363385
unittest.main()

0 commit comments

Comments
 (0)