|
6 | 6 | * compatible open source license. |
7 | 7 | */ |
8 | 8 |
|
9 | | -use jni::objects::JClass; |
10 | | -use jni::sys::{jlong, jstring}; |
| 9 | +use jni::objects::{JByteArray, JClass}; |
| 10 | +use jni::sys::{jbyteArray, jlong, jstring}; |
11 | 11 | use jni::JNIEnv; |
12 | 12 | use std::sync::Arc; |
13 | 13 |
|
| 14 | +mod util; |
| 15 | + |
14 | 16 | use datafusion::execution::context::SessionContext; |
15 | 17 |
|
16 | 18 | use datafusion::DATAFUSION_VERSION; |
| 19 | +use datafusion::datasource::file_format::csv::CsvFormat; |
| 20 | +use datafusion::datasource::file_format::parquet::ParquetFormat; |
17 | 21 | use datafusion::execution::cache::cache_manager::{CacheManager, CacheManagerConfig, FileStatisticsCache}; |
18 | 22 | use datafusion::execution::disk_manager::DiskManagerConfig; |
19 | 23 | use datafusion::execution::runtime_env::{RuntimeEnv, RuntimeEnvBuilder}; |
20 | 24 | use datafusion::prelude::SessionConfig; |
| 25 | +use crate::util::{create_object_meta_from_filenames, parse_string_arr}; |
| 26 | +use datafusion::datasource::listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}; |
| 27 | +use datafusion::execution::cache::cache_unit::DefaultListFilesCache; |
| 28 | +use datafusion::execution::cache::CacheAccessor; |
| 29 | +use datafusion::execution::SendableRecordBatchStream; |
| 30 | +use datafusion_substrait::logical_plan::consumer::from_substrait_plan; |
| 31 | +use datafusion_substrait::substrait::proto::Plan; |
| 32 | +use jni::objects::{JObjectArray, JString}; |
| 33 | +use prost::Message; |
| 34 | +use tokio::runtime::Runtime; |
| 35 | +use object_store::ObjectMeta; |
21 | 36 |
|
22 | 37 | /// Create a new DataFusion session context |
23 | 38 | #[no_mangle] |
@@ -110,6 +125,232 @@ pub extern "system" fn Java_org_opensearch_datafusion_DataFusionQueryJNI_closeSe |
110 | 125 | } |
111 | 126 |
|
112 | 127 |
|
| 128 | +#[no_mangle] |
| 129 | +pub extern "system" fn Java_org_opensearch_datafusion_DataFusionQueryJNI_createReader( |
| 130 | + mut env: JNIEnv, |
| 131 | + _class: JClass, |
| 132 | + table_path: JString, |
| 133 | + files: JObjectArray |
| 134 | +) -> jlong { |
| 135 | + |
| 136 | + let table_path: String = env.get_string(&table_path).expect("Couldn't get java string!").into(); |
| 137 | + let files: Vec<String> = parse_string_arr(&mut env, files).expect("Expected list of files"); |
| 138 | + let files_meta = create_object_meta_from_filenames(&table_path, files); |
| 139 | + |
| 140 | + let table_path = ListingTableUrl::parse(table_path).unwrap(); |
| 141 | + let shard_view = ShardView::new(table_path, files_meta); |
| 142 | + Box::into_raw(Box::new(shard_view)) as jlong |
| 143 | +} |
| 144 | + |
| 145 | +#[no_mangle] |
| 146 | +pub extern "system" fn Java_org_opensearch_datafusion_DataFusionQueryJNI_destroyReader( |
| 147 | + mut env: JNIEnv, |
| 148 | + _class: JClass, |
| 149 | + ptr: jlong |
| 150 | +) { |
| 151 | + let _ = unsafe { Box::from_raw(ptr as *mut ShardView) }; |
| 152 | +} |
| 153 | + |
| 154 | +pub struct ShardView { |
| 155 | + table_path: ListingTableUrl, |
| 156 | + files_meta: Arc<Vec<ObjectMeta>> |
| 157 | +} |
| 158 | + |
| 159 | +impl ShardView { |
| 160 | + pub fn new(table_path: ListingTableUrl, files_meta: Vec<ObjectMeta>) -> Self { |
| 161 | + let files_meta = Arc::new(files_meta); |
| 162 | + ShardView { |
| 163 | + table_path, |
| 164 | + files_meta |
| 165 | + } |
| 166 | + } |
| 167 | + |
| 168 | + pub fn table_path(&self) -> ListingTableUrl { |
| 169 | + self.table_path.clone() |
| 170 | + } |
| 171 | + |
| 172 | + pub fn files_meta(&self) -> Arc<Vec<ObjectMeta>> { |
| 173 | + self.files_meta.clone() |
| 174 | + } |
| 175 | +} |
| 176 | + |
| 177 | + |
| 178 | +#[no_mangle] |
| 179 | +pub extern "system" fn Java_org_opensearch_datafusion_DataFusionQueryJNI_nativeExecuteSubstraitQuery( |
| 180 | + mut env: JNIEnv, |
| 181 | + _class: JClass, |
| 182 | + shard_view_ptr: jlong, |
| 183 | + substrait_bytes: jbyteArray, |
| 184 | + // callback: JObject, |
| 185 | +) -> jlong { |
| 186 | + let shard_view = unsafe { &*(shard_view_ptr as *const ShardView) }; |
| 187 | + let table_path = shard_view.table_path(); |
| 188 | + let files_meta = shard_view.files_meta(); |
| 189 | + |
| 190 | + // Will use it once the global RunTime is defined |
| 191 | + // let runtime_arc = unsafe { |
| 192 | + // let boxed = &*(runtime_env_ptr as *const Pin<Arc<RuntimeEnv>>); |
| 193 | + // (**boxed).clone() |
| 194 | + // }; |
| 195 | + |
| 196 | + let list_file_cache = Arc::new(DefaultListFilesCache::default()); |
| 197 | + list_file_cache.put(table_path.prefix(), files_meta); |
| 198 | + |
| 199 | + let runtime_env = RuntimeEnvBuilder::new() |
| 200 | + .with_cache_manager(CacheManagerConfig::default() |
| 201 | + .with_list_files_cache(Some(list_file_cache))).build().unwrap(); |
| 202 | + |
| 203 | + |
| 204 | + |
| 205 | + let ctx = SessionContext::new_with_config_rt(SessionConfig::new(), Arc::new(runtime_env)); |
| 206 | + |
| 207 | + |
| 208 | + // Create default parquet options |
| 209 | + let file_format = CsvFormat::default(); |
| 210 | + let listing_options = ListingOptions::new(Arc::new(file_format)) |
| 211 | + .with_file_extension(".csv"); |
| 212 | + |
| 213 | + // Ideally the executor will give this |
| 214 | + Runtime::new().expect("Failed to create Tokio Runtime").block_on(async { |
| 215 | + let resolved_schema = listing_options |
| 216 | + .infer_schema(&ctx.state(), &table_path.clone()) |
| 217 | + .await.unwrap(); |
| 218 | + |
| 219 | + |
| 220 | + let config = ListingTableConfig::new(table_path.clone()) |
| 221 | + .with_listing_options(listing_options) |
| 222 | + .with_schema(resolved_schema); |
| 223 | + |
| 224 | + // Create a new TableProvider |
| 225 | + let provider = Arc::new(ListingTable::try_new(config).unwrap()); |
| 226 | + let shard_id = table_path.prefix().filename().expect("error in fetching Path"); |
| 227 | + ctx.register_table(shard_id, provider) |
| 228 | + .expect("Failed to attach the Table"); |
| 229 | + |
| 230 | + }); |
| 231 | + |
| 232 | + // TODO : how to close ctx ? |
| 233 | + // Convert Java byte array to Rust Vec<u8> |
| 234 | + let plan_bytes_obj = unsafe { JByteArray::from_raw(substrait_bytes) }; |
| 235 | + let plan_bytes_vec = match env.convert_byte_array(plan_bytes_obj) { |
| 236 | + Ok(bytes) => bytes, |
| 237 | + Err(e) => { |
| 238 | + let error_msg = format!("Failed to convert plan bytes: {}", e); |
| 239 | + env.throw_new("java/lang/Exception", error_msg); |
| 240 | + return 0; |
| 241 | + } |
| 242 | + }; |
| 243 | + |
| 244 | + let substrait_plan = match Plan::decode(plan_bytes_vec.as_slice()) { |
| 245 | + Ok(plan) => { |
| 246 | + println!("SUBSTRAIT rust: Decoding is successful, Plan has {} relations", plan.relations.len()); |
| 247 | + plan |
| 248 | + }, |
| 249 | + Err(e) => { |
| 250 | + return 0; |
| 251 | + } |
| 252 | + }; |
| 253 | + |
| 254 | + //let runtime = unsafe { &mut *(runtime_ptr as *mut Runtime) }; |
| 255 | + Runtime::new().expect("Failed to create Tokio Runtime").block_on(async { |
| 256 | + |
| 257 | + let logical_plan = match from_substrait_plan(&ctx.state(), &substrait_plan).await { |
| 258 | + Ok(plan) => { |
| 259 | + println!("SUBSTRAIT Rust: LogicalPlan: {:?}", plan); |
| 260 | + plan |
| 261 | + }, |
| 262 | + Err(e) => { |
| 263 | + println!("SUBSTRAIT Rust: Failed to convert Substrait plan: {}", e); |
| 264 | + return; |
| 265 | + } |
| 266 | + }; |
| 267 | + |
| 268 | + let dataframe = ctx.execute_logical_plan(logical_plan) |
| 269 | + .await.expect("Failed to run Logical Plan"); |
| 270 | + |
| 271 | + // TODO : check if this works |
| 272 | + return match dataframe.execute_stream() { |
| 273 | + Ok(stream) => { |
| 274 | + let boxed_stream = Box::new(stream); |
| 275 | + let stream_ptr = Box::into_raw(boxed_stream); |
| 276 | + stream_ptr as jlong |
| 277 | + }, |
| 278 | + Err(e) => { |
| 279 | + 0 |
| 280 | + } |
| 281 | + } |
| 282 | + }) |
| 283 | + |
| 284 | + |
| 285 | + // Create DataFrame from the converted logical plan |
| 286 | + |
| 287 | + |
| 288 | +} |
| 289 | + |
| 290 | +// If we need to create session context separately |
| 291 | +#[no_mangle] |
| 292 | +pub extern "system" fn Java_org_opensearch_datafusion_DataFusionQueryJNI_nativeCreateSessionContext( |
| 293 | + mut env: JNIEnv, |
| 294 | + _class: JClass, |
| 295 | + runtime_ptr: jlong, |
| 296 | + shard_view_ptr: jlong, |
| 297 | + global_runtime_env_ptr: jlong, |
| 298 | +) -> jlong { |
| 299 | + let shard_view = unsafe { &*(shard_view_ptr as *const ShardView) }; |
| 300 | + let table_path = shard_view.table_path(); |
| 301 | + let files_meta = shard_view.files_meta(); |
| 302 | + |
| 303 | + // Will use it once the global RunTime is defined |
| 304 | + // let runtime_arc = unsafe { |
| 305 | + // let boxed = &*(runtime_env_ptr as *const Pin<Arc<RuntimeEnv>>); |
| 306 | + // (**boxed).clone() |
| 307 | + // }; |
| 308 | + |
| 309 | + let list_file_cache = Arc::new(DefaultListFilesCache::default()); |
| 310 | + list_file_cache.put(table_path.prefix(), files_meta); |
| 311 | + |
| 312 | + let runtime_env = RuntimeEnvBuilder::new() |
| 313 | + .with_cache_manager(CacheManagerConfig::default() |
| 314 | + .with_list_files_cache(Some(list_file_cache))).build().unwrap(); |
| 315 | + |
| 316 | + |
| 317 | + |
| 318 | + let ctx = SessionContext::new_with_config_rt(SessionConfig::new(), Arc::new(runtime_env)); |
| 319 | + |
| 320 | + |
| 321 | + // Create default parquet options |
| 322 | + let file_format = CsvFormat::default(); |
| 323 | + let listing_options = ListingOptions::new(Arc::new(file_format)) |
| 324 | + .with_file_extension(".csv"); |
| 325 | + |
| 326 | + |
| 327 | + // let runtime = unsafe { &mut *(runtime_ptr as *mut Runtime) }; |
| 328 | + let mut session_context_ptr = 0; |
| 329 | + |
| 330 | + // Ideally the executor will give this |
| 331 | + Runtime::new().expect("Failed to create Tokio Runtime").block_on(async { |
| 332 | + let resolved_schema = listing_options |
| 333 | + .infer_schema(&ctx.state(), &table_path.clone()) |
| 334 | + .await.unwrap(); |
| 335 | + |
| 336 | + |
| 337 | + let config = ListingTableConfig::new(table_path.clone()) |
| 338 | + .with_listing_options(listing_options) |
| 339 | + .with_schema(resolved_schema); |
| 340 | + |
| 341 | + // Create a new TableProvider |
| 342 | + let provider = Arc::new(ListingTable::try_new(config).unwrap()); |
| 343 | + let shard_id = table_path.prefix().filename().expect("error in fetching Path"); |
| 344 | + ctx.register_table(shard_id, provider) |
| 345 | + .expect("Failed to attach the Table"); |
| 346 | + |
| 347 | + // Return back after wrapping in Box |
| 348 | + session_context_ptr = Box::into_raw(Box::new(ctx)) as jlong |
| 349 | + }); |
| 350 | + |
| 351 | + session_context_ptr |
| 352 | +} |
| 353 | + |
113 | 354 |
|
114 | 355 |
|
115 | 356 |
|
0 commit comments