Skip to content

Commit 75b1df8

Browse files
fix: Write all record batches to the same file without overwriting rows (#195)
1 parent b639b04 commit 75b1df8

File tree

6 files changed

+93
-56
lines changed

6 files changed

+93
-56
lines changed

poetry.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ python = ">=3.8"
2929
pytz = "~=2025.1"
3030
singer-sdk = "~=0.42.1"
3131

32-
[tool.poetry.dev-dependencies]
32+
[tool.poetry.group.dev.dependencies]
3333
pytest = "~=8.3"
3434

3535
[tool.poetry.scripts]
@@ -45,8 +45,6 @@ target-version = "py38"
4545

4646
[tool.ruff.lint]
4747
ignore = [
48-
"ANN101", # Missing type annotation for `self` in method
49-
"ANN102", # Missing type annotation for `cls` in class method
5048
"ANN401", # Allow `typing.Any` as parameter type
5149
]
5250
select = [

target_csv/serialization.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,20 @@
1-
import csv # noqa: D100
2-
import sys
3-
from pathlib import Path
4-
from typing import Any, List, Callable, TypeVar
5-
6-
if sys.version_info < (3, 10):
7-
from typing_extensions import Concatenate, ParamSpec
8-
else:
9-
from typing import Concatenate, ParamSpec
10-
11-
P = ParamSpec("P")
12-
T = TypeVar("T")
13-
1+
"""Serialization utilities for CSV files."""
142

15-
def create_folder_if_not_exists(
16-
func: Callable[Concatenate[Path, P], T],
17-
) -> Callable[Concatenate[Path, P], T]:
18-
"""Decorator to create folder if it does not exist."""
3+
from __future__ import annotations
194

20-
def wrapper(filepath: Path, *args: P.args, **kwargs: P.kwargs) -> T:
21-
filepath.parent.mkdir(parents=True, exist_ok=True)
22-
return func(filepath, *args, **kwargs)
23-
24-
return wrapper
5+
import csv # noqa: D100
6+
import tempfile
7+
from pathlib import Path
8+
from typing import Any
259

2610

27-
@create_folder_if_not_exists
28-
def write_csv(filepath: Path, records: List[dict], schema: dict, **kwargs: Any) -> int:
11+
def write_csv(
12+
filepath: Path,
13+
records: list[dict],
14+
keys: list[str],
15+
**kwargs: Any,
16+
) -> int:
2917
"""Write a CSV file."""
30-
if "properties" not in schema:
31-
raise ValueError("Stream's schema has no properties defined.")
32-
33-
keys: List[str] = list(schema["properties"].keys())
3418
with open(filepath, "w", encoding="utf-8", newline="") as fp:
3519
writer = csv.DictWriter(fp, fieldnames=keys, dialect="excel", **kwargs)
3620
writer.writeheader()
@@ -40,9 +24,37 @@ def write_csv(filepath: Path, records: List[dict], schema: dict, **kwargs: Any)
4024
return record_count
4125

4226

43-
def read_csv(filepath: Path) -> List[dict]:
27+
def write_header(filepath: Path, keys: list[str], **kwargs: Any) -> None:
28+
"""Write a header to a CSV file.
29+
30+
Creates the parent directory if it doesn't exist.
31+
"""
32+
filepath.parent.mkdir(parents=True, exist_ok=True)
33+
with filepath.open("w", encoding="utf-8", newline="") as fp:
34+
writer = csv.DictWriter(fp, fieldnames=keys, **kwargs)
35+
writer.writeheader()
36+
37+
38+
def write_batch(
39+
filepath: Path,
40+
records: list[dict],
41+
keys: list[str],
42+
**kwargs: Any,
43+
) -> None:
44+
"""Write a batch of records to a CSV file."""
45+
with tempfile.NamedTemporaryFile("w+", encoding="utf-8", newline="") as tmp_fp:
46+
writer = csv.DictWriter(tmp_fp, fieldnames=keys, **kwargs)
47+
writer.writerows(records)
48+
49+
tmp_fp.seek(0)
50+
51+
with filepath.open("a") as f:
52+
f.write(tmp_fp.read())
53+
54+
55+
def read_csv(filepath: Path) -> list[dict]:
4456
"""Read a CSV file."""
45-
result: List[dict] = []
57+
result: list[dict] = []
4658
with open(filepath, newline="") as fp:
4759
reader = csv.DictReader(fp, delimiter=",", dialect="excel")
4860
result.extend(iter(reader))

target_csv/sinks.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""CSV target sink class, which handles writing streams."""
22

3+
from __future__ import annotations
4+
35
import datetime
6+
import functools
47
import sys
58
import warnings
69
from pathlib import Path
@@ -10,7 +13,7 @@
1013
from singer_sdk import Target
1114
from singer_sdk.sinks import BatchSink
1215

13-
from target_csv.serialization import write_csv
16+
from target_csv.serialization import write_batch, write_header
1417

1518

1619
class CSVSink(BatchSink):
@@ -77,32 +80,50 @@ def output_file(self) -> Path: # noqa: D102
7780

7881
return filepath
7982

83+
@functools.cached_property
84+
def keys(self) -> list[str]:
85+
"""Get the header keys for the CSV file."""
86+
if "properties" not in self.schema:
87+
raise ValueError("Stream's schema has no properties defined")
88+
89+
return list(self.schema["properties"].keys())
90+
91+
@functools.cached_property
92+
def escape_character(self) -> str | None:
93+
"""Get the escape character for the CSV file."""
94+
return self.config.get("escape_character")
95+
96+
def setup(self) -> None:
97+
"""Create the output file and write the header."""
98+
super().setup()
99+
output_file = self.output_file
100+
self.logger.info("Writing to destination file '%s'...", output_file.resolve())
101+
write_header(
102+
output_file,
103+
self.keys,
104+
dialect="excel",
105+
escapechar=self.escape_character,
106+
)
107+
80108
def process_batch(self, context: dict) -> None:
81109
"""Write out any prepped records and return once fully written."""
82110
output_file: Path = self.output_file
83-
self.logger.info(f"Writing to destination file '{output_file.resolve()}'...")
84-
new_contents: dict # noqa: F842
85-
create_new = (
86-
self.config["overwrite_behavior"] == "replace_file"
87-
or not output_file.exists()
88-
)
89-
if not create_new:
90-
raise NotImplementedError("Append mode is not yet supported.")
91111

92112
if not isinstance(context["records"], list):
93-
self.logger.warning(f"No values in {self.stream_name} records collection.")
113+
self.logger.warning("No values in %s records collection.", self.stream_name)
94114
context["records"] = []
95115

96116
records: List[Dict[str, Any]] = context["records"]
97117
if "record_sort_property_name" in self.config:
98118
sort_property_name = self.config["record_sort_property_name"]
99119
records = sorted(records, key=lambda x: x[sort_property_name])
100120

101-
self.logger.info(f"Writing {len(context['records'])} records to file...")
121+
self.logger.info(f"Appending {len(records)} records to file...")
102122

103-
write_csv(
123+
write_batch(
104124
output_file,
105-
context["records"],
106-
self.schema,
107-
escapechar=self.config.get("escape_character"),
125+
records,
126+
self.keys,
127+
dialect="excel",
128+
escapechar=self.escape_character,
108129
)

target_csv/target.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class TargetCSV(Target):
107107
),
108108
th.Property(
109109
"escape_character",
110-
th.StringType,
110+
th.StringType(min_length=1, max_length=1),
111111
description="The character to use for escaping special characters.",
112112
),
113113
).to_dict()

tests/test_csv.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77

8-
from target_csv.serialization import read_csv, write_csv
8+
from target_csv.serialization import read_csv, write_batch, write_header
99

1010
SAMPLE_DATASETS: List[Tuple[Dict, List[Dict[str, Any]]]] = [
1111
(
@@ -70,18 +70,24 @@ def test_file_paths(output_dir) -> List[Path]:
7070

7171
def test_csv_write(output_filepath) -> None:
7272
for schema, records in SAMPLE_DATASETS:
73-
write_csv(filepath=output_filepath, records=records, schema=schema)
73+
keys = list(schema["properties"].keys())
74+
write_header(filepath=output_filepath, keys=keys)
75+
write_batch(filepath=output_filepath, records=records, keys=keys)
7476

7577

7678
def test_csv_write_if_not_exists(test_file_paths) -> None:
7779
for path in test_file_paths:
7880
for schema, records in SAMPLE_DATASETS:
79-
write_csv(filepath=path, records=records, schema=schema)
81+
keys = list(schema["properties"].keys())
82+
write_header(filepath=path, keys=keys)
83+
write_batch(filepath=path, records=records, keys=keys)
8084

8185

8286
def test_csv_roundtrip(output_filepath) -> None:
8387
for schema, records in SAMPLE_DATASETS:
84-
write_csv(filepath=output_filepath, records=records, schema=schema)
88+
keys = list(schema["properties"].keys())
89+
write_header(filepath=output_filepath, keys=keys)
90+
write_batch(filepath=output_filepath, records=records, keys=keys)
8591
read_records = read_csv(filepath=output_filepath)
8692
for orig_record, new_record in zip(records, read_records):
8793
for key in orig_record.keys():

0 commit comments

Comments
 (0)