4
4
5
5
import csv
6
6
from collections .abc import Iterable
7
+ from itertools import chain
7
8
from typing import Any
8
9
9
10
import pydantic
@@ -43,24 +44,30 @@ def __init__(
43
44
self ._model = model
44
45
self ._field_mapping : dict [str , str ] = {}
45
46
46
- fields = {name : field for name , field in self ._model .model_fields .items () if not (field .exclude or False )}
47
+ fields = {
48
+ name : field
49
+ for name , field in chain (self ._model .model_fields .items (), self ._model .model_computed_fields .items ())
50
+ if not (getattr (field , "exclude" , False ) or False )
51
+ }
47
52
48
- if use_alias :
53
+ self ._use_alias = use_alias
54
+
55
+ if self ._use_alias :
49
56
self ._fieldnames = [field .alias or name for name , field in fields .items ()]
50
57
else :
51
58
self ._fieldnames = fields .keys ()
52
59
53
- self ._writer = csv .writer (file_obj , dialect = dialect , ** kwargs )
60
+ self ._writer = csv .DictWriter (file_obj , self . _fieldnames , dialect = dialect , ** kwargs )
54
61
55
62
def _add_to_mapping (self , header : str , fieldname : str ) -> None :
56
63
self ._field_mapping [fieldname ] = header
57
64
58
- def _apply_mapping (self ) -> list [ str ]:
59
- mapped_fields = []
65
+ def _apply_mapping (self ) -> dict [ str , str ]:
66
+ mapped_fields = {}
60
67
61
68
for field in self ._fieldnames :
62
69
mapped_item = self ._field_mapping .get (field , field )
63
- mapped_fields . append ( mapped_item )
70
+ mapped_fields [ field ] = mapped_item
64
71
65
72
return mapped_fields
66
73
@@ -75,11 +82,12 @@ def write(self, skip_header: bool = False) -> None:
75
82
Returns:
76
83
None: well, nothing
77
84
"""
85
+
78
86
if not skip_header :
79
87
if self ._field_mapping :
80
- self ._fieldnames = self ._apply_mapping ()
81
-
82
- self ._writer .writerow ( self . _fieldnames )
88
+ self ._writer . writerow ( self ._apply_mapping () )
89
+ else :
90
+ self ._writer .writeheader ( )
83
91
84
92
for item in self ._data :
85
93
if not isinstance (item , self ._model ):
@@ -88,7 +96,7 @@ def write(self, skip_header: bool = False) -> None:
88
96
f"{ self ._model .__name__ } . All items on the list must be "
89
97
"instances of the same type"
90
98
)
91
- row = item .model_dump (). values ( )
99
+ row = item .model_dump (by_alias = self . _use_alias )
92
100
self ._writer .writerow (row )
93
101
94
102
def map (self , fieldname : str ) -> HeaderMapper :
0 commit comments