Skip to content

Commit 0472e2e

Browse files
committed
Re-implement infer_json_schema to make use of TapeDecoder, removing
the need to parse rows into `serde_json::Value`s first.
1 parent 9fa2d38 commit 0472e2e

File tree

3 files changed

+181
-11
lines changed

3 files changed

+181
-11
lines changed

arrow-json/src/reader/schema.rs

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ use arrow_schema::{ArrowError, Schema};
2222
use bumpalo::Bump;
2323
use serde_json::Value;
2424

25-
use self::infer::{EMPTY_OBJECT_TY, infer_json_type};
26-
use super::ValueIter;
25+
use super::tape::TapeDecoder;
26+
use infer::{ANY_TY, EMPTY_OBJECT_TY, InferredType, TapeValue, infer_json_type};
2727

2828
mod infer;
2929

@@ -100,12 +100,29 @@ pub fn infer_json_schema_from_seekable<R: BufRead + Seek>(
100100
/// file.seek(SeekFrom::Start(0)).unwrap();
101101
/// ```
102102
pub fn infer_json_schema<R: BufRead>(
103-
reader: R,
103+
mut reader: R,
104104
max_read_records: Option<usize>,
105105
) -> Result<(Schema, usize), ArrowError> {
106-
let mut values = ValueIter::new(reader, max_read_records);
107-
let schema = infer_json_schema_from_iterator(&mut values)?;
108-
Ok((schema, values.record_count()))
106+
let arena = Bump::new();
107+
let mut decoder = SchemaDecoder::new(max_read_records, &arena);
108+
109+
loop {
110+
let buf = reader.fill_buf()?;
111+
let read = buf.len();
112+
113+
if read == 0 {
114+
break;
115+
}
116+
117+
let decoded = decoder.decode(buf)?;
118+
reader.consume(decoded);
119+
120+
if decoded != read {
121+
break;
122+
}
123+
}
124+
125+
decoder.finish()
109126
}
110127

111128
/// Infer the fields of a JSON file by reading all items from the JSON Value Iterator.
@@ -136,6 +153,60 @@ where
136153
.into_schema()
137154
}
138155

156+
struct SchemaDecoder<'a> {
157+
decoder: TapeDecoder,
158+
max_read_records: Option<usize>,
159+
record_count: usize,
160+
schema: InferredType<'a>,
161+
arena: &'a Bump,
162+
}
163+
164+
impl<'a> SchemaDecoder<'a> {
165+
pub fn new(max_read_records: Option<usize>, arena: &'a Bump) -> Self {
166+
Self {
167+
decoder: TapeDecoder::new(1024, 8),
168+
max_read_records,
169+
record_count: 0,
170+
schema: ANY_TY,
171+
arena,
172+
}
173+
}
174+
175+
pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
176+
let read = self.decoder.decode(buf)?;
177+
if read != buf.len() {
178+
self.infer_batch()?;
179+
}
180+
Ok(read)
181+
}
182+
183+
pub fn finish(mut self) -> Result<(Schema, usize), ArrowError> {
184+
self.infer_batch()?;
185+
Ok((self.schema.into_schema()?, self.record_count))
186+
}
187+
188+
fn infer_batch(&mut self) -> Result<(), ArrowError> {
189+
let tape = self.decoder.finish()?;
190+
191+
let remaining_records = self
192+
.max_read_records
193+
.map_or(usize::MAX, |max| max - self.record_count);
194+
195+
let records = tape
196+
.iter_rows()
197+
.map(|idx| TapeValue::new(&tape, idx))
198+
.take(remaining_records);
199+
200+
for record in records {
201+
self.schema = infer_json_type(record, self.schema, self.arena)?;
202+
self.record_count += 1;
203+
}
204+
205+
self.decoder.clear();
206+
Ok(())
207+
}
208+
}
209+
139210
#[cfg(test)]
140211
mod tests {
141212
use super::*;
@@ -306,7 +377,7 @@ mod tests {
306377
let re = infer_json_schema_from_seekable(Cursor::new(b"}"), None);
307378
assert_eq!(
308379
re.err().unwrap().to_string(),
309-
"Json error: Not valid JSON: expected value at line 1 column 1",
380+
"Json error: Encountered unexpected '}' whilst parsing value"
310381
);
311382
}
312383

@@ -320,14 +391,14 @@ mod tests {
320391
let (inferred_schema, _) =
321392
infer_json_schema_from_seekable(Cursor::new(data), None).expect("infer");
322393
let schema = Schema::new(vec![
323-
Field::new("an", list_type_of(DataType::Null), true),
324394
Field::new("in", DataType::Int64, true),
325-
Field::new("n", DataType::Null, true),
326-
Field::new("na", list_type_of(DataType::Null), true),
327-
Field::new("nas", list_type_of(DataType::Utf8), true),
328395
Field::new("ni", DataType::Int64, true),
329396
Field::new("ns", DataType::Utf8, true),
330397
Field::new("sn", DataType::Utf8, true),
398+
Field::new("n", DataType::Null, true),
399+
Field::new("an", list_type_of(DataType::Null), true),
400+
Field::new("na", list_type_of(DataType::Null), true),
401+
Field::new("nas", list_type_of(DataType::Utf8), true),
331402
]);
332403
assert_eq!(inferred_schema, schema);
333404
}

arrow-json/src/reader/schema/infer.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ use std::collections::HashMap;
2020
use arrow_schema::{ArrowError, DataType, Field, Fields, Schema};
2121
use bumpalo::Bump;
2222

23+
use crate::reader::tape::{Tape, TapeElement};
24+
2325
#[derive(Clone, Copy, Debug)]
2426
pub struct InferredType<'t>(&'t TyKind<'t>);
2527

@@ -257,6 +259,54 @@ pub enum JsonType {
257259
Object,
258260
}
259261

262+
#[derive(Copy, Clone, Debug)]
263+
pub struct TapeValue<'a> {
264+
tape: &'a Tape<'a>,
265+
idx: u32,
266+
}
267+
268+
impl<'a> TapeValue<'a> {
269+
pub fn new(tape: &'a Tape<'a>, idx: u32) -> Self {
270+
Self { tape, idx }
271+
}
272+
}
273+
274+
impl<'a> JsonValue<'a> for TapeValue<'a> {
275+
fn get(&self) -> JsonType {
276+
match self.tape.get(self.idx) {
277+
TapeElement::Null => JsonType::Null,
278+
TapeElement::False => JsonType::Bool,
279+
TapeElement::True => JsonType::Bool,
280+
TapeElement::I64(_) | TapeElement::I32(_) => JsonType::Int64,
281+
TapeElement::F64(_) | TapeElement::F32(_) => JsonType::Float64,
282+
TapeElement::Number(s) => {
283+
if self.tape.get_string(s).parse::<i64>().is_ok() {
284+
JsonType::Int64
285+
} else {
286+
JsonType::Float64
287+
}
288+
}
289+
TapeElement::String(_) => JsonType::String,
290+
TapeElement::StartList(_) => JsonType::Array,
291+
TapeElement::EndList(_) => unreachable!(),
292+
TapeElement::StartObject(_) => JsonType::Object,
293+
TapeElement::EndObject(_) => unreachable!(),
294+
}
295+
}
296+
297+
fn elements(&self) -> impl Iterator<Item = Self> {
298+
self.tape
299+
.iter_elements(self.idx)
300+
.map(move |idx| Self { idx, ..*self })
301+
}
302+
303+
fn fields(&self) -> impl Iterator<Item = (&'a str, Self)> {
304+
self.tape
305+
.iter_fields(self.idx)
306+
.map(move |(key, idx)| (key, Self { idx, ..*self }))
307+
}
308+
}
309+
260310
impl<'a> JsonValue<'a> for &'a serde_json::Value {
261311
fn get(&self) -> JsonType {
262312
use serde_json::Value;

arrow-json/src/reader/tape.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,55 @@ impl<'a> Tape<'a> {
142142
self.num_rows
143143
}
144144

145+
/// Iterates over the rows of the tape
146+
pub fn iter_rows(&self) -> impl Iterator<Item = u32> {
147+
self.iter_values(0, self.elements.len() as u32)
148+
}
149+
150+
/// Iterates over the elements of the array starting at `idx`
151+
pub fn iter_elements(&self, idx: u32) -> impl Iterator<Item = u32> {
152+
let end = match self.get(idx) {
153+
TapeElement::StartList(end) => end,
154+
elem => panic!("Expected the start of a list, found {:?}", elem),
155+
};
156+
157+
self.iter_values(idx, end)
158+
}
159+
160+
/// Iterates over the fields of the objected starting at `idx`
161+
pub fn iter_fields(&self, idx: u32) -> impl Iterator<Item = (&'a str, u32)> {
162+
let end = match self.get(idx) {
163+
TapeElement::StartObject(end) => end,
164+
elem => panic!("Expected the start of an object, found {:?}", elem),
165+
};
166+
167+
let mut idx = idx + 1;
168+
169+
std::iter::from_fn(move || {
170+
(idx < end).then(|| {
171+
let key = match self.get(idx) {
172+
TapeElement::String(s) => self.get_string(s),
173+
elem => panic!("Expected a string, found {:?}", elem),
174+
};
175+
let value_idx = idx + 1;
176+
idx = self.next(value_idx, "field value").unwrap();
177+
(key, value_idx)
178+
})
179+
})
180+
}
181+
182+
fn iter_values(&self, start: u32, end: u32) -> impl Iterator<Item = u32> {
183+
let mut idx = start + 1;
184+
185+
std::iter::from_fn(move || {
186+
(idx < end).then(|| {
187+
let value_id = idx;
188+
idx = self.next(idx, "value").unwrap();
189+
value_id
190+
})
191+
})
192+
}
193+
145194
/// Serialize the tape element at index `idx` to `out` returning the next field index
146195
fn serialize(&self, out: &mut String, idx: u32) -> u32 {
147196
match self.get(idx) {

0 commit comments

Comments
 (0)