Skip to content

Commit 78e0a72

Browse files
crafty-coderHyukjinKwon
authored andcommitted
[SPARK-19018][SQL] Add support for custom encoding on csv writer
## What changes were proposed in this pull request? Add support for custom encoding on csv writer, see https://issues.apache.org/jira/browse/SPARK-19018 ## How was this patch tested? Added two unit tests in CSVSuite Author: crafty-coder <[email protected]> Author: Carlos <[email protected]> Closes apache#20949 from crafty-coder/master.
1 parent afb0627 commit 78e0a72

File tree

4 files changed

+50
-4
lines changed

4 files changed

+50
-4
lines changed

python/pyspark/sql/readwriter.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ def text(self, path, compression=None, lineSep=None):
859859
def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None,
860860
header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None,
861861
timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None,
862-
charToEscapeQuoteEscaping=None):
862+
charToEscapeQuoteEscaping=None, encoding=None):
863863
"""Saves the content of the :class:`DataFrame` in CSV format at the specified path.
864864
865865
:param path: the path in any Hadoop supported file system
@@ -909,6 +909,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
909909
the quote character. If None is set, the default value is
910910
escape character when escape and quote characters are
911911
different, ``\0`` otherwise..
912+
:param encoding: sets the encoding (charset) of saved csv files. If None is set,
913+
the default UTF-8 charset will be used.
912914
913915
>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
914916
"""
@@ -918,7 +920,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
918920
dateFormat=dateFormat, timestampFormat=timestampFormat,
919921
ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
920922
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace,
921-
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping)
923+
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping,
924+
encoding=encoding)
922925
self._jwrite.csv(path)
923926

924927
@since(1.5)

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
629629
* enclosed in quotes. Default is to only escape values containing a quote character.</li>
630630
* <li>`header` (default `false`): writes the names of columns as the first line.</li>
631631
* <li>`nullValue` (default empty string): sets the string representation of a null value.</li>
632+
* <li>`encoding` (by default it is not set): specifies encoding (charset) of saved csv
633+
* files. If it is not set, the UTF-8 charset will be used.</li>
632634
* <li>`compression` (default `null`): compression codec to use when saving to file. This can be
633635
* one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`,
634636
* `snappy` and `deflate`). </li>

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.datasources.csv
1919

20+
import java.nio.charset.Charset
21+
2022
import org.apache.hadoop.conf.Configuration
2123
import org.apache.hadoop.fs.{FileStatus, Path}
2224
import org.apache.hadoop.mapreduce._
@@ -168,7 +170,9 @@ private[csv] class CsvOutputWriter(
168170
context: TaskAttemptContext,
169171
params: CSVOptions) extends OutputWriter with Logging {
170172

171-
private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
173+
private val charset = Charset.forName(params.charset)
174+
175+
private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path), charset)
172176

173177
private val gen = new UnivocityGenerator(dataSchema, writer, params)
174178

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
package org.apache.spark.sql.execution.datasources.csv
1919

2020
import java.io.File
21-
import java.nio.charset.UnsupportedCharsetException
21+
import java.nio.charset.{Charset, UnsupportedCharsetException}
22+
import java.nio.file.Files
2223
import java.sql.{Date, Timestamp}
2324
import java.text.SimpleDateFormat
2425
import java.util.Locale
2526

2627
import scala.collection.JavaConverters._
28+
import scala.util.Properties
2729

2830
import org.apache.commons.lang3.time.FastDateFormat
2931
import org.apache.hadoop.io.SequenceFile.CompressionType
@@ -514,6 +516,41 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
514516
}
515517
}
516518

519+
test("SPARK-19018: Save csv with custom charset") {
520+
521+
// scalastyle:off nonascii
522+
val content = "µß áâä ÁÂÄ"
523+
// scalastyle:on nonascii
524+
525+
Seq("iso-8859-1", "utf-8", "utf-16", "utf-32", "windows-1250").foreach { encoding =>
526+
withTempPath { path =>
527+
val csvDir = new File(path, "csv")
528+
Seq(content).toDF().write
529+
.option("encoding", encoding)
530+
.csv(csvDir.getCanonicalPath)
531+
532+
csvDir.listFiles().filter(_.getName.endsWith("csv")).foreach({ csvFile =>
533+
val readback = Files.readAllBytes(csvFile.toPath)
534+
val expected = (content + Properties.lineSeparator).getBytes(Charset.forName(encoding))
535+
assert(readback === expected)
536+
})
537+
}
538+
}
539+
}
540+
541+
test("SPARK-19018: error handling for unsupported charsets") {
542+
val exception = intercept[SparkException] {
543+
withTempPath { path =>
544+
val csvDir = new File(path, "csv").getCanonicalPath
545+
Seq("a,A,c,A,b,B").toDF().write
546+
.option("encoding", "1-9588-osi")
547+
.csv(csvDir)
548+
}
549+
}
550+
551+
assert(exception.getCause.getMessage.contains("1-9588-osi"))
552+
}
553+
517554
test("commented lines in CSV data") {
518555
Seq("false", "true").foreach { multiLine =>
519556

0 commit comments

Comments
 (0)