From eb50eb4249c592b2b248a3ff155f88f5c9d0001f Mon Sep 17 00:00:00 2001 From: plotor Date: Fri, 6 Mar 2026 14:59:08 +0800 Subject: [PATCH] feat: Add support for reading whole text files to `read_text` Signed-off-by: plotor --- Cargo.lock | 1 + daft/daft/__init__.pyi | 10 +- daft/io/_text.py | 15 +- .../src/sources/scan_task_reader.rs | 1 + src/daft-scan/src/file_format_config.rs | 19 +- src/daft-text/Cargo.toml | 3 + src/daft-text/src/options.rs | 5 +- src/daft-text/src/read.rs | 290 ++++++++++++++---- tests/io/test_text.py | 95 ++++++ 9 files changed, 374 insertions(+), 65 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b50d58724c..75cab2d4cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3352,6 +3352,7 @@ dependencies = [ "daft-core", "daft-io", "daft-recordbatch", + "flate2", "futures", "serde", "tokio", diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 7fa68d1e93..e1b2cc5f32 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -324,10 +324,18 @@ class TextSourceConfig: encoding: str skip_blank_lines: bool + whole_text: bool buffer_size: int | None chunk_size: int | None - def __init__(self, encoding: str, skip_blank_lines: bool, buffer_size: int | None, chunk_size: int | None): ... + def __init__( + self, + encoding: str, + skip_blank_lines: bool, + whole_text: bool, + buffer_size: int | None, + chunk_size: int | None, + ): ... class FileFormatConfig: """Configuration for parsing a particular file format (Parquet, CSV, JSON).""" diff --git a/daft/io/_text.py b/daft/io/_text.py index 8332de1076..9241f765f8 100755 --- a/daft/io/_text.py +++ b/daft/io/_text.py @@ -14,6 +14,7 @@ def read_text( *, encoding: str = "utf-8", skip_blank_lines: bool = True, + whole_text: bool = False, file_path_column: str | None = None, hive_partitioning: bool = False, io_config: IOConfig | None = None, @@ -26,15 +27,21 @@ def read_text( path: Path to text file(s). Supports wildcards and remote URLs such as ``s3://`` or ``gs://``. encoding: Encoding of the input files, defaults to ``"utf-8"``. skip_blank_lines: Whether to skip empty lines (after stripping whitespace). Defaults to ``True``. + When ``whole_text=True``, this skips files that are entirely blank. + whole_text: Whether to read each file as a single row. Defaults to ``False``. + When ``False``, each line in the file becomes a row in the DataFrame. + When ``True``, the entire content of each file becomes a single row in the DataFrame. file_path_column: Include the source path(s) as a column with this name. Defaults to ``None``. hive_partitioning: Whether to infer hive-style partitions from file paths and include them as columns in the DataFrame. Defaults to ``False``. io_config: IO configuration for the native downloader. _buffer_size: Optional tuning parameter for the underlying streaming reader buffer size (bytes). _chunk_size: Optional tuning parameter for the underlying streaming reader chunk size (rows). + Has no effect when ``whole_text=True``. Returns: - DataFrame: A DataFrame with a single ``"text"`` column containing lines from the input files. + DataFrame: A DataFrame with a single ``"text"`` column containing lines from the input files + (when ``whole_text=False``) or entire file contents (when ``whole_text=True``). Examples: Read a text file from a local path: @@ -49,6 +56,11 @@ def read_text( >>> io_config = IOConfig(s3=S3Config(region="us-west-2", anonymous=True)) >>> df = daft.read_text("s3://path/to/files-*.txt", io_config=io_config) >>> df.show() + + Read multiple small files, each as a single row: + + >>> df = daft.read_text("/path/to/files/*.txt", whole_text=True) + >>> df.show() """ if isinstance(path, list) and len(path) == 0: raise ValueError("Cannot read DataFrame from empty list of text filepaths") @@ -57,6 +69,7 @@ def read_text( text_config = TextSourceConfig( encoding=encoding, skip_blank_lines=skip_blank_lines, + whole_text=whole_text, buffer_size=_buffer_size, chunk_size=_chunk_size, ) diff --git a/src/daft-local-execution/src/sources/scan_task_reader.rs b/src/daft-local-execution/src/sources/scan_task_reader.rs index 71bec2e640..841de77036 100644 --- a/src/daft-local-execution/src/sources/scan_task_reader.rs +++ b/src/daft-local-execution/src/sources/scan_task_reader.rs @@ -257,6 +257,7 @@ async fn read_text( let convert_options = TextConvertOptions::new( &cfg.encoding, cfg.skip_blank_lines, + cfg.whole_text, Some(schema_of_file), scan_task.pushdowns.limit, ); diff --git a/src/daft-scan/src/file_format_config.rs b/src/daft-scan/src/file_format_config.rs index 2bcad88a4d..59ca03b753 100644 --- a/src/daft-scan/src/file_format_config.rs +++ b/src/daft-scan/src/file_format_config.rs @@ -442,6 +442,7 @@ impl_bincode_py_state_serialization!(WarcSourceConfig); pub struct TextSourceConfig { pub encoding: String, pub skip_blank_lines: bool, + pub whole_text: bool, pub buffer_size: Option, pub chunk_size: Option, } @@ -455,18 +456,21 @@ impl TextSourceConfig { #[pyo3(signature = ( encoding, skip_blank_lines, + whole_text=false, buffer_size=None, - chunk_size=None + chunk_size=None, ))] fn new( encoding: String, skip_blank_lines: bool, + whole_text: bool, buffer_size: Option, chunk_size: Option, ) -> PyResult { Ok(Self { encoding, skip_blank_lines, + whole_text, buffer_size, chunk_size, }) @@ -479,6 +483,7 @@ impl TextSourceConfig { let mut res = vec![]; res.push(format!("Encoding = {}", self.encoding)); res.push(format!("Skip blank lines = {}", self.skip_blank_lines)); + res.push(format!("Whole text = {}", self.whole_text)); if let Some(buffer_size) = self.buffer_size { res.push(format!("Buffer size = {buffer_size}")); } @@ -489,4 +494,16 @@ impl TextSourceConfig { } } +impl Default for TextSourceConfig { + fn default() -> Self { + Self { + encoding: "utf-8".to_string(), + skip_blank_lines: true, + whole_text: false, + buffer_size: None, + chunk_size: None, + } + } +} + impl_bincode_py_state_serialization!(TextSourceConfig); diff --git a/src/daft-text/Cargo.toml b/src/daft-text/Cargo.toml index e6a3a3ae1f..a93aecf1ea 100644 --- a/src/daft-text/Cargo.toml +++ b/src/daft-text/Cargo.toml @@ -11,6 +11,9 @@ tokio = {workspace = true} tokio-stream = {workspace = true, features = ["io-util"]} tokio-util = {workspace = true} +[dev-dependencies] +flate2 = {version = "1.1", features = ["zlib-rs"], default-features = false} + [features] python = [ "common-error/python", diff --git a/src/daft-text/src/options.rs b/src/daft-text/src/options.rs index a71c1e1102..213e103ca4 100644 --- a/src/daft-text/src/options.rs +++ b/src/daft-text/src/options.rs @@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize}; pub struct TextConvertOptions { pub encoding: String, pub skip_blank_lines: bool, + pub whole_text: bool, pub schema: Option, pub limit: Option, } @@ -15,12 +16,14 @@ impl TextConvertOptions { pub fn new( encoding: &str, skip_blank_lines: bool, + whole_text: bool, schema: Option, limit: Option, ) -> Self { Self { encoding: encoding.to_string(), skip_blank_lines, + whole_text, schema, limit, } @@ -29,7 +32,7 @@ impl TextConvertOptions { impl Default for TextConvertOptions { fn default() -> Self { - Self::new("utf-8", true, None, None) + Self::new("utf-8", true, false, None, None) } } diff --git a/src/daft-text/src/read.rs b/src/daft-text/src/read.rs index 225b7124c4..95c5904476 100644 --- a/src/daft-text/src/read.rs +++ b/src/daft-text/src/read.rs @@ -9,12 +9,40 @@ use daft_recordbatch::RecordBatch; use futures::{Stream, StreamExt, stream::BoxStream}; use tokio::{ fs::File, - io::{AsyncBufRead, AsyncBufReadExt, BufReader}, + io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, BufReader}, }; use tokio_util::io::StreamReader; use crate::options::{TextConvertOptions, TextReadOptions}; +async fn open_reader( + uri: &str, + buffer_size: usize, + io_client: Arc, + io_stats: Option, +) -> DaftResult> { + let reader: Box = match io_client + .single_url_get(uri.to_string(), None, io_stats) + .await? + { + GetResult::File(file) => Box::new(BufReader::with_capacity( + buffer_size, + File::open(file.path).await?, + )), + GetResult::Stream(stream, ..) => Box::new(BufReader::with_capacity( + buffer_size, + StreamReader::new(stream), + )), + }; + Ok(match CompressionCodec::from_uri(uri) { + Some(codec) => Box::new(BufReader::with_capacity( + buffer_size, + codec.to_decoder(reader), + )), + None => reader, + }) +} + /// Stream text lines from a URI into `RecordBatch` chunks with a single Utf8 column named "text". /// /// The `encoding` argument is currently restricted to UTF-8 (case-insensitive). Any other encoding @@ -40,6 +68,19 @@ pub async fn stream_text( .clone() .unwrap_or_else(|| Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8)]))); + // Check if we're reading the whole file as a single row + if convert_options.whole_text { + let whole_text_stream = + read_into_whole_text_stream(uri, convert_options, read_options, io_client, io_stats) + .await?; + return Ok(Box::pin(whole_text_stream.map(move |content_res| { + let content = content_res?; + let array = Utf8Array::from_values("text", std::iter::once(content.as_str())); + let series = array.into_series(); + RecordBatch::new_with_size(schema.clone(), vec![series], 1) + }))); + } + // Build a stream of line chunks let line_chunk_stream = read_into_line_chunk_stream(uri, convert_options, read_options, io_client, io_stats) @@ -60,6 +101,34 @@ pub async fn stream_text( Ok(Box::pin(table_stream)) } +async fn read_into_whole_text_stream( + uri: String, + convert_options: TextConvertOptions, + read_options: TextReadOptions, + io_client: Arc, + io_stats: Option, +) -> DaftResult> + Send> { + let buffer_size = read_options.buffer_size.unwrap_or(8 * 1024 * 1024); + let mut reader = open_reader(&uri, buffer_size, io_client, io_stats).await?; + + Ok(try_stream! { + // Check limit first, and skip read if limit is 0 + if convert_options.limit == Some(0) { + return; + } + + let mut content = String::new(); + reader.read_to_string(&mut content).await?; + + // Apply skip_blank_lines if needed (for whole file, this means skip if entire content is blank) + if convert_options.skip_blank_lines && content.trim().is_empty() { + return; + } + + yield content; + }) +} + async fn read_into_line_chunk_stream( uri: String, convert_options: TextConvertOptions, @@ -67,40 +136,9 @@ async fn read_into_line_chunk_stream( io_client: Arc, io_stats: Option, ) -> DaftResult>> + Send> { - let (reader, buffer_size, chunk_size): (Box, usize, usize) = - match io_client - .single_url_get(uri.clone(), None, io_stats) - .await? - { - GetResult::File(file) => { - // Use user-provided buffer size, otherwise falling back to 256KiB as the default. - let buffer_size = read_options.buffer_size.unwrap_or(256 * 1024); - let chunk_size = read_options.chunk_size.unwrap_or(64 * 1024); - ( - Box::new(BufReader::with_capacity( - buffer_size, - File::open(file.path).await?, - )), - buffer_size, - chunk_size, - ) - } - GetResult::Stream(stream, ..) => { - // Use user-provided buffer size, otherwise falling back to 8MiB as the default. - let buffer_size = read_options.buffer_size.unwrap_or(8 * 1024 * 1024); - let chunk_size = read_options.chunk_size.unwrap_or(64 * 1024); - (Box::new(StreamReader::new(stream)), buffer_size, chunk_size) - } - }; - - // If file is compressed, wrap stream in decoding stream. - let reader: Box = match CompressionCodec::from_uri(&uri) { - Some(compression) => Box::new(BufReader::with_capacity( - buffer_size, - compression.to_decoder(reader), - )), - None => reader, - }; + let buffer_size = read_options.buffer_size.unwrap_or(8 * 1024 * 1024); + let chunk_size = read_options.chunk_size.unwrap_or(64 * 1024); + let reader = open_reader(&uri, buffer_size, io_client, io_stats).await?; let line_stream = tokio_stream::wrappers::LinesStream::new(reader.lines()); Ok(try_stream! { @@ -140,50 +178,180 @@ async fn read_into_line_chunk_stream( mod tests { use std::{ fs, + io::Write, sync::Arc, time::{SystemTime, UNIX_EPOCH}, }; use daft_io::{IOConfig, get_io_client}; + use flate2::{Compression, write::GzEncoder}; use futures::StreamExt; use super::*; - #[tokio::test] - async fn read_local_text_file() { - // Create a uniquely named temporary file in the system temp directory. + fn unique_temp_path(extension: &str) -> std::path::PathBuf { + use std::sync::atomic::{AtomicU64, Ordering}; + static COUNTER: AtomicU64 = AtomicU64::new(0); + let mut path = std::env::temp_dir(); let unique = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("system time before UNIX_EPOCH") .as_nanos(); - path.push(format!("daft_text_stream_test_{unique}.txt")); + let counter = COUNTER.fetch_add(1, Ordering::SeqCst); + path.push(format!("daft_text_test_{unique}_{counter}.{extension}")); + path + } + + fn create_test_file(content: &str, compressed: bool) -> (std::path::PathBuf, String) { + let extension = if compressed { "gz" } else { "txt" }; + let path = unique_temp_path(extension); + + if compressed { + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder + .write_all(content.as_bytes()) + .expect("failed to compress content"); + let compressed_data = encoder.finish().expect("failed to finish compression"); + fs::write(&path, &compressed_data).expect("failed to write compressed file"); + } else { + fs::write(&path, content).expect("failed to write temp text file"); + } + + (path, content.to_string()) + } + + #[tokio::test] + async fn test_read_into_whole_text_stream() { + let io_config = Arc::new(IOConfig::default()); + let io_client = get_io_client(false, io_config).expect("failed to construct IOClient"); + + let test_cases = vec![ + ("uncompressed with default buffer", false, None, None), + ("uncompressed with small buffer", false, Some(16), None), + ( + "uncompressed with large buffer", + false, + Some(1024 * 1024), + None, + ), + ("gzip compressed with default buffer", true, None, None), + ("gzip compressed with small buffer", true, Some(16), None), + ( + "gzip compressed with large buffer", + true, + Some(1024 * 1024), + None, + ), + ("uncompressed with limit=0", false, None, Some(0)), + ("gzip compressed with limit=0", true, None, Some(0)), + ]; + + for (name, compressed, buffer_size, limit) in test_cases { + let content = "Hello, World!\nThis is a test file.\nMultiple lines here.\n"; + let (path, expected_content) = create_test_file(content, compressed); + + let read_options = TextReadOptions { + buffer_size, + ..Default::default() + }; - fs::write(&path, b"line1\nline2\n").expect("failed to write temp text file"); + let convert_options = TextConvertOptions { + limit, + ..Default::default() + }; + let stream = read_into_whole_text_stream( + path.to_string_lossy().to_string(), + convert_options, + read_options, + io_client.clone(), + None, + ) + .await + .expect(&format!( + "read_into_whole_text_stream should succeed for {name}" + )); + + let results: Vec<_> = stream.collect::>().await; + + if limit == Some(0) { + assert_eq!( + results.len(), + 0, + "[{name}] expected zero results for limit=0" + ); + } else { + assert_eq!(results.len(), 1, "[{name}] expected exactly one result"); + + let actual_content = results[0] + .as_ref() + .expect(&format!("[{name}] stream yielded error")) + .clone(); + assert_eq!( + actual_content, expected_content, + "[{name}] content mismatch" + ); + } + + let _ = fs::remove_file(&path); + } + } + + #[tokio::test] + async fn test_read_into_line_chunk_stream() { let io_config = Arc::new(IOConfig::default()); let io_client = get_io_client(false, io_config).expect("failed to construct IOClient"); - let stream = stream_text( - path.to_string_lossy().to_string(), - TextConvertOptions::default(), - TextReadOptions::default(), - io_client, - None, - ) - .await - .expect("stream_text should succeed for local file"); - - let batches: Vec<_> = stream.collect::>().await; - assert!(!batches.is_empty(), "expected at least one RecordBatch"); - - let total_rows: usize = batches - .into_iter() - .map(|res| res.expect("stream yielded error RecordBatch")) - .map(|rb| rb.num_rows()) - .sum(); - assert_eq!(total_rows, 2); - - let _ = fs::remove_file(&path); + let test_cases = vec![ + ("uncompressed with default buffer", false, None), + ("uncompressed with small buffer", false, Some(16)), + ("uncompressed with large buffer", false, Some(1024 * 1024)), + ("gzip compressed with default buffer", true, None), + ("gzip compressed with small buffer", true, Some(16)), + ("gzip compressed with large buffer", true, Some(1024 * 1024)), + ]; + + for (name, compressed, buffer_size) in test_cases { + let content = "line1\nline2\nline3\nline4\nline5\n"; + let (path, _) = create_test_file(content, compressed); + + let read_options = TextReadOptions { + buffer_size, + chunk_size: Some(2), + }; + + let stream = read_into_line_chunk_stream( + path.to_string_lossy().to_string(), + TextConvertOptions::default(), + read_options, + io_client.clone(), + None, + ) + .await + .expect(&format!( + "read_into_line_chunk_stream should succeed for {name}" + )); + + let chunks: Vec<_> = stream.collect::>().await; + + let all_lines: Vec = chunks + .iter() + .flat_map(|chunk_res| { + chunk_res + .as_ref() + .expect(&format!("[{name}] stream yielded error")) + .clone() + }) + .collect(); + + assert_eq!( + all_lines, + vec!["line1", "line2", "line3", "line4", "line5"], + "[{name}] lines mismatch" + ); + + let _ = fs::remove_file(&path); + } } } diff --git a/tests/io/test_text.py b/tests/io/test_text.py index fa80ca3074..397ab2348b 100755 --- a/tests/io/test_text.py +++ b/tests/io/test_text.py @@ -153,3 +153,98 @@ def test_read_with_encoding_setting(tmp_path): with pytest.raises(Exception, match=r"(?i)utf-?8"): daft.read_text(str(path)).to_pydict() + + +def test_read_whole_text_from_single_file(tmp_path): + path = tmp_path / "sample.txt" + path.write_text("hello\nworld\nfoo", encoding="utf-8") + + df = daft.read_text(str(path), whole_text=True) + assert df.schema() == Schema.from_pyarrow_schema(pa.schema([("text", pa.string())])) + result = df.to_pydict() + assert result["text"] == ["hello\nworld\nfoo"] + + +def test_read_whole_text_from_multiple_files(tmp_path): + file_a = tmp_path / "a.txt" + file_b = tmp_path / "b.txt" + file_a.write_text("content of file a\nwith multiple lines", encoding="utf-8") + file_b.write_text("content of file b", encoding="utf-8") + + df = daft.read_text([str(file_a), str(file_b)], whole_text=True) + result = df.to_pydict() + assert len(result["text"]) == 2 + assert "content of file a\nwith multiple lines" in result["text"] + assert "content of file b" in result["text"] + + +def test_read_whole_text_with_path_column(tmp_path): + file_a = tmp_path / "a.txt" + file_b = tmp_path / "b.txt" + file_a.write_text("content a", encoding="utf-8") + file_b.write_text("content b", encoding="utf-8") + + df = daft.read_text([str(file_a), str(file_b)], whole_text=True, file_path_column="path") + assert df.schema() == Schema.from_pyarrow_schema(pa.schema([("text", pa.string()), ("path", pa.string())])) + + data = df.to_pydict() + assert len(data["text"]) == 2 + assert len(data["path"]) == 2 + + rows = {(t, p) for t, p in zip(data["text"], data["path"])} + assert rows == { + ("content a", f"{tmp_path}/a.txt"), + ("content b", f"{tmp_path}/b.txt"), + } + + +def test_read_whole_text_from_empty_file(tmp_path): + path = tmp_path / "empty.txt" + path.write_text("", encoding="utf-8") + + df = daft.read_text(str(path), whole_text=True, skip_blank_lines=False) + result = df.to_pydict() + assert result["text"] == [""] + + df = daft.read_text(str(path), whole_text=True, skip_blank_lines=True) + result = df.to_pydict() + assert result["text"] == [] + + +def test_read_whole_text_with_glob_patterns(tmp_path): + file_a = tmp_path / "a.txt" + file_b = tmp_path / "b.txt" + file_c = tmp_path / "c.txt" + file_d = tmp_path / "d.txt" + file_a.write_text("content a1", encoding="utf-8") + file_b.write_text("content b1\ncontent b2\t", encoding="utf-8") + file_c.write_text("content c1\ncontent c2\ncontent c3\n\t", encoding="utf-8") + file_d.write_text("", encoding="utf-8") + + df = daft.read_text( + str(tmp_path / "*.txt"), + skip_blank_lines=True, + whole_text=True, + file_path_column="path", + ) + data = df.to_pydict() + assert len(data["text"]) == 3 + assert len(data["path"]) == 3 + + file_to_content = {p: t for p, t in zip(data["path"], data["text"])} + assert file_to_content[str(file_a)] == "content a1" + assert file_to_content[str(file_b)] == "content b1\ncontent b2\t" + assert file_to_content[str(file_c)] == "content c1\ncontent c2\ncontent c3\n\t" + + +def test_read_whole_text_with_gzip(tmp_path): + def _write_gzip(path: Path, content: bytes) -> None: + with gzip.open(path, "wb") as f: + f.write(content) + + path = tmp_path / "compressed.txt.gz" + _write_gzip(path, b"line1\nline2\nline3") + + df = daft.read_text(str(path), whole_text=True) + result = df.to_pydict() + assert result["text"] == ["line1\nline2\nline3"]