Skip to content

Commit 8760844

Browse files
authored
Merge pull request dfurtado#19 from zivanfi/optionals
Handle Optional typing annotations.
2 parents 1ac84ed + 8f39469 commit 8760844

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

dataclass_csv/dataclass_reader.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from datetime import datetime
55
from distutils.util import strtobool
6+
from typing import Union
67

78
from .field_mapper import FieldMapper
89
from .exceptions import CsvValueError
@@ -145,7 +146,22 @@ def _process_row(self, row):
145146
values.append(None)
146147
continue
147148

148-
if field.type is datetime:
149+
field_type = field.type
150+
# Special handling for Optional (Union of a single real type and None)
151+
if (
152+
# The first part of the condition is for Python < 3.8
153+
type(field_type).__name__ == '_Union'
154+
# The second part of the condition is for Python >= 3.8
155+
or '__origin__' in field_type.__dict__
156+
and field_type.__origin__ is Union
157+
):
158+
real_types = [
159+
t for t in field_type.__args__ if t is not type(None)
160+
]
161+
if len(real_types) == 1:
162+
field_type = real_types[0]
163+
164+
if field_type is datetime:
149165
try:
150166
transformed_value = self._parse_date_value(field, value)
151167
except ValueError as ex:
@@ -156,7 +172,7 @@ def _process_row(self, row):
156172
values.append(transformed_value)
157173
continue
158174

159-
if field.type is bool:
175+
if field_type is bool:
160176
try:
161177
transformed_value = (
162178
value
@@ -172,7 +188,7 @@ def _process_row(self, row):
172188
continue
173189

174190
try:
175-
transformed_value = field.type(value)
191+
transformed_value = field_type(value)
176192
except ValueError:
177193
raise CsvValueError(
178194
(

tests/mocks.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from dataclass_csv import dateformat, accept_whitespaces
66

7+
from typing import Optional
8+
79

810
@dataclasses.dataclass
911
class User:
@@ -65,3 +67,10 @@ class DataclassWithBooleanValue:
6567
@dataclasses.dataclass
6668
class DataclassWithBooleanValueNoneDefault:
6769
boolValue: bool = None
70+
71+
72+
@dataclasses.dataclass
73+
class UserWithOptionalAge:
74+
name: str
75+
age: Optional[int]
76+

tests/test_dataclass_reader.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from dataclass_csv import DataclassReader, CsvValueError
55

6-
from .mocks import User, DataclassWithBooleanValue, DataclassWithBooleanValueNoneDefault
6+
from .mocks import User, UserWithOptionalAge, DataclassWithBooleanValue, DataclassWithBooleanValueNoneDefault
77

88

99
def test_reader_with_non_dataclass(create_csv):
@@ -130,3 +130,13 @@ def test_parse_bool_value_none_default(create_csv):
130130
items = list(reader)
131131
dataclass_instance = items[0]
132132
assert dataclass_instance.boolValue is None
133+
134+
135+
def test_reader_with_optional_types(create_csv):
136+
csv_file = create_csv({'name': 'User', 'age': 40})
137+
138+
with csv_file.open() as f:
139+
reader = DataclassReader(f, UserWithOptionalAge)
140+
list(reader)
141+
142+

0 commit comments

Comments
 (0)