Skip to content

Commit a185278

Browse files
author
Henry Walshaw
committed
Deal with computed fields and related alias bug
This involves switching to a dict writer, and making user whether to use aliases is passed through to the write method.
1 parent 56d2432 commit a185278

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

pydantic_csv/basemodel_csv_writer.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import csv
66
from collections.abc import Iterable
7+
from itertools import chain
78
from typing import Any
89

910
import pydantic
@@ -43,24 +44,30 @@ def __init__(
4344
self._model = model
4445
self._field_mapping: dict[str, str] = {}
4546

46-
fields = {name: field for name, field in self._model.model_fields.items() if not (field.exclude or False)}
47+
fields = {
48+
name: field
49+
for name, field in chain(self._model.model_fields.items(), self._model.model_computed_fields.items())
50+
if not (getattr(field, "exclude", False) or False)
51+
}
4752

48-
if use_alias:
53+
self._use_alias = use_alias
54+
55+
if self._use_alias:
4956
self._fieldnames = [field.alias or name for name, field in fields.items()]
5057
else:
5158
self._fieldnames = fields.keys()
5259

53-
self._writer = csv.writer(file_obj, dialect=dialect, **kwargs)
60+
self._writer = csv.DictWriter(file_obj, self._fieldnames, dialect=dialect, **kwargs)
5461

5562
def _add_to_mapping(self, header: str, fieldname: str) -> None:
5663
self._field_mapping[fieldname] = header
5764

58-
def _apply_mapping(self) -> list[str]:
59-
mapped_fields = []
65+
def _apply_mapping(self) -> dict[str, str]:
66+
mapped_fields = {}
6067

6168
for field in self._fieldnames:
6269
mapped_item = self._field_mapping.get(field, field)
63-
mapped_fields.append(mapped_item)
70+
mapped_fields[field] = mapped_item
6471

6572
return mapped_fields
6673

@@ -75,11 +82,12 @@ def write(self, skip_header: bool = False) -> None:
7582
Returns:
7683
None: well, nothing
7784
"""
85+
7886
if not skip_header:
7987
if self._field_mapping:
80-
self._fieldnames = self._apply_mapping()
81-
82-
self._writer.writerow(self._fieldnames)
88+
self._writer.writerow(self._apply_mapping())
89+
else:
90+
self._writer.writeheader()
8391

8492
for item in self._data:
8593
if not isinstance(item, self._model):
@@ -88,7 +96,7 @@ def write(self, skip_header: bool = False) -> None:
8896
f"{self._model.__name__}. All items on the list must be "
8997
"instances of the same type"
9098
)
91-
row = item.model_dump().values()
99+
row = item.model_dump(by_alias=self._use_alias)
92100
self._writer.writerow(row)
93101

94102
def map(self, fieldname: str) -> HeaderMapper:

0 commit comments

Comments
 (0)