@@ -21,13 +21,61 @@ use arrow_array::builder::BooleanBufferBuilder;
2121use arrow_buffer:: buffer:: NullBuffer ;
2222use arrow_data:: { ArrayData , ArrayDataBuilder } ;
2323use arrow_schema:: { ArrowError , DataType , Fields } ;
24+ use std:: collections:: HashMap ;
25+
26+ /// Reusable buffer for tape positions, indexed by (field_idx, row_idx).
27+ /// A value of 0 indicates the field is absent for that row.
28+ struct FieldTapePositions {
29+ data : Vec < u32 > ,
30+ row_count : usize ,
31+ }
32+
33+ impl FieldTapePositions {
34+ fn new ( ) -> Self {
35+ Self {
36+ data : Vec :: new ( ) ,
37+ row_count : 0 ,
38+ }
39+ }
40+
41+ fn resize ( & mut self , field_count : usize , row_count : usize ) -> Result < ( ) , ArrowError > {
42+ let total_len = field_count. checked_mul ( row_count) . ok_or_else ( || {
43+ ArrowError :: JsonError ( format ! (
44+ "FieldTapePositions buffer size overflow for rows={row_count} fields={field_count}"
45+ ) )
46+ } ) ?;
47+ self . data . clear ( ) ;
48+ self . data . resize ( total_len, 0 ) ;
49+ self . row_count = row_count;
50+ Ok ( ( ) )
51+ }
52+
53+ fn try_set ( & mut self , field_idx : usize , row_idx : usize , pos : u32 ) -> Option < ( ) > {
54+ let idx = field_idx
55+ . checked_mul ( self . row_count ) ?
56+ . checked_add ( row_idx) ?;
57+ * self . data . get_mut ( idx) ? = pos;
58+ Some ( ( ) )
59+ }
60+
61+ fn set ( & mut self , field_idx : usize , row_idx : usize , pos : u32 ) {
62+ self . data [ field_idx * self . row_count + row_idx] = pos;
63+ }
64+
65+ fn field_positions ( & self , field_idx : usize ) -> & [ u32 ] {
66+ let start = field_idx * self . row_count ;
67+ & self . data [ start..start + self . row_count ]
68+ }
69+ }
2470
2571pub struct StructArrayDecoder {
2672 data_type : DataType ,
2773 decoders : Vec < Box < dyn ArrayDecoder > > ,
2874 strict_mode : bool ,
2975 is_nullable : bool ,
3076 struct_mode : StructMode ,
77+ field_name_to_index : Option < HashMap < String , usize > > ,
78+ field_tape_positions : FieldTapePositions ,
3179}
3280
3381impl StructArrayDecoder {
@@ -38,119 +86,140 @@ impl StructArrayDecoder {
3886 is_nullable : bool ,
3987 struct_mode : StructMode ,
4088 ) -> Result < Self , ArrowError > {
41- let decoders = struct_fields ( & data_type)
42- . iter ( )
43- . map ( |f| {
44- // If this struct nullable, need to permit nullability in child array
45- // StructArrayDecoder::decode verifies that if the child is not nullable
46- // it doesn't contain any nulls not masked by its parent
47- let nullable = f. is_nullable ( ) || is_nullable;
48- make_decoder (
49- f. data_type ( ) . clone ( ) ,
50- coerce_primitive,
51- strict_mode,
52- nullable,
53- struct_mode,
54- )
55- } )
56- . collect :: < Result < Vec < _ > , ArrowError > > ( ) ?;
89+ let ( decoders, field_name_to_index) = {
90+ let fields = struct_fields ( & data_type) ;
91+ let decoders = fields
92+ . iter ( )
93+ . map ( |f| {
94+ // If this struct nullable, need to permit nullability in child array
95+ // StructArrayDecoder::decode verifies that if the child is not nullable
96+ // it doesn't contain any nulls not masked by its parent
97+ let nullable = f. is_nullable ( ) || is_nullable;
98+ make_decoder (
99+ f. data_type ( ) . clone ( ) ,
100+ coerce_primitive,
101+ strict_mode,
102+ nullable,
103+ struct_mode,
104+ )
105+ } )
106+ . collect :: < Result < Vec < _ > , ArrowError > > ( ) ?;
107+ let field_name_to_index = if struct_mode == StructMode :: ObjectOnly {
108+ build_field_index ( fields)
109+ } else {
110+ None
111+ } ;
112+ ( decoders, field_name_to_index)
113+ } ;
57114
58115 Ok ( Self {
59116 data_type,
60117 decoders,
61118 strict_mode,
62119 is_nullable,
63120 struct_mode,
121+ field_name_to_index,
122+ field_tape_positions : FieldTapePositions :: new ( ) ,
64123 } )
65124 }
66125}
67126
68127impl ArrayDecoder for StructArrayDecoder {
69128 fn decode ( & mut self , tape : & Tape < ' _ > , pos : & [ u32 ] ) -> Result < ArrayData , ArrowError > {
70129 let fields = struct_fields ( & self . data_type ) ;
71- let mut child_pos: Vec < _ > = ( 0 ..fields. len ( ) ) . map ( |_| vec ! [ 0 ; pos. len( ) ] ) . collect ( ) ;
72-
130+ let row_count = pos. len ( ) ;
131+ let field_count = fields. len ( ) ;
132+ self . field_tape_positions . resize ( field_count, row_count) ?;
73133 let mut nulls = self
74134 . is_nullable
75135 . then ( || BooleanBufferBuilder :: new ( pos. len ( ) ) ) ;
76136
77- // We avoid having the match on self.struct_mode inside the hot loop for performance
78- // TODO: Investigate how to extract duplicated logic.
79- match self . struct_mode {
80- StructMode :: ObjectOnly => {
81- for ( row, p) in pos. iter ( ) . enumerate ( ) {
82- let end_idx = match ( tape. get ( * p) , nulls. as_mut ( ) ) {
83- ( TapeElement :: StartObject ( end_idx) , None ) => end_idx,
84- ( TapeElement :: StartObject ( end_idx) , Some ( nulls) ) => {
85- nulls. append ( true ) ;
86- end_idx
87- }
88- ( TapeElement :: Null , Some ( nulls) ) => {
89- nulls. append ( false ) ;
90- continue ;
91- }
92- ( _, _) => return Err ( tape. error ( * p, "{" ) ) ,
93- } ;
94-
95- let mut cur_idx = * p + 1 ;
96- while cur_idx < end_idx {
97- // Read field name
98- let field_name = match tape. get ( cur_idx) {
99- TapeElement :: String ( s) => tape. get_string ( s) ,
100- _ => return Err ( tape. error ( cur_idx, "field name" ) ) ,
137+ {
138+ // We avoid having the match on self.struct_mode inside the hot loop for performance
139+ // TODO: Investigate how to extract duplicated logic.
140+ match self . struct_mode {
141+ StructMode :: ObjectOnly => {
142+ for ( row, p) in pos. iter ( ) . enumerate ( ) {
143+ let end_idx = match ( tape. get ( * p) , nulls. as_mut ( ) ) {
144+ ( TapeElement :: StartObject ( end_idx) , None ) => end_idx,
145+ ( TapeElement :: StartObject ( end_idx) , Some ( nulls) ) => {
146+ nulls. append ( true ) ;
147+ end_idx
148+ }
149+ ( TapeElement :: Null , Some ( nulls) ) => {
150+ nulls. append ( false ) ;
151+ continue ;
152+ }
153+ ( _, _) => return Err ( tape. error ( * p, "{" ) ) ,
101154 } ;
102155
103- // Update child pos if match found
104- match fields. iter ( ) . position ( |x| x. name ( ) == field_name) {
105- Some ( field_idx) => child_pos[ field_idx] [ row] = cur_idx + 1 ,
106- None => {
107- if self . strict_mode {
108- return Err ( ArrowError :: JsonError ( format ! (
109- "column '{field_name}' missing from schema" ,
110- ) ) ) ;
156+ let mut cur_idx = * p + 1 ;
157+ while cur_idx < end_idx {
158+ // Read field name
159+ let field_name = match tape. get ( cur_idx) {
160+ TapeElement :: String ( s) => tape. get_string ( s) ,
161+ _ => return Err ( tape. error ( cur_idx, "field name" ) ) ,
162+ } ;
163+
164+ // Update child pos if match found
165+ let field_idx = match & self . field_name_to_index {
166+ Some ( map) => map. get ( field_name) . copied ( ) ,
167+ None => fields. iter ( ) . position ( |x| x. name ( ) == field_name) ,
168+ } ;
169+ match field_idx {
170+ Some ( field_idx) => {
171+ self . field_tape_positions . set ( field_idx, row, cur_idx + 1 ) ;
172+ }
173+ None => {
174+ if self . strict_mode {
175+ return Err ( ArrowError :: JsonError ( format ! (
176+ "column '{field_name}' missing from schema" ,
177+ ) ) ) ;
178+ }
111179 }
112180 }
181+ // Advance to next field
182+ cur_idx = tape. next ( cur_idx + 1 , "field value" ) ?;
113183 }
114- // Advance to next field
115- cur_idx = tape. next ( cur_idx + 1 , "field value" ) ?;
116184 }
117185 }
118- }
119- StructMode :: ListOnly => {
120- for ( row, p) in pos. iter ( ) . enumerate ( ) {
121- let end_idx = match ( tape. get ( * p) , nulls. as_mut ( ) ) {
122- ( TapeElement :: StartList ( end_idx) , None ) => end_idx,
123- ( TapeElement :: StartList ( end_idx) , Some ( nulls) ) => {
124- nulls. append ( true ) ;
125- end_idx
126- }
127- ( TapeElement :: Null , Some ( nulls) ) => {
128- nulls. append ( false ) ;
129- continue ;
130- }
131- ( _, _) => return Err ( tape. error ( * p, "[" ) ) ,
132- } ;
186+ StructMode :: ListOnly => {
187+ for ( row, p) in pos. iter ( ) . enumerate ( ) {
188+ let end_idx = match ( tape. get ( * p) , nulls. as_mut ( ) ) {
189+ ( TapeElement :: StartList ( end_idx) , None ) => end_idx,
190+ ( TapeElement :: StartList ( end_idx) , Some ( nulls) ) => {
191+ nulls. append ( true ) ;
192+ end_idx
193+ }
194+ ( TapeElement :: Null , Some ( nulls) ) => {
195+ nulls. append ( false ) ;
196+ continue ;
197+ }
198+ ( _, _) => return Err ( tape. error ( * p, "[" ) ) ,
199+ } ;
133200
134- let mut cur_idx = * p + 1 ;
135- let mut entry_idx = 0 ;
136- while cur_idx < end_idx {
137- if entry_idx >= fields. len ( ) {
201+ let mut cur_idx = * p + 1 ;
202+ let mut entry_idx = 0 ;
203+ while cur_idx < end_idx {
204+ self . field_tape_positions
205+ . try_set ( entry_idx, row, cur_idx)
206+ . ok_or_else ( || {
207+ ArrowError :: JsonError ( format ! (
208+ "found extra columns for {} fields" ,
209+ fields. len( )
210+ ) )
211+ } ) ?;
212+ entry_idx += 1 ;
213+ // Advance to next field
214+ cur_idx = tape. next ( cur_idx, "field value" ) ?;
215+ }
216+ if entry_idx != fields. len ( ) {
138217 return Err ( ArrowError :: JsonError ( format ! (
139- "found extra columns for {} fields" ,
218+ "found {} columns for {} fields" ,
219+ entry_idx,
140220 fields. len( )
141221 ) ) ) ;
142222 }
143- child_pos[ entry_idx] [ row] = cur_idx;
144- entry_idx += 1 ;
145- // Advance to next field
146- cur_idx = tape. next ( cur_idx, "field value" ) ?;
147- }
148- if entry_idx != fields. len ( ) {
149- return Err ( ArrowError :: JsonError ( format ! (
150- "found {} columns for {} fields" ,
151- entry_idx,
152- fields. len( )
153- ) ) ) ;
154223 }
155224 }
156225 }
@@ -159,10 +228,11 @@ impl ArrayDecoder for StructArrayDecoder {
159228 let child_data = self
160229 . decoders
161230 . iter_mut ( )
162- . zip ( child_pos )
231+ . enumerate ( )
163232 . zip ( fields)
164- . map ( |( ( d, pos) , f) | {
165- d. decode ( tape, & pos) . map_err ( |e| match e {
233+ . map ( |( ( field_idx, d) , f) | {
234+ let pos = self . field_tape_positions . field_positions ( field_idx) ;
235+ d. decode ( tape, pos) . map_err ( |e| match e {
166236 ArrowError :: JsonError ( s) => {
167237 ArrowError :: JsonError ( format ! ( "whilst decoding field '{}': {s}" , f. name( ) ) )
168238 }
@@ -205,3 +275,20 @@ fn struct_fields(data_type: &DataType) -> &Fields {
205275 _ => unreachable ! ( ) ,
206276 }
207277}
278+
279+ fn build_field_index ( fields : & Fields ) -> Option < HashMap < String , usize > > {
280+ // Heuristic threshold: for small field counts, linear scan avoids HashMap overhead.
281+ const FIELD_INDEX_LINEAR_THRESHOLD : usize = 16 ;
282+ if fields. len ( ) < FIELD_INDEX_LINEAR_THRESHOLD {
283+ return None ;
284+ }
285+
286+ let mut map = HashMap :: with_capacity ( fields. len ( ) ) ;
287+ for ( idx, field) in fields. iter ( ) . enumerate ( ) {
288+ let name = field. name ( ) ;
289+ if !map. contains_key ( name) {
290+ map. insert ( name. to_string ( ) , idx) ;
291+ }
292+ }
293+ Some ( map)
294+ }
0 commit comments