diff --git a/pydantic_csv/basemodel_csv_writer.py b/pydantic_csv/basemodel_csv_writer.py index d7210d5..eaa6b6c 100644 --- a/pydantic_csv/basemodel_csv_writer.py +++ b/pydantic_csv/basemodel_csv_writer.py @@ -4,6 +4,7 @@ import csv from collections.abc import Iterable +from itertools import chain from typing import Any import pydantic @@ -43,22 +44,32 @@ def __init__( self._model = model self._field_mapping: dict[str, str] = {} - if use_alias: - self._fieldnames = [field.alias or name for name, field in self._model.model_fields.items()] + fields = { + name: field + for name, field in chain(self._model.model_fields.items(), self._model.model_computed_fields.items()) + if not (getattr(field, "exclude", False) or False) + } + + self._use_alias = use_alias + + if self._use_alias: + self._fieldnames = [ + field.alias or getattr(field, "serialization_alias", None) or name for name, field in fields.items() + ] else: - self._fieldnames = model.model_fields.keys() + self._fieldnames = fields.keys() - self._writer = csv.writer(file_obj, dialect=dialect, **kwargs) + self._writer = csv.DictWriter(file_obj, self._fieldnames, dialect=dialect, **kwargs) def _add_to_mapping(self, header: str, fieldname: str) -> None: self._field_mapping[fieldname] = header - def _apply_mapping(self) -> list[str]: - mapped_fields = [] + def _apply_mapping(self) -> dict[str, str]: + mapped_fields = {} for field in self._fieldnames: mapped_item = self._field_mapping.get(field, field) - mapped_fields.append(mapped_item) + mapped_fields[field] = mapped_item return mapped_fields @@ -73,11 +84,12 @@ def write(self, skip_header: bool = False) -> None: Returns: None: well, nothing """ + if not skip_header: if self._field_mapping: - self._fieldnames = self._apply_mapping() - - self._writer.writerow(self._fieldnames) + self._writer.writerow(self._apply_mapping()) + else: + self._writer.writeheader() for item in self._data: if not isinstance(item, self._model): @@ -86,7 +98,7 @@ def write(self, skip_header: bool = False) -> None: f"{self._model.__name__}. All items on the list must be " "instances of the same type" ) - row = item.model_dump().values() + row = item.model_dump(by_alias=self._use_alias) self._writer.writerow(row) def map(self, fieldname: str) -> HeaderMapper: diff --git a/pyproject.toml b/pyproject.toml index 97a3365..05f3400 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pydantic-csv" -version = "0.1.0" +version = "0.1.2" description = "convert CSV to pydantic.BaseModel and vice versa" authors = ["Nathan Richard "] license = "LICENSE" diff --git a/tests/models.py b/tests/models.py index bce8c3c..792854a 100644 --- a/tests/models.py +++ b/tests/models.py @@ -2,7 +2,7 @@ from typing import Optional import pydantic -from pydantic import Field +from pydantic import Field, computed_field class User(pydantic.BaseModel): @@ -51,3 +51,25 @@ def parse_start_date(cls, value): @pydantic.field_validator("end", mode="before") def parse_end_date(cls, value): return datetime.strptime(value, "%d.%m.%Y").date() + + +class ExcludedPassword(pydantic.BaseModel): + username: str = "Wagstaff" + password: str = Field(default="swordfish", exclude=True) + email: str = Field(default="wagstaff@marx.bros", serialization_alias="contact") + + +class ComputedPropertyField(pydantic.BaseModel): + username: str = "Groucho" + + @computed_field + def email(self) -> str: + return f"{self.username.lower()}@marx.bros" + + +class ComputedPropertyWithAlias(pydantic.BaseModel): + username: str = "Harpo" + + @computed_field(alias="e") + def email(self) -> str: + return f"{self.username.lower()}@marx.bros" diff --git a/tests/test_basemodel_csv_writer.py b/tests/test_basemodel_csv_writer.py index aea54bc..eac58d5 100644 --- a/tests/test_basemodel_csv_writer.py +++ b/tests/test_basemodel_csv_writer.py @@ -4,7 +4,14 @@ from pydantic_csv import BasemodelCSVWriter -from .models import NonBaseModelUser, SimpleUser, User +from .models import ( + ComputedPropertyField, + ComputedPropertyWithAlias, + ExcludedPassword, + NonBaseModelUser, + SimpleUser, + User, +) def test_create_csv_file(users_as_csv_buffer, users_from_csv): @@ -50,3 +57,32 @@ def test_with_wrong_type_in_list(user_list): def test_header_mapping(users_mapped_as_csv_buffer, users_mapped_from_csv): assert users_mapped_as_csv_buffer == users_mapped_from_csv + + +def test_excluded_field(): + output = io.StringIO() + user = ExcludedPassword() + + w = BasemodelCSVWriter(output, [user], ExcludedPassword) + w.write() + + assert output.getvalue() == "username,contact\r\nWagstaff,wagstaff@marx.bros\r\n" + + +@pytest.mark.parametrize( + ("model", "use_alias", "expected_output"), + [ + (ComputedPropertyField, True, "username,email\r\nGroucho,groucho@marx.bros\r\n"), + (ComputedPropertyWithAlias, True, "username,e\r\nHarpo,harpo@marx.bros\r\n"), + (ComputedPropertyField, False, "username,email\r\nGroucho,groucho@marx.bros\r\n"), + (ComputedPropertyWithAlias, False, "username,email\r\nHarpo,harpo@marx.bros\r\n"), + ], +) +def test_computed_property_included(model, use_alias, expected_output): + output = io.StringIO() + user = model() + + w = BasemodelCSVWriter(output, [user], model, use_alias=use_alias) + w.write() + + assert output.getvalue() == expected_output