Skip to content

Commit 95c8981

Browse files
committed
Arrow to MongoDB writer
Writes record batches to a MongoDB collection. Fixes #2
1 parent 460e086 commit 95c8981

File tree

4 files changed

+577
-255
lines changed

4 files changed

+577
-255
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ edition = "2018"
88
bson = { version = "0.14", features = ["decimal128"] }
99
mongodb = { git = "https://github.com/nevi-me/mongo-rust-driver", branch = "decimal-128-hack" }
1010
arrow = "0.16.0"
11-
# mongodb-schema-parser = { git = "https://github.com/nevi-me/mongodb-schema-parser", branch = "update-bson" }
11+
mongodb-schema-parser = { git = "https://github.com/nevi-me/mongodb-schema-parser", branch = "write-bson", default-features = false }
12+
chrono = "0.4"

src/lib.rs

Lines changed: 2 additions & 254 deletions
Original file line numberDiff line numberDiff line change
@@ -1,257 +1,5 @@
1-
use std::sync::Arc;
1+
pub mod reader;
2+
pub mod writer;
23

3-
use arrow::{
4-
array::*,
5-
datatypes::{DataType, Schema, TimeUnit},
6-
record_batch::{RecordBatch, RecordBatchReader},
7-
};
84
#[macro_use(bson)]
95
extern crate bson;
10-
use bson::{doc, Bson};
11-
use mongodb::{
12-
options::{AggregateOptions, ClientOptions, StreamAddress},
13-
Client,
14-
};
15-
// use mongodb_schema_parser::SchemaParser;
16-
17-
pub struct ReaderConfig<'a> {
18-
pub hostname: &'a str,
19-
pub port: Option<u16>,
20-
// read_preference,
21-
pub database: &'a str,
22-
pub collection: &'a str,
23-
}
24-
25-
pub struct Reader {
26-
client: Client,
27-
database: String,
28-
collection: String,
29-
schema: Schema,
30-
current_index: usize,
31-
batch_size: usize,
32-
}
33-
34-
impl Reader {
35-
/// Try to create a new reader
36-
pub fn try_new(config: &ReaderConfig, schema: Schema) -> Result<Self, ()> {
37-
let options = ClientOptions::builder()
38-
.hosts(vec![StreamAddress {
39-
hostname: config.hostname.to_string(),
40-
port: config.port,
41-
}])
42-
.build();
43-
// TODO: support connection with uri_string
44-
let client = Client::with_options(options).expect("Unable to connect to MongoDB");
45-
46-
Ok(Self {
47-
/// MongoDB client. The client supports connection pooling, and is suitable for parallel querying
48-
client,
49-
/// Database name
50-
database: config.database.to_string(),
51-
/// Collection name
52-
collection: config.collection.to_string(),
53-
/// The schema of the collection being read
54-
schema,
55-
/// An internal counter to track the number of documents read
56-
current_index: 0,
57-
/// The batch size that should be returned from the database
58-
///
59-
/// If documents are relatively small, or there is ample RAM, a very large batch size should be used
60-
/// to reduce the number of roundtrips to the database
61-
batch_size: 1024000,
62-
})
63-
}
64-
65-
/// Read the next record batch
66-
pub fn next(&mut self) -> Result<Option<RecordBatch>, ()> {
67-
let mut criteria = doc! {};
68-
let mut project = doc! {};
69-
for field in self.schema.fields() {
70-
project.insert(field.name(), bson::Bson::I32(1));
71-
}
72-
criteria.insert("$project", project);
73-
let coll = self
74-
.client
75-
.database(self.database.as_ref())
76-
.collection(self.collection.as_ref());
77-
78-
let aggregate_options = AggregateOptions::builder()
79-
.batch_size(Some(self.batch_size as u32))
80-
.build();
81-
82-
let mut cursor = coll
83-
.aggregate(
84-
vec![criteria, doc! {"$skip": self.current_index as i32}],
85-
Some(aggregate_options),
86-
)
87-
.expect("Unable to run aggregation");
88-
89-
// collect results from cursor into batches
90-
let mut docs = vec![];
91-
for _ in 0..self.batch_size {
92-
if let Some(Ok(doc)) = cursor.next() {
93-
docs.push(doc);
94-
} else {
95-
break;
96-
}
97-
}
98-
99-
let docs_len = docs.len();
100-
self.current_index = self.current_index + docs_len;
101-
if docs_len == 0 {
102-
return Ok(None);
103-
}
104-
dbg!(&self.current_index);
105-
106-
let mut builder = StructBuilder::from_schema(self.schema.clone(), self.current_index);
107-
108-
let field_len = self.schema.fields().len();
109-
for i in 0..field_len {
110-
let field = self.schema.field(i);
111-
match field.data_type() {
112-
DataType::Binary => {}
113-
DataType::Boolean => {
114-
let field_builder = builder.field_builder::<BooleanBuilder>(i).unwrap();
115-
for v in 0..docs_len {
116-
let doc: &_ = docs.get(v).unwrap();
117-
match doc.get_bool(field.name()) {
118-
Ok(val) => field_builder.append_value(val).unwrap(),
119-
Err(_) => field_builder.append_null().unwrap(),
120-
};
121-
}
122-
}
123-
DataType::Timestamp(time_unit, _) => {
124-
let field_builder = match time_unit {
125-
TimeUnit::Millisecond => builder
126-
.field_builder::<TimestampMillisecondBuilder>(i)
127-
.unwrap(),
128-
t @ _ => panic!("Timestamp arrays can only be read as milliseconds, found {:?}. \nPlease read as milliseconds then cast to desired resolution.", t)
129-
};
130-
for v in 0..docs_len {
131-
let doc: &_ = docs.get(v).unwrap();
132-
match doc.get_utc_datetime(field.name()) {
133-
Ok(val) => field_builder.append_value(val.timestamp_millis()).unwrap(),
134-
Err(_) => field_builder.append_null().unwrap(),
135-
};
136-
}
137-
}
138-
DataType::Float64 => {
139-
let field_builder = builder.field_builder::<Float64Builder>(i).unwrap();
140-
for v in 0..docs_len {
141-
let doc: &_ = docs.get(v).unwrap();
142-
match doc.get_f64(field.name()) {
143-
Ok(val) => field_builder.append_value(val).unwrap(),
144-
Err(_) => field_builder.append_null().unwrap(),
145-
};
146-
}
147-
}
148-
DataType::Int32 => {
149-
let field_builder = builder.field_builder::<Int32Builder>(i).unwrap();
150-
for v in 0..docs_len {
151-
let doc: &_ = docs.get(v).unwrap();
152-
match doc.get_i32(field.name()) {
153-
Ok(val) => field_builder.append_value(val).unwrap(),
154-
Err(_) => field_builder.append_null().unwrap(),
155-
};
156-
}
157-
}
158-
DataType::Int64 => {
159-
let field_builder = builder.field_builder::<Int64Builder>(i).unwrap();
160-
for v in 0..docs_len {
161-
let doc: &_ = docs.get(v).unwrap();
162-
match doc.get_i64(field.name()) {
163-
Ok(val) => field_builder.append_value(val).unwrap(),
164-
Err(_) => field_builder.append_null().unwrap(),
165-
};
166-
}
167-
}
168-
DataType::Utf8 => {
169-
let field_builder = builder.field_builder::<StringBuilder>(i).unwrap();
170-
for v in 0..docs_len {
171-
let doc: &_ = docs.get(v).unwrap();
172-
match doc.get(field.name()) {
173-
Some(Bson::ObjectId(oid)) => {
174-
field_builder.append_value(oid.to_hex().as_str()).unwrap()
175-
}
176-
Some(Bson::String(val)) => field_builder.append_value(&val).unwrap(),
177-
Some(Bson::Null) => field_builder.append_null().unwrap(),
178-
Some(t) => panic!(
179-
"Option to cast non-string types to string not yet implemented for {:?}", t
180-
),
181-
None => field_builder.append_null().unwrap(),
182-
};
183-
}
184-
}
185-
DataType::List(_dtype) => panic!("Creating lists not yet implemented"),
186-
DataType::Struct(_fields) => panic!("Creating nested structs not yet implemented"),
187-
t @ _ => panic!("Data type {:?} not supported when reading from MongoDB", t),
188-
}
189-
}
190-
// append true to all struct records
191-
for _ in 0..docs_len {
192-
builder.append(true).unwrap();
193-
}
194-
Ok(Some(RecordBatch::from(&builder.finish())))
195-
}
196-
}
197-
198-
impl RecordBatchReader for Reader {
199-
fn schema(&mut self) -> Arc<Schema> {
200-
Arc::new(self.schema.clone())
201-
}
202-
fn next_batch(&mut self) -> arrow::error::Result<Option<RecordBatch>> {
203-
self.next().map_err(|_| {
204-
arrow::error::ArrowError::IoError("Unable to read next batch from MongoDB".to_string())
205-
})
206-
}
207-
}
208-
209-
#[cfg(test)]
210-
mod tests {
211-
use super::*;
212-
213-
use std::fs::File;
214-
215-
use arrow::csv;
216-
use arrow::datatypes::Field;
217-
218-
#[test]
219-
fn test_read_collection() -> Result<(), ()> {
220-
let fields = vec![
221-
Field::new("_id", DataType::Utf8, false),
222-
Field::new("trip_id", DataType::Utf8, false),
223-
Field::new("trip_status", DataType::Utf8, false),
224-
Field::new("route_name", DataType::Utf8, false),
225-
Field::new("route_variant", DataType::Utf8, true),
226-
Field::new(
227-
"trip_date",
228-
DataType::Timestamp(TimeUnit::Millisecond, None),
229-
false,
230-
),
231-
Field::new("trip_time", DataType::Int32, false),
232-
Field::new("direction", DataType::Utf8, false),
233-
Field::new("line", DataType::Utf8, true),
234-
Field::new("stop_id", DataType::Utf8, true),
235-
Field::new("stop_index", DataType::Int32, false),
236-
Field::new("scheduled_departure", DataType::Int32, false),
237-
Field::new("observed_departure", DataType::Int32, true),
238-
Field::new("stop_relevance", DataType::Utf8, false),
239-
];
240-
let schema = Schema::new(fields);
241-
let config = ReaderConfig {
242-
hostname: "localhost",
243-
port: None,
244-
database: "mycollection",
245-
collection: "delays",
246-
};
247-
let mut reader = Reader::try_new(&config, schema)?;
248-
249-
// write results to CSV as the schema would allow
250-
let file = File::create("./target/debug/delays.csv").unwrap();
251-
let mut writer = csv::Writer::new(file);
252-
while let Ok(Some(batch)) = reader.next() {
253-
writer.write(&batch).unwrap();
254-
}
255-
Ok(())
256-
}
257-
}

0 commit comments

Comments
 (0)