1+ use arrow:: array:: { Float64Array , Float32Array , Int16Array , Int32Array , Int64Array , Int8Array , BooleanArray } ;
2+ use arrow:: record_batch:: RecordBatch ;
3+ use arrow:: array:: BooleanBuilder ;
4+ use arrow:: datatypes:: Schema ;
5+ use std:: sync:: Arc ;
6+
7+ /// Create a boolean mask based on the filters provided.
8+ ///
9+ /// # Arguments
10+ ///
11+ /// * `batch` - A reference to a RecordBatch that will be filtered.
12+ /// * `original_schema` - A reference to the original schema of the RecordBatch.
13+ /// * `filters` - A vector of tuples containing the column name, the comparison operator and the value to compare.
14+ ///
15+ /// # Returns
16+ ///
17+ /// This function returns an Arrow Result with the boolean mask.
18+ pub fn create_boolean_mask ( batch : & RecordBatch , original_schema : & Arc < Schema > , filters : Vec < ( & str , & str , & str ) > ) -> arrow:: error:: Result < Arc < BooleanArray > > {
19+ let num_rows = batch. num_rows ( ) ;
20+ let mut boolean_builder = BooleanBuilder :: new ( ) ;
21+
22+ // Initialize all rows as true
23+ for _ in 0 ..num_rows {
24+ boolean_builder. append_value ( true ) ;
25+ }
26+ let mut boolean_mask = boolean_builder. finish ( ) ;
27+
28+ for filter in filters. iter ( ) {
29+ let column = batch. column ( original_schema. index_of ( filter. 0 ) . unwrap ( ) ) ;
30+
31+ if column. data_type ( ) == & arrow:: datatypes:: DataType :: Float32 {
32+ let column = column. as_any ( ) . downcast_ref :: < Float32Array > ( ) . unwrap ( ) ;
33+ apply_filter ( & mut boolean_mask, column, filter) ?;
34+ } else if column. data_type ( ) == & arrow:: datatypes:: DataType :: Float64 {
35+ let column = column. as_any ( ) . downcast_ref :: < Float64Array > ( ) . unwrap ( ) ;
36+ apply_filter ( & mut boolean_mask, column, filter) ?;
37+ } else if column. data_type ( ) == & arrow:: datatypes:: DataType :: Int16 {
38+ let column = column. as_any ( ) . downcast_ref :: < Int16Array > ( ) . unwrap ( ) ;
39+ apply_filter ( & mut boolean_mask, column, filter) ?;
40+ } else if column. data_type ( ) == & arrow:: datatypes:: DataType :: Int32 {
41+ let column = column. as_any ( ) . downcast_ref :: < Int32Array > ( ) . unwrap ( ) ;
42+ apply_filter ( & mut boolean_mask, column, filter) ?;
43+ } else if column. data_type ( ) == & arrow:: datatypes:: DataType :: Int64 {
44+ let column = column. as_any ( ) . downcast_ref :: < Int64Array > ( ) . unwrap ( ) ;
45+ apply_filter ( & mut boolean_mask, column, filter) ?;
46+ } else if column. data_type ( ) == & arrow:: datatypes:: DataType :: Int8 {
47+ let column = column. as_any ( ) . downcast_ref :: < Int8Array > ( ) . unwrap ( ) ;
48+ apply_filter ( & mut boolean_mask, column, filter) ?;
49+ } else if column. data_type ( ) == & arrow:: datatypes:: DataType :: Boolean {
50+ let column = column. as_any ( ) . downcast_ref :: < Int16Array > ( ) . unwrap ( ) ;
51+ apply_filter ( & mut boolean_mask, column, filter) ?;
52+ } else {
53+ return Err ( arrow:: error:: ArrowError :: NotYetImplemented ( format ! ( "Data type {:?} not yet implemented" , column. data_type( ) ) ) ) ;
54+ }
55+ }
56+ Ok ( Arc :: new ( boolean_mask) )
57+ }
58+
59+ /// Apply a filter to a column and update the boolean mask.
60+ ///
61+ /// # Arguments
62+ ///
63+ /// * `boolean_mask` - A mutable reference to a BooleanArray that will be updated with the filter results.
64+ /// * `column` - A reference to a PrimitiveArray that will be filtered.
65+ /// * `filter` - A tuple containing the column name, the comparison operator and the value to compare.
66+ ///
67+ /// # Returns
68+ ///
69+ /// This function returns an Arrow Result.
70+ fn apply_filter < T > ( boolean_mask : & mut BooleanArray , column : & arrow:: array:: PrimitiveArray < T > , filter : & ( & str , & str , & str ) ) -> arrow:: error:: Result < ( ) >
71+ where
72+ T : arrow:: datatypes:: ArrowPrimitiveType ,
73+ T :: Native : std:: cmp:: PartialOrd + std:: str:: FromStr ,
74+ <T :: Native as std:: str:: FromStr >:: Err : std:: fmt:: Debug ,
75+ {
76+ let filter_value = filter. 2 . parse :: < T :: Native > ( ) . unwrap ( ) ;
77+ let mut new_mask = BooleanBuilder :: new ( ) ;
78+
79+ for ( index, value) in column. iter ( ) . enumerate ( ) {
80+ let current_mask = boolean_mask. value ( index) ;
81+ let result = match filter. 1 {
82+ ">" => value. map_or ( false , |v| v > filter_value) ,
83+ "<" => value. map_or ( false , |v| v < filter_value) ,
84+ "=" => value. map_or ( false , |v| v == filter_value) ,
85+ "!=" => value. map_or ( false , |v| v != filter_value) ,
86+ ">=" => value. map_or ( false , |v| v >= filter_value) ,
87+ "<=" => value. map_or ( false , |v| v <= filter_value) ,
88+ "==" => value. map_or ( false , |v| v == filter_value) ,
89+ _ => false ,
90+ } ;
91+ new_mask. append_value ( current_mask && result) ;
92+ }
93+
94+ * boolean_mask = new_mask. finish ( ) ;
95+ Ok ( ( ) )
96+ }
0 commit comments