Skip to content

Commit a42e68c

Browse files
committed
Add columns parameters to Pandas.to_csv()
1 parent 60ee9ae commit a42e68c

File tree

4 files changed

+108
-52
lines changed

4 files changed

+108
-52
lines changed

awswrangler/data_types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,8 @@ def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,
370370
:param indexes_position: "right" or "left"
371371
:return: Pyarrow schema (e.g. [("col name": "bigint"), ("col2 name": "int")]
372372
"""
373-
cols = []
374-
cols_dtypes = {}
373+
cols: List[str] = []
374+
cols_dtypes: Dict[str, str] = {}
375375
if indexes_position not in ("right", "left"):
376376
raise ValueError(f"indexes_position must be \"right\" or \"left\"")
377377

@@ -384,10 +384,10 @@ def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,
384384
cols.append(name)
385385

386386
# Filling cols_dtypes and indexes
387-
indexes = []
387+
indexes: List[str] = []
388388
for field in pa.Schema.from_pandas(df=dataframe[cols], preserve_index=preserve_index):
389389
name = str(field.name)
390-
dtype = field.type
390+
dtype = str(field.type)
391391
cols_dtypes[name] = dtype
392392
if name not in dataframe.columns:
393393
indexes.append(name)

awswrangler/glue.py

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Dict, Optional, Any, Iterator, List, Union
1+
from typing import TYPE_CHECKING, Dict, Optional, Any, Iterator, List, Union, Tuple
22
from math import ceil
33
from itertools import islice
44
import re
@@ -55,16 +55,16 @@ def get_table_python_types(self, database: str, table: str) -> Dict[str, Optiona
5555
def metadata_to_glue(self,
5656
dataframe,
5757
path: str,
58-
objects_paths,
59-
file_format,
60-
database=None,
61-
table=None,
62-
partition_cols=None,
63-
preserve_index=True,
58+
objects_paths: List[str],
59+
file_format: str,
60+
database: str,
61+
table: Optional[str],
62+
partition_cols: Optional[List[str]] = None,
63+
preserve_index: bool = True,
6464
mode: str = "append",
65-
compression=None,
66-
cast_columns=None,
67-
extra_args: Optional[Dict[str, Optional[Union[str, int]]]] = None,
65+
compression: Optional[str] = None,
66+
cast_columns: Optional[Dict[str, str]] = None,
67+
extra_args: Optional[Dict[str, Optional[Union[str, int, List[str]]]]] = None,
6868
description: Optional[str] = None,
6969
parameters: Optional[Dict[str, str]] = None,
7070
columns_comments: Optional[Dict[str, str]] = None) -> None:
@@ -88,6 +88,8 @@ def metadata_to_glue(self,
8888
:return: None
8989
"""
9090
indexes_position = "left" if file_format == "csv" else "right"
91+
schema: List[Tuple[str, str]]
92+
partition_cols_schema: List[Tuple[str, str]]
9193
schema, partition_cols_schema = Glue._build_schema(dataframe=dataframe,
9294
partition_cols=partition_cols,
9395
preserve_index=preserve_index,
@@ -138,14 +140,14 @@ def does_table_exists(self, database, table):
138140
return False
139141

140142
def create_table(self,
141-
database,
142-
table,
143-
schema,
144-
path,
145-
file_format,
146-
compression,
147-
partition_cols_schema=None,
148-
extra_args=None,
143+
database: str,
144+
table: str,
145+
schema: List[Tuple[str, str]],
146+
path: str,
147+
file_format: str,
148+
compression: Optional[str],
149+
partition_cols_schema: List[Tuple[str, str]],
150+
extra_args: Optional[Dict[str, Union[str, int, List[str], None]]] = None,
149151
description: Optional[str] = None,
150152
parameters: Optional[Dict[str, str]] = None,
151153
columns_comments: Optional[Dict[str, str]] = None) -> None:
@@ -166,13 +168,17 @@ def create_table(self,
166168
:return: None
167169
"""
168170
if file_format == "parquet":
169-
table_input = Glue.parquet_table_definition(table, partition_cols_schema, schema, path, compression)
171+
table_input: Dict[str, Any] = Glue.parquet_table_definition(table=table,
172+
partition_cols_schema=partition_cols_schema,
173+
schema=schema,
174+
path=path,
175+
compression=compression)
170176
elif file_format == "csv":
171-
table_input = Glue.csv_table_definition(table,
172-
partition_cols_schema,
173-
schema,
174-
path,
175-
compression,
177+
table_input = Glue.csv_table_definition(table=table,
178+
partition_cols_schema=partition_cols_schema,
179+
schema=schema,
180+
path=path,
181+
compression=compression,
176182
extra_args=extra_args)
177183
else:
178184
raise UnsupportedFileFormat(file_format)
@@ -223,19 +229,23 @@ def get_connection_details(self, name):
223229
return self._client_glue.get_connection(Name=name, HidePassword=False)["Connection"]
224230

225231
@staticmethod
226-
def _build_schema(dataframe, partition_cols, preserve_index, indexes_position, cast_columns=None):
232+
def _build_schema(
233+
dataframe,
234+
partition_cols: Optional[List[str]],
235+
preserve_index: bool,
236+
indexes_position: str,
237+
cast_columns: Optional[Dict[str, str]] = None) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]:
227238
if cast_columns is None:
228239
cast_columns = {}
229240
logger.debug(f"dataframe.dtypes:\n{dataframe.dtypes}")
230-
if not partition_cols:
241+
if partition_cols is None:
231242
partition_cols = []
232243

233-
pyarrow_schema = data_types.extract_pyarrow_schema_from_pandas(dataframe=dataframe,
234-
preserve_index=preserve_index,
235-
indexes_position=indexes_position)
244+
pyarrow_schema: List[Tuple[str, str]] = data_types.extract_pyarrow_schema_from_pandas(
245+
dataframe=dataframe, preserve_index=preserve_index, indexes_position=indexes_position)
236246

237-
schema_built = []
238-
partition_cols_types = {}
247+
schema_built: List[Tuple[str, str]] = []
248+
partition_cols_types: Dict[str, str] = {}
239249
for name, dtype in pyarrow_schema:
240250
if (cast_columns is not None) and (name in cast_columns.keys()):
241251
if name in partition_cols:
@@ -256,7 +266,7 @@ def _build_schema(dataframe, partition_cols, preserve_index, indexes_position, c
256266
else:
257267
schema_built.append((name, athena_type))
258268

259-
partition_cols_schema_built = [(name, partition_cols_types[name]) for name in partition_cols]
269+
partition_cols_schema_built: List = [(name, partition_cols_types[name]) for name in partition_cols]
260270

261271
logger.debug(f"schema_built:\n{schema_built}")
262272
logger.debug(f"partition_cols_schema_built:\n{partition_cols_schema_built}")
@@ -269,12 +279,12 @@ def parse_table_name(path):
269279
return path.rpartition("/")[2]
270280

271281
@staticmethod
272-
def csv_table_definition(table,
273-
partition_cols_schema,
274-
schema,
275-
path,
276-
compression,
277-
extra_args: Optional[Dict[str, Optional[Union[str, int]]]] = None):
282+
def csv_table_definition(table: str,
283+
partition_cols_schema: List[Tuple[str, str]],
284+
schema: List[Tuple[str, str]],
285+
path: str,
286+
compression: Optional[str],
287+
extra_args: Optional[Dict[str, Optional[Union[str, int, List[str]]]]] = None):
278288
if extra_args is None:
279289
extra_args = {"sep": ","}
280290
if partition_cols_schema is None:
@@ -301,6 +311,9 @@ def csv_table_definition(table,
301311
refined_schema = [(name, dtype) if dtype in dtypes_allowed else (name, "string") for name, dtype in schema]
302312
else:
303313
raise InvalidSerDe(f"{serde} in not in the valid SerDe list.")
314+
if "columns" in extra_args:
315+
refined_schema = [(name, dtype) for name, dtype in refined_schema
316+
if name in extra_args["columns"]] # type: ignore
304317
return {
305318
"Name": table,
306319
"PartitionKeys": [{
@@ -378,7 +391,8 @@ def csv_partition_definition(partition, compression, extra_args=None):
378391
}
379392

380393
@staticmethod
381-
def parquet_table_definition(table, partition_cols_schema, schema, path, compression):
394+
def parquet_table_definition(table: str, partition_cols_schema: List[Tuple[str, str]],
395+
schema: List[Tuple[str, str]], path: str, compression: Optional[str]):
382396
if not partition_cols_schema:
383397
partition_cols_schema = []
384398
compressed = False if compression is None else True

awswrangler/pandas.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,7 @@ def to_csv(self,
696696
path: str,
697697
sep: Optional[str] = None,
698698
na_rep: Optional[str] = None,
699+
columns: Optional[List[str]] = None,
699700
quoting: Optional[int] = None,
700701
escapechar: Optional[str] = None,
701702
serde: Optional[str] = "OpenCSVSerDe",
@@ -718,6 +719,7 @@ def to_csv(self,
718719
:param path: AWS S3 path (E.g. s3://bucket-name/folder_name/
719720
:param sep: Same as pandas.to_csv()
720721
:param na_rep: Same as pandas.to_csv()
722+
:param columns: Same as pandas.to_csv()
721723
:param quoting: Same as pandas.to_csv()
722724
:param escapechar: Same as pandas.to_csv()
723725
:param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe) (For Athena/Glue Catalog only)
@@ -738,9 +740,10 @@ def to_csv(self,
738740
raise InvalidSerDe(f"{serde} in not in the valid SerDe list ({Pandas.VALID_CSV_SERDES})")
739741
if (database is not None) and (serde is None):
740742
raise InvalidParameters(f"It is not possible write to a Glue Database without a SerDe.")
741-
extra_args: Dict[str, Optional[Union[str, int]]] = {
743+
extra_args: Dict[str, Optional[Union[str, int, List[str]]]] = {
742744
"sep": sep,
743745
"na_rep": na_rep,
746+
"columns": columns,
744747
"serde": serde,
745748
"escapechar": escapechar,
746749
"quoting": quoting
@@ -822,14 +825,14 @@ def to_s3(self,
822825
file_format: str,
823826
database: Optional[str] = None,
824827
table: Optional[str] = None,
825-
partition_cols=None,
826-
preserve_index=True,
828+
partition_cols: Optional[List[str]] = None,
829+
preserve_index: bool = True,
827830
mode: str = "append",
828-
compression=None,
829-
procs_cpu_bound=None,
830-
procs_io_bound=None,
831-
cast_columns=None,
832-
extra_args: Optional[Dict[str, Optional[Union[str, int]]]] = None,
831+
compression: Optional[str] = None,
832+
procs_cpu_bound: Optional[int] = None,
833+
procs_io_bound: Optional[int] = None,
834+
cast_columns: Optional[Dict[str, str]] = None,
835+
extra_args: Optional[Dict[str, Optional[Union[str, int, List[str]]]]] = None,
833836
inplace: bool = True,
834837
description: Optional[str] = None,
835838
parameters: Optional[Dict[str, str]] = None,
@@ -866,6 +869,8 @@ def to_s3(self,
866869
logger.debug(f"cast_columns: {cast_columns}")
867870
partition_cols = [Athena.normalize_column_name(x) for x in partition_cols]
868871
logger.debug(f"partition_cols: {partition_cols}")
872+
if extra_args is not None and "columns" in extra_args:
873+
extra_args["columns"] = [Athena.normalize_column_name(x) for x in extra_args["columns"]]
869874
dataframe = Pandas.drop_duplicated_columns(dataframe=dataframe, inplace=inplace)
870875
if compression is not None:
871876
compression = compression.lower()
@@ -1112,6 +1117,9 @@ def write_csv_dataframe(dataframe, path, preserve_index, compression, fs, extra_
11121117
sep = extra_args.get("sep")
11131118
if sep is not None:
11141119
csv_extra_args["sep"] = sep
1120+
columns = extra_args.get("columns")
1121+
if columns is not None:
1122+
csv_extra_args["columns"] = columns
11151123

11161124
serde = extra_args.get("serde")
11171125
if serde is None:
@@ -1519,7 +1527,10 @@ def _read_parquet_path(session_primitives: "SessionPrimitives",
15191527
fs.invalidate_cache()
15201528
table = pq.read_table(source=path, columns=columns, filters=filters, filesystem=fs, use_threads=use_threads)
15211529
# Check if we lose some integer during the conversion (Happens when has some null value)
1522-
integers = [field.name for field in table.schema if str(field.type).startswith("int") and field.name != "__index_level_0__"]
1530+
integers = [
1531+
field.name for field in table.schema
1532+
if str(field.type).startswith("int") and field.name != "__index_level_0__"
1533+
]
15231534
logger.debug(f"Converting to Pandas: {path}")
15241535
df = table.to_pandas(use_threads=use_threads, integer_object_nulls=True)
15251536
logger.debug(f"Casting Int64 columns: {path}")
@@ -1612,6 +1623,7 @@ def to_aurora(self,
16121623
temp_s3_path: Optional[str] = None,
16131624
preserve_index: bool = False,
16141625
mode: str = "append",
1626+
columns: Optional[List[str]] = None,
16151627
procs_cpu_bound: Optional[int] = None,
16161628
procs_io_bound: Optional[int] = None,
16171629
inplace=True) -> None:
@@ -1626,6 +1638,7 @@ def to_aurora(self,
16261638
:param temp_s3_path: S3 path to write temporary files (E.g. s3://BUCKET_NAME/ANY_NAME/)
16271639
:param preserve_index: Should we preserve the Dataframe index?
16281640
:param mode: append or overwrite
1641+
:param columns: List of columns to load
16291642
:param procs_cpu_bound: Number of cores used for CPU bound tasks
16301643
:param procs_io_bound: Number of cores used for I/O bound tasks
16311644
:param inplace: True is cheapest (CPU and Memory) but False leaves your DataFrame intact
@@ -1654,6 +1667,7 @@ def to_aurora(self,
16541667
serde=None,
16551668
sep=",",
16561669
na_rep=na_rep,
1670+
columns=columns,
16571671
quoting=csv.QUOTE_MINIMAL,
16581672
escapechar="\"",
16591673
preserve_index=preserve_index,

testing/test_awswrangler/test_pandas.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2185,6 +2185,7 @@ def test_to_parquet_categorical_partitions(bucket):
21852185
x['Year'] = x['Year'].astype('category')
21862186
wr.pandas.to_parquet(x[x.Year == 1990], path=path, partition_cols=["Year"])
21872187
y = wr.pandas.read_parquet(path=path)
2188+
wr.s3.delete_objects(path=path)
21882189
assert len(x[x.Year == 1990].index) == len(y.index)
21892190

21902191

@@ -2197,5 +2198,32 @@ def test_range_index(bucket, database):
21972198
print(x)
21982199
wr.pandas.to_parquet(dataframe=x, path=path, database=database)
21992200
df = wr.pandas.read_parquet(path=path)
2201+
wr.s3.delete_objects(path=path)
22002202
assert len(x.columns) == len(df.columns)
22012203
assert len(x.index) == len(df.index)
2204+
2205+
2206+
def test_to_csv_columns(bucket, database):
2207+
path = f"s3://{bucket}/test_to_csv_columns"
2208+
wr.s3.delete_objects(path=path)
2209+
df = pd.DataFrame({
2210+
"A": [1, 2, 3],
2211+
"B": [4, 5, 6],
2212+
"C": ["foo", "boo", "bar"]
2213+
})
2214+
wr.s3.delete_objects(path=path)
2215+
wr.pandas.to_csv(
2216+
dataframe=df,
2217+
database=database,
2218+
path=path,
2219+
columns=["A", "B"],
2220+
mode="overwrite",
2221+
preserve_index=False,
2222+
procs_cpu_bound=1,
2223+
inplace=False
2224+
)
2225+
sleep(10)
2226+
df2 = wr.pandas.read_sql_athena(database=database, sql="SELECT * FROM test_to_csv_columns")
2227+
wr.s3.delete_objects(path=path)
2228+
assert len(df.columns) == len(df2.columns) + 1
2229+
assert len(df.index) == len(df2.index)

0 commit comments

Comments
 (0)