Skip to content

Commit ecaddb9

Browse files
committed
feat: Add support for reading whole text files to read_text
Signed-off-by: plotor <zhenchao.wang@hotmail.com>
1 parent 19591bd commit ecaddb9

File tree

7 files changed

+199
-5
lines changed

7 files changed

+199
-5
lines changed

daft/daft/__init__.pyi

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,18 @@ class TextSourceConfig:
324324

325325
encoding: str
326326
skip_blank_lines: bool
327+
whole_text: bool
327328
buffer_size: int | None
328329
chunk_size: int | None
329330

330-
def __init__(self, encoding: str, skip_blank_lines: bool, buffer_size: int | None, chunk_size: int | None): ...
331+
def __init__(
332+
self,
333+
encoding: str,
334+
skip_blank_lines: bool,
335+
whole_text: bool,
336+
buffer_size: int | None,
337+
chunk_size: int | None,
338+
): ...
331339

332340
class FileFormatConfig:
333341
"""Configuration for parsing a particular file format (Parquet, CSV, JSON)."""

daft/io/_text.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def read_text(
1414
*,
1515
encoding: str = "utf-8",
1616
skip_blank_lines: bool = True,
17+
whole_text: bool = False,
1718
file_path_column: str | None = None,
1819
hive_partitioning: bool = False,
1920
io_config: IOConfig | None = None,
@@ -26,6 +27,10 @@ def read_text(
2627
path: Path to text file(s). Supports wildcards and remote URLs such as ``s3://`` or ``gs://``.
2728
encoding: Encoding of the input files, defaults to ``"utf-8"``.
2829
skip_blank_lines: Whether to skip empty lines (after stripping whitespace). Defaults to ``True``.
30+
When ``whole_text=True``, this skips files that are entirely blank.
31+
whole_text: Whether to read each file as a single row. Defaults to ``False``.
32+
When ``False``, each line in the file becomes a row in the DataFrame.
33+
When ``True``, the entire content of each file becomes a single row in the DataFrame.
2934
file_path_column: Include the source path(s) as a column with this name. Defaults to ``None``.
3035
hive_partitioning: Whether to infer hive-style partitions from file paths and include them as
3136
columns in the DataFrame. Defaults to ``False``.
@@ -34,7 +39,8 @@ def read_text(
3439
_chunk_size: Optional tuning parameter for the underlying streaming reader chunk size (rows).
3540
3641
Returns:
37-
DataFrame: A DataFrame with a single ``"text"`` column containing lines from the input files.
42+
DataFrame: A DataFrame with a single ``"text"`` column containing lines from the input files
43+
(when ``whole_text=False``) or entire file contents (when ``whole_text=True``).
3844
3945
Examples:
4046
Read a text file from a local path:
@@ -49,6 +55,11 @@ def read_text(
4955
>>> io_config = IOConfig(s3=S3Config(region="us-west-2", anonymous=True))
5056
>>> df = daft.read_text("s3://path/to/files-*.txt", io_config=io_config)
5157
>>> df.show()
58+
59+
Read multiple small files, each as a single row:
60+
61+
>>> df = daft.read_text("/path/to/files/*.txt", whole_text=True)
62+
>>> df.show()
5263
"""
5364
if isinstance(path, list) and len(path) == 0:
5465
raise ValueError("Cannot read DataFrame from empty list of text filepaths")
@@ -57,6 +68,7 @@ def read_text(
5768
text_config = TextSourceConfig(
5869
encoding=encoding,
5970
skip_blank_lines=skip_blank_lines,
71+
whole_text=whole_text,
6072
buffer_size=_buffer_size,
6173
chunk_size=_chunk_size,
6274
)

src/common/file-formats/src/file_format_config.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ impl_bincode_py_state_serialization!(WarcSourceConfig);
464464
pub struct TextSourceConfig {
465465
pub encoding: String,
466466
pub skip_blank_lines: bool,
467+
pub whole_text: bool,
467468
pub buffer_size: Option<usize>,
468469
pub chunk_size: Option<usize>,
469470
}
@@ -477,18 +478,21 @@ impl TextSourceConfig {
477478
#[pyo3(signature = (
478479
encoding,
479480
skip_blank_lines,
481+
whole_text=false,
480482
buffer_size=None,
481-
chunk_size=None
483+
chunk_size=None,
482484
))]
483485
fn new(
484486
encoding: String,
485487
skip_blank_lines: bool,
488+
whole_text: bool,
486489
buffer_size: Option<usize>,
487490
chunk_size: Option<usize>,
488491
) -> PyResult<Self> {
489492
Ok(Self {
490493
encoding,
491494
skip_blank_lines,
495+
whole_text,
492496
buffer_size,
493497
chunk_size,
494498
})
@@ -501,6 +505,7 @@ impl TextSourceConfig {
501505
let mut res = vec![];
502506
res.push(format!("Encoding = {}", self.encoding));
503507
res.push(format!("Skip blank lines = {}", self.skip_blank_lines));
508+
res.push(format!("Whole text = {}", self.whole_text));
504509
if let Some(buffer_size) = self.buffer_size {
505510
res.push(format!("Buffer size = {buffer_size}"));
506511
}
@@ -511,4 +516,16 @@ impl TextSourceConfig {
511516
}
512517
}
513518

519+
impl Default for TextSourceConfig {
520+
fn default() -> Self {
521+
Self {
522+
encoding: "utf-8".to_string(),
523+
skip_blank_lines: true,
524+
whole_text: false,
525+
buffer_size: None,
526+
chunk_size: None,
527+
}
528+
}
529+
}
530+
514531
impl_bincode_py_state_serialization!(TextSourceConfig);

src/daft-local-execution/src/sources/scan_task.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ async fn stream_scan_task(
678678
let convert_options = TextConvertOptions::new(
679679
&cfg.encoding,
680680
cfg.skip_blank_lines,
681+
cfg.whole_text,
681682
Some(schema_of_file),
682683
scan_task.pushdowns.limit,
683684
);

src/daft-text/src/options.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
66
pub struct TextConvertOptions {
77
pub encoding: String,
88
pub skip_blank_lines: bool,
9+
pub whole_text: bool,
910
pub schema: Option<SchemaRef>,
1011
pub limit: Option<usize>,
1112
}
@@ -15,12 +16,14 @@ impl TextConvertOptions {
1516
pub fn new(
1617
encoding: &str,
1718
skip_blank_lines: bool,
19+
whole_text: bool,
1820
schema: Option<SchemaRef>,
1921
limit: Option<usize>,
2022
) -> Self {
2123
Self {
2224
encoding: encoding.to_string(),
2325
skip_blank_lines,
26+
whole_text,
2427
schema,
2528
limit,
2629
}
@@ -29,7 +32,7 @@ impl TextConvertOptions {
2932

3033
impl Default for TextConvertOptions {
3134
fn default() -> Self {
32-
Self::new("utf-8", true, None, None)
35+
Self::new("utf-8", true, false, None, None)
3336
}
3437
}
3538

src/daft-text/src/read.rs

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use daft_recordbatch::RecordBatch;
99
use futures::{Stream, StreamExt, stream::BoxStream};
1010
use tokio::{
1111
fs::File,
12-
io::{AsyncBufRead, AsyncBufReadExt, BufReader},
12+
io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, BufReader},
1313
};
1414
use tokio_util::io::StreamReader;
1515

@@ -40,6 +40,19 @@ pub async fn stream_text(
4040
.clone()
4141
.unwrap_or_else(|| Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8)])));
4242

43+
// Check if we're reading the whole file as a single row
44+
if convert_options.whole_text {
45+
let whole_text_stream =
46+
read_into_whole_text_stream(uri, convert_options, read_options, io_client, io_stats)
47+
.await?;
48+
return Ok(Box::pin(whole_text_stream.map(move |content_res| {
49+
let content = content_res?;
50+
let array = Utf8Array::from_values("text", std::iter::once(content.as_str()));
51+
let series = array.into_series();
52+
RecordBatch::new_with_size(schema.clone(), vec![series], 1)
53+
})));
54+
}
55+
4356
// Build a stream of line chunks
4457
let line_chunk_stream =
4558
read_into_line_chunk_stream(uri, convert_options, read_options, io_client, io_stats)
@@ -60,6 +73,51 @@ pub async fn stream_text(
6073
Ok(Box::pin(table_stream))
6174
}
6275

76+
async fn read_into_whole_text_stream(
77+
uri: String,
78+
convert_options: TextConvertOptions,
79+
read_options: TextReadOptions,
80+
io_client: Arc<IOClient>,
81+
io_stats: Option<IOStatsRef>,
82+
) -> DaftResult<impl Stream<Item = DaftResult<String>> + Send> {
83+
let buffer_size = read_options.buffer_size.unwrap_or(8 * 1024 * 1024);
84+
85+
let reader: Box<dyn AsyncBufRead + Unpin + Send> = match io_client
86+
.single_url_get(uri.clone(), None, io_stats)
87+
.await?
88+
{
89+
GetResult::File(file) => Box::new(BufReader::with_capacity(
90+
buffer_size,
91+
File::open(file.path).await?,
92+
)),
93+
GetResult::Stream(stream, ..) => Box::new(BufReader::with_capacity(
94+
buffer_size,
95+
StreamReader::new(stream),
96+
)),
97+
};
98+
99+
// If file is compressed, wrap stream in decoding stream.
100+
let mut reader: Box<dyn AsyncBufRead + Unpin + Send> = match CompressionCodec::from_uri(&uri) {
101+
Some(compression) => Box::new(BufReader::with_capacity(
102+
buffer_size,
103+
compression.to_decoder(reader),
104+
)),
105+
None => reader,
106+
};
107+
108+
Ok(try_stream! {
109+
let mut content = String::new();
110+
reader.read_to_string(&mut content).await?;
111+
112+
// Apply skip_blank_lines if needed (for whole file, this means skip if entire content is blank)
113+
if convert_options.skip_blank_lines && content.trim().is_empty() {
114+
return;
115+
}
116+
117+
yield content;
118+
})
119+
}
120+
63121
async fn read_into_line_chunk_stream(
64122
uri: String,
65123
convert_options: TextConvertOptions,

tests/io/test_text.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,98 @@ def test_read_with_encoding_setting(tmp_path):
153153

154154
with pytest.raises(Exception, match=r"(?i)utf-?8"):
155155
daft.read_text(str(path)).to_pydict()
156+
157+
158+
def test_read_whole_text_from_single_file(tmp_path):
159+
path = tmp_path / "sample.txt"
160+
path.write_text("hello\nworld\nfoo", encoding="utf-8")
161+
162+
df = daft.read_text(str(path), whole_text=True)
163+
assert df.schema() == Schema.from_pyarrow_schema(pa.schema([("text", pa.string())]))
164+
result = df.to_pydict()
165+
assert result["text"] == ["hello\nworld\nfoo"]
166+
167+
168+
def test_read_whole_text_from_multiple_files(tmp_path):
169+
file_a = tmp_path / "a.txt"
170+
file_b = tmp_path / "b.txt"
171+
file_a.write_text("content of file a\nwith multiple lines", encoding="utf-8")
172+
file_b.write_text("content of file b", encoding="utf-8")
173+
174+
df = daft.read_text([str(file_a), str(file_b)], whole_text=True)
175+
result = df.to_pydict()
176+
assert len(result["text"]) == 2
177+
assert "content of file a\nwith multiple lines" in result["text"]
178+
assert "content of file b" in result["text"]
179+
180+
181+
def test_read_whole_text_with_path_column(tmp_path):
182+
file_a = tmp_path / "a.txt"
183+
file_b = tmp_path / "b.txt"
184+
file_a.write_text("content a", encoding="utf-8")
185+
file_b.write_text("content b", encoding="utf-8")
186+
187+
df = daft.read_text([str(file_a), str(file_b)], whole_text=True, file_path_column="path")
188+
assert df.schema() == Schema.from_pyarrow_schema(pa.schema([("text", pa.string()), ("path", pa.string())]))
189+
190+
data = df.to_pydict()
191+
assert len(data["text"]) == 2
192+
assert len(data["path"]) == 2
193+
194+
rows = {(t, p) for t, p in zip(data["text"], data["path"])}
195+
assert rows == {
196+
("content a", f"{tmp_path}/a.txt"),
197+
("content b", f"{tmp_path}/b.txt"),
198+
}
199+
200+
201+
def test_read_whole_text_from_empty_file(tmp_path):
202+
path = tmp_path / "empty.txt"
203+
path.write_text("", encoding="utf-8")
204+
205+
df = daft.read_text(str(path), whole_text=True, skip_blank_lines=False)
206+
result = df.to_pydict()
207+
assert result["text"] == [""]
208+
209+
df = daft.read_text(str(path), whole_text=True, skip_blank_lines=True)
210+
result = df.to_pydict()
211+
assert result["text"] == []
212+
213+
214+
def test_read_whole_text_with_glob_patterns(tmp_path):
215+
file_a = tmp_path / "a.txt"
216+
file_b = tmp_path / "b.txt"
217+
file_c = tmp_path / "c.txt"
218+
file_d = tmp_path / "d.txt"
219+
file_a.write_text("content a1", encoding="utf-8")
220+
file_b.write_text("content b1\ncontent b2\t", encoding="utf-8")
221+
file_c.write_text("content c1\ncontent c2\ncontent c3\n\t", encoding="utf-8")
222+
file_d.write_text("", encoding="utf-8")
223+
224+
df = daft.read_text(
225+
str(tmp_path / "*.txt"),
226+
skip_blank_lines=True,
227+
whole_text=True,
228+
file_path_column="path",
229+
)
230+
data = df.to_pydict()
231+
assert len(data["text"]) == 3
232+
assert len(data["path"]) == 3
233+
234+
file_to_content = {p: t for p, t in zip(data["path"], data["text"])}
235+
assert file_to_content[str(file_a)] == "content a1"
236+
assert file_to_content[str(file_b)] == "content b1\ncontent b2\t"
237+
assert file_to_content[str(file_c)] == "content c1\ncontent c2\ncontent c3\n\t"
238+
239+
240+
def test_read_whole_text_with_gzip(tmp_path):
241+
def _write_gzip(path: Path, content: bytes) -> None:
242+
with gzip.open(path, "wb") as f:
243+
f.write(content)
244+
245+
path = tmp_path / "compressed.txt.gz"
246+
_write_gzip(path, b"line1\nline2\nline3")
247+
248+
df = daft.read_text(str(path), whole_text=True)
249+
result = df.to_pydict()
250+
assert result["text"] == ["line1\nline2\nline3"]

0 commit comments

Comments
 (0)