diff --git a/python/pyarrow/parquet/core.py b/python/pyarrow/parquet/core.py index aaf15c20288..a952ecec4a6 100644 --- a/python/pyarrow/parquet/core.py +++ b/python/pyarrow/parquet/core.py @@ -90,6 +90,69 @@ def _check_filters(filters, check_null_strings=True): return filters +def _map_spark_to_arrow_types(datatype: pa.DataType) -> str | None: + lookup = { + "NA": "null", + "BOOL": "boolean", + "INT8": "byte", + "INT16": "short", + **dict.fromkeys( + ["UINT8", "UINT16", "UINT32", "UINT64", "INT32"], "integer"), + "INT64": "long", + **dict.fromkeys(["HALF_FLOAT", "FLOAT"], "float"), + "DOUBLE": "double", + "BINARY": "binary", + "STRING": "string", + **dict.fromkeys( + ["DECIMAL" + str(2 ** i) for i in range(5, 9)], "decimal"), + **dict.fromkeys( + ["LIST", "LARGE_LIST", "LIST_VIEW", "LARGE_LIST_VIEW", "FIXED_SIZE_LIST"], + "array", + ), + "MAP": "map", + **dict.fromkeys(["DATE32", "DATE64"], "date"), + "TIMESTAMP": "timestamp", + "INTERVAL_MONTH_DAY_NANO": "Calendar Interval", # TODO: Correct this + } + + str_value = pa.types.TypesEnum(datatype.id).name + + try: + return lookup[str_value] + except KeyError: + return None + + +def _substitute_spark_metadata(schema: pa.Schema) -> dict: + metadata = schema.metadata + spark_key = b"org.apache.spark.sql.parquet.row.metadata" + + try: + spark_row_metadata = json.loads( + schema.metadata.pop(spark_key, None)) + except (TypeError, json.JSONDecodeError): # Could not convert Spark's row metadata + return metadata + + spark_fields = [field["name"] for field in spark_row_metadata["fields"]] + + for name in schema.names: + field = schema.field(name) + + if name in spark_fields: + continue + + spark_row_metadata["fields"].append({ + "name": field.name, + "type": _map_spark_to_arrow_types(field.type), + "nullable": field.nullable, + "metadata": {}, + }) + + metadata[spark_key] = json.dumps(spark_row_metadata).encode("utf-8") + + return metadata + + _DNF_filter_doc = """Predicates are expressed using an ``Expression`` or using the disjunctive normal form (DNF), like ``[[('x', '=', 0), ...], ...]``. DNF allows arbitrary boolean logical combinations of single column predicates. @@ -1954,6 +2017,11 @@ def write_table(table, where, row_group_size=None, version='2.6', # update it in write_to_dataset and _dataset_parquet.pyx ParquetFileWriteOptions row_group_size = kwargs.pop('chunk_size', row_group_size) use_int96 = use_deprecated_int96_timestamps + + if flavor == "spark" and b"org.apache.spark.sql.parquet.row.metadata" in table.schema.metadata: + new_metadata = _substitute_spark_metadata(table.schema) + table = table.replace_schema_metadata(new_metadata) + try: with ParquetWriter( where, table.schema,