|
1 |
| -use std::sync::Arc; |
| 1 | +pub mod reader; |
| 2 | +pub mod writer; |
2 | 3 |
|
3 |
| -use arrow::{ |
4 |
| - array::*, |
5 |
| - datatypes::{DataType, Schema, TimeUnit}, |
6 |
| - record_batch::{RecordBatch, RecordBatchReader}, |
7 |
| -}; |
8 | 4 | #[macro_use(bson)]
|
9 | 5 | extern crate bson;
|
10 |
| -use bson::{doc, Bson}; |
11 |
| -use mongodb::{ |
12 |
| - options::{AggregateOptions, ClientOptions, StreamAddress}, |
13 |
| - Client, |
14 |
| -}; |
15 |
| -// use mongodb_schema_parser::SchemaParser; |
16 |
| - |
17 |
| -pub struct ReaderConfig<'a> { |
18 |
| - pub hostname: &'a str, |
19 |
| - pub port: Option<u16>, |
20 |
| - // read_preference, |
21 |
| - pub database: &'a str, |
22 |
| - pub collection: &'a str, |
23 |
| -} |
24 |
| - |
25 |
| -pub struct Reader { |
26 |
| - client: Client, |
27 |
| - database: String, |
28 |
| - collection: String, |
29 |
| - schema: Schema, |
30 |
| - current_index: usize, |
31 |
| - batch_size: usize, |
32 |
| -} |
33 |
| - |
34 |
| -impl Reader { |
35 |
| - /// Try to create a new reader |
36 |
| - pub fn try_new(config: &ReaderConfig, schema: Schema) -> Result<Self, ()> { |
37 |
| - let options = ClientOptions::builder() |
38 |
| - .hosts(vec![StreamAddress { |
39 |
| - hostname: config.hostname.to_string(), |
40 |
| - port: config.port, |
41 |
| - }]) |
42 |
| - .build(); |
43 |
| - // TODO: support connection with uri_string |
44 |
| - let client = Client::with_options(options).expect("Unable to connect to MongoDB"); |
45 |
| - |
46 |
| - Ok(Self { |
47 |
| - /// MongoDB client. The client supports connection pooling, and is suitable for parallel querying |
48 |
| - client, |
49 |
| - /// Database name |
50 |
| - database: config.database.to_string(), |
51 |
| - /// Collection name |
52 |
| - collection: config.collection.to_string(), |
53 |
| - /// The schema of the collection being read |
54 |
| - schema, |
55 |
| - /// An internal counter to track the number of documents read |
56 |
| - current_index: 0, |
57 |
| - /// The batch size that should be returned from the database |
58 |
| - /// |
59 |
| - /// If documents are relatively small, or there is ample RAM, a very large batch size should be used |
60 |
| - /// to reduce the number of roundtrips to the database |
61 |
| - batch_size: 1024000, |
62 |
| - }) |
63 |
| - } |
64 |
| - |
65 |
| - /// Read the next record batch |
66 |
| - pub fn next(&mut self) -> Result<Option<RecordBatch>, ()> { |
67 |
| - let mut criteria = doc! {}; |
68 |
| - let mut project = doc! {}; |
69 |
| - for field in self.schema.fields() { |
70 |
| - project.insert(field.name(), bson::Bson::I32(1)); |
71 |
| - } |
72 |
| - criteria.insert("$project", project); |
73 |
| - let coll = self |
74 |
| - .client |
75 |
| - .database(self.database.as_ref()) |
76 |
| - .collection(self.collection.as_ref()); |
77 |
| - |
78 |
| - let aggregate_options = AggregateOptions::builder() |
79 |
| - .batch_size(Some(self.batch_size as u32)) |
80 |
| - .build(); |
81 |
| - |
82 |
| - let mut cursor = coll |
83 |
| - .aggregate( |
84 |
| - vec![criteria, doc! {"$skip": self.current_index as i32}], |
85 |
| - Some(aggregate_options), |
86 |
| - ) |
87 |
| - .expect("Unable to run aggregation"); |
88 |
| - |
89 |
| - // collect results from cursor into batches |
90 |
| - let mut docs = vec![]; |
91 |
| - for _ in 0..self.batch_size { |
92 |
| - if let Some(Ok(doc)) = cursor.next() { |
93 |
| - docs.push(doc); |
94 |
| - } else { |
95 |
| - break; |
96 |
| - } |
97 |
| - } |
98 |
| - |
99 |
| - let docs_len = docs.len(); |
100 |
| - self.current_index = self.current_index + docs_len; |
101 |
| - if docs_len == 0 { |
102 |
| - return Ok(None); |
103 |
| - } |
104 |
| - dbg!(&self.current_index); |
105 |
| - |
106 |
| - let mut builder = StructBuilder::from_schema(self.schema.clone(), self.current_index); |
107 |
| - |
108 |
| - let field_len = self.schema.fields().len(); |
109 |
| - for i in 0..field_len { |
110 |
| - let field = self.schema.field(i); |
111 |
| - match field.data_type() { |
112 |
| - DataType::Binary => {} |
113 |
| - DataType::Boolean => { |
114 |
| - let field_builder = builder.field_builder::<BooleanBuilder>(i).unwrap(); |
115 |
| - for v in 0..docs_len { |
116 |
| - let doc: &_ = docs.get(v).unwrap(); |
117 |
| - match doc.get_bool(field.name()) { |
118 |
| - Ok(val) => field_builder.append_value(val).unwrap(), |
119 |
| - Err(_) => field_builder.append_null().unwrap(), |
120 |
| - }; |
121 |
| - } |
122 |
| - } |
123 |
| - DataType::Timestamp(time_unit, _) => { |
124 |
| - let field_builder = match time_unit { |
125 |
| - TimeUnit::Millisecond => builder |
126 |
| - .field_builder::<TimestampMillisecondBuilder>(i) |
127 |
| - .unwrap(), |
128 |
| - t @ _ => panic!("Timestamp arrays can only be read as milliseconds, found {:?}. \nPlease read as milliseconds then cast to desired resolution.", t) |
129 |
| - }; |
130 |
| - for v in 0..docs_len { |
131 |
| - let doc: &_ = docs.get(v).unwrap(); |
132 |
| - match doc.get_utc_datetime(field.name()) { |
133 |
| - Ok(val) => field_builder.append_value(val.timestamp_millis()).unwrap(), |
134 |
| - Err(_) => field_builder.append_null().unwrap(), |
135 |
| - }; |
136 |
| - } |
137 |
| - } |
138 |
| - DataType::Float64 => { |
139 |
| - let field_builder = builder.field_builder::<Float64Builder>(i).unwrap(); |
140 |
| - for v in 0..docs_len { |
141 |
| - let doc: &_ = docs.get(v).unwrap(); |
142 |
| - match doc.get_f64(field.name()) { |
143 |
| - Ok(val) => field_builder.append_value(val).unwrap(), |
144 |
| - Err(_) => field_builder.append_null().unwrap(), |
145 |
| - }; |
146 |
| - } |
147 |
| - } |
148 |
| - DataType::Int32 => { |
149 |
| - let field_builder = builder.field_builder::<Int32Builder>(i).unwrap(); |
150 |
| - for v in 0..docs_len { |
151 |
| - let doc: &_ = docs.get(v).unwrap(); |
152 |
| - match doc.get_i32(field.name()) { |
153 |
| - Ok(val) => field_builder.append_value(val).unwrap(), |
154 |
| - Err(_) => field_builder.append_null().unwrap(), |
155 |
| - }; |
156 |
| - } |
157 |
| - } |
158 |
| - DataType::Int64 => { |
159 |
| - let field_builder = builder.field_builder::<Int64Builder>(i).unwrap(); |
160 |
| - for v in 0..docs_len { |
161 |
| - let doc: &_ = docs.get(v).unwrap(); |
162 |
| - match doc.get_i64(field.name()) { |
163 |
| - Ok(val) => field_builder.append_value(val).unwrap(), |
164 |
| - Err(_) => field_builder.append_null().unwrap(), |
165 |
| - }; |
166 |
| - } |
167 |
| - } |
168 |
| - DataType::Utf8 => { |
169 |
| - let field_builder = builder.field_builder::<StringBuilder>(i).unwrap(); |
170 |
| - for v in 0..docs_len { |
171 |
| - let doc: &_ = docs.get(v).unwrap(); |
172 |
| - match doc.get(field.name()) { |
173 |
| - Some(Bson::ObjectId(oid)) => { |
174 |
| - field_builder.append_value(oid.to_hex().as_str()).unwrap() |
175 |
| - } |
176 |
| - Some(Bson::String(val)) => field_builder.append_value(&val).unwrap(), |
177 |
| - Some(Bson::Null) => field_builder.append_null().unwrap(), |
178 |
| - Some(t) => panic!( |
179 |
| - "Option to cast non-string types to string not yet implemented for {:?}", t |
180 |
| - ), |
181 |
| - None => field_builder.append_null().unwrap(), |
182 |
| - }; |
183 |
| - } |
184 |
| - } |
185 |
| - DataType::List(_dtype) => panic!("Creating lists not yet implemented"), |
186 |
| - DataType::Struct(_fields) => panic!("Creating nested structs not yet implemented"), |
187 |
| - t @ _ => panic!("Data type {:?} not supported when reading from MongoDB", t), |
188 |
| - } |
189 |
| - } |
190 |
| - // append true to all struct records |
191 |
| - for _ in 0..docs_len { |
192 |
| - builder.append(true).unwrap(); |
193 |
| - } |
194 |
| - Ok(Some(RecordBatch::from(&builder.finish()))) |
195 |
| - } |
196 |
| -} |
197 |
| - |
198 |
| -impl RecordBatchReader for Reader { |
199 |
| - fn schema(&mut self) -> Arc<Schema> { |
200 |
| - Arc::new(self.schema.clone()) |
201 |
| - } |
202 |
| - fn next_batch(&mut self) -> arrow::error::Result<Option<RecordBatch>> { |
203 |
| - self.next().map_err(|_| { |
204 |
| - arrow::error::ArrowError::IoError("Unable to read next batch from MongoDB".to_string()) |
205 |
| - }) |
206 |
| - } |
207 |
| -} |
208 |
| - |
209 |
| -#[cfg(test)] |
210 |
| -mod tests { |
211 |
| - use super::*; |
212 |
| - |
213 |
| - use std::fs::File; |
214 |
| - |
215 |
| - use arrow::csv; |
216 |
| - use arrow::datatypes::Field; |
217 |
| - |
218 |
| - #[test] |
219 |
| - fn test_read_collection() -> Result<(), ()> { |
220 |
| - let fields = vec![ |
221 |
| - Field::new("_id", DataType::Utf8, false), |
222 |
| - Field::new("trip_id", DataType::Utf8, false), |
223 |
| - Field::new("trip_status", DataType::Utf8, false), |
224 |
| - Field::new("route_name", DataType::Utf8, false), |
225 |
| - Field::new("route_variant", DataType::Utf8, true), |
226 |
| - Field::new( |
227 |
| - "trip_date", |
228 |
| - DataType::Timestamp(TimeUnit::Millisecond, None), |
229 |
| - false, |
230 |
| - ), |
231 |
| - Field::new("trip_time", DataType::Int32, false), |
232 |
| - Field::new("direction", DataType::Utf8, false), |
233 |
| - Field::new("line", DataType::Utf8, true), |
234 |
| - Field::new("stop_id", DataType::Utf8, true), |
235 |
| - Field::new("stop_index", DataType::Int32, false), |
236 |
| - Field::new("scheduled_departure", DataType::Int32, false), |
237 |
| - Field::new("observed_departure", DataType::Int32, true), |
238 |
| - Field::new("stop_relevance", DataType::Utf8, false), |
239 |
| - ]; |
240 |
| - let schema = Schema::new(fields); |
241 |
| - let config = ReaderConfig { |
242 |
| - hostname: "localhost", |
243 |
| - port: None, |
244 |
| - database: "mycollection", |
245 |
| - collection: "delays", |
246 |
| - }; |
247 |
| - let mut reader = Reader::try_new(&config, schema)?; |
248 |
| - |
249 |
| - // write results to CSV as the schema would allow |
250 |
| - let file = File::create("./target/debug/delays.csv").unwrap(); |
251 |
| - let mut writer = csv::Writer::new(file); |
252 |
| - while let Ok(Some(batch)) = reader.next() { |
253 |
| - writer.write(&batch).unwrap(); |
254 |
| - } |
255 |
| - Ok(()) |
256 |
| - } |
257 |
| -} |
0 commit comments