Skip to content

Commit 777e04f

Browse files
authored
Add header support for CSV datasets (#784)
1 parent 6a6f295 commit 777e04f

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

awswrangler/s3/_write_text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
456456
if database and table:
457457
quoting: Optional[int] = csv.QUOTE_NONE
458458
escapechar: Optional[str] = "\\"
459-
header: Union[bool, List[str]] = False
459+
header: Union[bool, List[str]] = pandas_kwargs.get("header", False)
460460
date_format: Optional[str] = "%Y-%m-%d %H:%M:%S.%f"
461461
pd_kwargs: Dict[str, Any] = {}
462462
compression: Optional[str] = pandas_kwargs.get("compression", None)
@@ -529,7 +529,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
529529
catalog_table_input=catalog_table_input,
530530
catalog_id=catalog_id,
531531
compression=pandas_kwargs.get("compression"),
532-
skip_header_line_count=None,
532+
skip_header_line_count=True if header else None,
533533
serde_library=serde_library,
534534
serde_parameters=serde_parameters,
535535
)

tests/test_s3_text.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,52 @@ def test_csv(path):
119119
wr.s3.read_csv(path=paths, iterator=True)
120120

121121

122+
@pytest.mark.parametrize("header", [True, ["identifier"]])
123+
def test_csv_dataset_header(path, header, glue_database, glue_table):
124+
path0 = f"{path}test_csv_dataset0.csv"
125+
df0 = pd.DataFrame({"id": [1, 2, 3]})
126+
wr.s3.to_csv(
127+
df=df0,
128+
path=path0,
129+
dataset=True,
130+
database=glue_database,
131+
table=glue_table,
132+
index=False,
133+
header=header,
134+
)
135+
df1 = wr.s3.read_csv(path=path0)
136+
if isinstance(header, list):
137+
df0.columns = header
138+
assert df0.equals(df1)
139+
140+
141+
@pytest.mark.parametrize("mode", ["append", "overwrite"])
142+
def test_csv_dataset_header_modes(path, mode, glue_database, glue_table):
143+
path0 = f"{path}test_csv_dataset0.csv"
144+
dfs = [
145+
pd.DataFrame({"id": [1, 2, 3]}),
146+
pd.DataFrame({"id": [4, 5, 6]}),
147+
]
148+
for df in dfs:
149+
wr.s3.to_csv(
150+
df=df,
151+
path=path0,
152+
dataset=True,
153+
database=glue_database,
154+
table=glue_table,
155+
mode=mode,
156+
index=False,
157+
header=True,
158+
)
159+
dfs_conc = pd.concat(dfs)
160+
df_res = wr.s3.read_csv(path=path0)
161+
162+
if mode == "append":
163+
assert len(df_res) == len(dfs_conc)
164+
else:
165+
assert df_res.equals(dfs[-1])
166+
167+
122168
def test_json(path):
123169
df0 = pd.DataFrame({"id": [1, 2, 3]})
124170
path0 = f"{path}test_json0.json"

0 commit comments

Comments
 (0)