Skip to content

Commit 0acebf2

Browse files
committed
Rename "encoding" to "charset" for clarity, add in missing functions,
make schema inference charset-aware
1 parent ae4e659 commit 0acebf2

File tree

12 files changed

+130
-43
lines changed

12 files changed

+130
-43
lines changed

datafusion/common/src/config.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2921,7 +2921,7 @@ config_namespace! {
29212921
///
29222922
/// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting.
29232923
pub newlines_in_values: Option<bool>, default = None
2924-
pub encoding: Option<String>, default = None
2924+
pub charset: Option<String>, default = None
29252925
pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED
29262926
/// Compression level for the output file. The valid range depends on the
29272927
/// compression algorithm:
@@ -3034,6 +3034,13 @@ impl CsvOptions {
30343034
self
30353035
}
30363036

3037+
/// Specifies the character encoding the file is encoded with.
3038+
/// - defaults to UTF-8
3039+
pub fn with_charset(mut self, charset: impl Into<String>) -> Self {
3040+
self.charset = Some(charset.into());
3041+
self
3042+
}
3043+
30373044
/// Set a `CompressionTypeVariant` of CSV
30383045
/// - defaults to `CompressionTypeVariant::UNCOMPRESSED`
30393046
pub fn with_file_compression_type(

datafusion/core/src/datasource/file_format/csv.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1646,7 +1646,7 @@ mod tests {
16461646

16471647
// Read the file
16481648
let ctx = SessionContext::new();
1649-
let opts = CsvReadOptions::new().has_header(true);
1649+
let opts = CsvReadOptions::new().has_header(true).charset("SHIFT-JIS");
16501650
let batches = ctx.read_csv(path, opts).await?.collect().await?;
16511651

16521652
// Check

datafusion/core/src/datasource/file_format/options.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ pub struct CsvReadOptions<'a> {
8585
pub file_extension: &'a str,
8686
/// Partition Columns
8787
pub table_partition_cols: Vec<(String, DataType)>,
88+
/// Character encoding
89+
pub charset: Option<&'a str>,
8890
/// File compression type
8991
pub file_compression_type: FileCompressionType,
9092
/// Indicates how the file is sorted
@@ -118,6 +120,7 @@ impl<'a> CsvReadOptions<'a> {
118120
newlines_in_values: false,
119121
file_extension: DEFAULT_CSV_EXTENSION,
120122
table_partition_cols: vec![],
123+
charset: None,
121124
file_compression_type: FileCompressionType::UNCOMPRESSED,
122125
file_sort_order: vec![],
123126
comment: None,
@@ -209,6 +212,12 @@ impl<'a> CsvReadOptions<'a> {
209212
self
210213
}
211214

215+
/// Configure the character set encoding
216+
pub fn charset(mut self, charset: &'a str) -> Self {
217+
self.charset = Some(charset);
218+
self
219+
}
220+
212221
/// Configure file compression type
213222
pub fn file_compression_type(
214223
mut self,
@@ -633,6 +642,7 @@ impl ReadOptions<'_> for CsvReadOptions<'_> {
633642
.with_terminator(self.terminator)
634643
.with_newlines_in_values(self.newlines_in_values)
635644
.with_schema_infer_max_rec(self.schema_infer_max_records)
645+
.with_charset(self.charset.map(ToOwned::to_owned))
636646
.with_file_compression_type(self.file_compression_type.to_owned())
637647
.with_null_regex(self.null_regex.clone())
638648
.with_truncated_rows(self.truncated_rows);
Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,31 @@
1616
// under the License.
1717

1818
use std::fmt::Debug;
19+
use std::io::{BufRead, Read};
1920

2021
use arrow::array::RecordBatch;
2122
use arrow::error::ArrowError;
22-
use datafusion_common::Result;
23+
use datafusion_common::{DataFusionError, Result};
2324
use datafusion_datasource::decoder::Decoder;
24-
use encoding_rs::{CoderResult, Encoding};
25+
use encoding_rs::{CoderResult, Encoding, UTF_8};
2526

2627
use self::buffer::Buffer;
2728

2829
/// Default capacity of the buffer used to decode non-UTF-8 charset streams
2930
static DECODE_BUFFER_CAP: usize = 8 * 1024;
3031

32+
pub fn lookup_charset(enc: Option<&str>) -> Result<Option<&'static Encoding>> {
33+
match enc {
34+
Some(enc) => match Encoding::for_label(enc.as_bytes()) {
35+
Some(enc) => Ok(Some(enc).filter(|enc| *enc != UTF_8)),
36+
None => Err(DataFusionError::Configuration(format!(
37+
"Unknown character set '{enc}'"
38+
)))?,
39+
},
40+
None => Ok(None),
41+
}
42+
}
43+
3144
/// A `Decoder` that decodes input bytes from the specified character encoding
3245
/// to UTF-8 before passing them onto the inner `Decoder`.
3346
pub struct CharsetDecoder<T> {
@@ -100,6 +113,54 @@ impl<T: Debug> Debug for CharsetDecoder<T> {
100113
}
101114
}
102115

116+
pub struct CharsetReader<R> {
117+
inner: R,
118+
charset_decoder: encoding_rs::Decoder,
119+
buffer: Buffer,
120+
}
121+
122+
impl<R: BufRead> CharsetReader<R> {
123+
pub fn new(inner: R, encoding: &'static Encoding) -> Self {
124+
Self {
125+
inner,
126+
charset_decoder: encoding.new_decoder(),
127+
buffer: Buffer::with_capacity(DECODE_BUFFER_CAP),
128+
}
129+
}
130+
}
131+
132+
impl<R: BufRead> Read for CharsetReader<R> {
133+
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
134+
let src = self.fill_buf()?;
135+
let len = src.len().min(buf.len());
136+
buf[..len].copy_from_slice(&src[..len]);
137+
Ok(len)
138+
}
139+
}
140+
141+
impl<R: BufRead> BufRead for CharsetReader<R> {
142+
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
143+
if self.buffer.is_empty() {
144+
self.buffer.backshift();
145+
146+
let buf = self.inner.fill_buf()?;
147+
let (_, read, written, _) = self.charset_decoder.decode_to_utf8(
148+
buf,
149+
self.buffer.write_buf(),
150+
buf.is_empty(),
151+
);
152+
self.inner.consume(read);
153+
self.buffer.advance(written);
154+
}
155+
156+
Ok(self.buffer.read_buf())
157+
}
158+
159+
fn consume(&mut self, amount: usize) {
160+
self.buffer.consume(amount);
161+
}
162+
}
163+
103164
mod buffer {
104165
/// A fixed-sized buffer that maintains both
105166
/// a read position and a write position

datafusion/datasource-csv/src/file_format.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use std::collections::{HashMap, HashSet};
2222
use std::fmt::{self, Debug};
2323
use std::sync::Arc;
2424

25+
use crate::charset::lookup_charset;
2526
use crate::source::CsvSource;
2627

2728
use arrow::array::RecordBatch;
@@ -294,6 +295,13 @@ impl CsvFormat {
294295
self
295296
}
296297

298+
/// Sets the character encoding of the CSV.
299+
/// Defaults to UTF-8 if unspecified.
300+
pub fn with_charset(mut self, charset: Option<String>) -> Self {
301+
self.options.charset = charset;
302+
self
303+
}
304+
297305
/// Set a `FileCompressionType` of CSV
298306
/// - defaults to `FileCompressionType::UNCOMPRESSED`
299307
pub fn with_file_compression_type(
@@ -540,6 +548,8 @@ impl CsvFormat {
540548

541549
pin_mut!(stream);
542550

551+
let charset = lookup_charset(self.options.charset.as_deref())?;
552+
543553
while let Some(chunk) = stream.next().await.transpose()? {
544554
record_number += 1;
545555
let first_chunk = record_number == 0;
@@ -569,8 +579,15 @@ impl CsvFormat {
569579
format = format.with_comment(comment);
570580
}
571581

572-
let (Schema { fields, .. }, records_read) =
573-
format.infer_schema(chunk.reader(), Some(records_to_read))?;
582+
let (Schema { fields, .. }, records_read) = match charset {
583+
#[cfg(feature = "encoding_rs")]
584+
Some(enc) => {
585+
use crate::charset::CharsetReader;
586+
let reader = CharsetReader::new(chunk.reader(), enc);
587+
format.infer_schema(reader, Some(records_to_read))?
588+
}
589+
None => format.infer_schema(chunk.reader(), Some(records_to_read))?,
590+
};
574591

575592
records_to_read -= records_read;
576593
total_records_read += records_read;

datafusion/datasource-csv/src/mod.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))]
2222

2323
#[cfg(feature = "encoding_rs")]
24-
mod encoding;
24+
mod charset;
2525
pub mod file_format;
2626
pub mod source;
2727

@@ -32,6 +32,7 @@ use datafusion_datasource::file_groups::FileGroup;
3232
use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
3333
use datafusion_datasource::{file::FileSource, file_scan_config::FileScanConfig};
3434
use datafusion_execution::object_store::ObjectStoreUrl;
35+
3536
pub use file_format::*;
3637

3738
/// Returns a [`FileScanConfig`] for given `file_groups`
@@ -45,3 +46,17 @@ pub fn partitioned_csv_config(
4546
.build(),
4647
)
4748
}
49+
50+
#[cfg(not(feature = "encoding_rs"))]
51+
mod encoding {
52+
use datafusion_common::{DataFusionError, Result};
53+
54+
pub fn find_encoding(enc: Option<&str>) -> Result<Option<core::convert::Infallible>> {
55+
match enc {
56+
Some(_) => Err(DataFusionError::NotImplemented(format!(
57+
"The 'encoding_rs' feature must be enabled to decode non-UTF-8 encodings"
58+
)))?,
59+
None => Ok(None),
60+
}
61+
}
62+
}

datafusion/datasource-csv/src/source.rs

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ use datafusion_physical_plan::{
4646
DisplayFormatType, ExecutionPlan, ExecutionPlanProperties,
4747
};
4848

49-
#[cfg(feature = "encoding_rs")]
50-
use crate::encoding::CharsetDecoder;
49+
use crate::charset::lookup_charset;
5150
use crate::file_format::CsvDecoder;
5251
use futures::{StreamExt, TryStreamExt};
5352
use object_store::buffered::BufWriter;
@@ -232,30 +231,6 @@ impl CsvOpener {
232231
partition_index: 0,
233232
}
234233
}
235-
236-
#[cfg(feature = "encoding_rs")]
237-
fn encoding(&self) -> Result<Option<&'static encoding_rs::Encoding>> {
238-
match self.config.options.encoding.as_ref() {
239-
Some(enc) => match encoding_rs::Encoding::for_label(enc.as_bytes()) {
240-
Some(enc) => Ok(Some(enc)),
241-
None => Err(DataFusionError::Configuration(format!(
242-
"Unknown character set '{enc}'"
243-
)))?,
244-
},
245-
None => Ok(None),
246-
}
247-
}
248-
249-
#[cfg(not(feature = "encoding_rs"))]
250-
fn encoding(&self) -> Result<Option<core::convert::Infallible>> {
251-
match &self.config.options.encoding {
252-
Some(_) => Err(DataFusionError::NotImplemented(
253-
"The 'encoding_rs' feature must be enabled to decode non-UTF-8 encodings"
254-
.to_owned(),
255-
))?,
256-
None => Ok(None),
257-
}
258-
}
259234
}
260235

261236
impl From<CsvSource> for Arc<dyn FileSource> {
@@ -379,7 +354,7 @@ impl FileOpener for CsvOpener {
379354
config.options.truncated_rows = Some(config.truncate_rows());
380355

381356
let file_compression_type = self.file_compression_type.to_owned();
382-
let encoding = self.encoding()?;
357+
let charset = lookup_charset(self.config.options.charset.as_deref())?;
383358

384359
if partitioned_file.range.is_some() {
385360
assert!(
@@ -437,9 +412,10 @@ impl FileOpener for CsvOpener {
437412

438413
let reader = BufReader::new(reader);
439414

440-
let mut reader = match encoding {
415+
let mut reader = match charset {
441416
#[cfg(feature = "encoding_rs")]
442417
Some(enc) => {
418+
use crate::charset::CharsetDecoder;
443419
let decoder = CharsetDecoder::new(decoder, enc);
444420
deserialize_reader(reader, decoder)
445421
}
@@ -461,9 +437,10 @@ impl FileOpener for CsvOpener {
461437

462438
let stream = file_compression_type.convert_stream(stream)?.fuse();
463439

464-
let stream = match encoding {
440+
let stream = match charset {
465441
#[cfg(feature = "encoding_rs")]
466442
Some(enc) => {
443+
use crate::charset::CharsetDecoder;
467444
let decoder = CharsetDecoder::new(decoder, enc);
468445
deserialize_stream(stream, DecoderDeserializer::new(decoder))
469446
}

datafusion/proto-common/proto/datafusion_common.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ message CsvOptions {
476476
bytes terminator = 17; // Optional terminator character as a byte
477477
bytes truncated_rows = 18; // Indicates if truncated rows are allowed
478478
optional uint32 compression_level = 19; // Optional compression level
479-
string encoding = 20; // Optional character encoding
479+
string charset = 20; // Optional character encoding
480480
}
481481

482482
// Options controlling CSV format

datafusion/proto-common/src/from_proto/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,7 @@ impl TryFrom<&protobuf::CsvOptions> for CsvOptions {
984984
escape: proto_opts.escape.first().copied(),
985985
double_quote: proto_opts.double_quote.first().map(|h| *h != 0),
986986
newlines_in_values: proto_opts.newlines_in_values.first().map(|h| *h != 0),
987-
encoding: (!proto_opts.encoding.is_empty())
987+
charset: (!proto_opts.encoding.is_empty())
988988
.then(|| proto_opts.encoding.clone()),
989989
compression: proto_opts.compression().into(),
990990
compression_level: proto_opts.compression_level,

datafusion/proto-common/src/to_proto/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ impl TryFrom<&CsvOptions> for protobuf::CsvOptions {
986986
newlines_in_values: opts
987987
.newlines_in_values
988988
.map_or_else(Vec::new, |h| vec![h as u8]),
989-
encoding: opts.encoding.clone().unwrap_or_default(),
989+
encoding: opts.charset.clone().unwrap_or_default(),
990990
compression: compression.into(),
991991
schema_infer_max_rec: opts.schema_infer_max_rec.map(|h| h as u64),
992992
date_format: opts.date_format.clone().unwrap_or_default(),

0 commit comments

Comments
 (0)