1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from pyarrow import ListArray , StructArray , Table
14+ from pyarrow import ListArray , StructArray , Table , timestamp
1515from pyarrow .types import is_struct
1616
1717from pymongoarrow .types import _BsonArrowTypes , _get_internal_typemap
@@ -73,55 +73,55 @@ def __init__(self, schema, codec_options=None):
7373 self .tzinfo = codec_options .tzinfo
7474 else :
7575 self .tzinfo = None
76- self . manager = BuilderManager ( self . schema is not None , self . tzinfo )
76+ builder_map = {}
7777 if self .schema is not None :
78- schema_map = {}
7978 str_type_map = _get_internal_typemap (schema .typemap )
80- _parse_types (str_type_map , schema_map , self .tzinfo )
81- self .manager .parse_types (schema_map )
79+ _parse_types (str_type_map , builder_map , self .tzinfo )
80+
81+ self .manager = BuilderManager (builder_map , self .schema is not None , self .tzinfo )
8282
8383 def process_bson_stream (self , stream ):
8484 self .manager .process_bson_stream (stream , len (stream ))
8585
8686 def finish (self ):
87- builder_map = self .manager .finish ().copy ()
88-
89- # Handle nested builders.
90- to_remove = []
91- # Traverse the builder map right to left.
92- for key , value in reversed (builder_map .items ()):
93- field = key .decode ("utf-8" )
94- if isinstance (value , DocumentBuilder ):
95- arr = value .finish ()
96- full_names = [f"{ field } .{ name .decode ('utf-8' )} " for name in arr ]
97- arrs = [builder_map [c .encode ("utf-8" )] for c in full_names ]
98- builder_map [field ] = StructArray .from_arrays (arrs , names = arr )
99- to_remove .extend (full_names )
100- elif isinstance (value , ListBuilder ):
101- arr = value .finish ()
102- child_name = field + "[]"
103- to_remove .append (child_name )
104- child = builder_map [child_name .encode ("utf-8" )]
105- builder_map [key ] = ListArray .from_arrays (arr , child )
106- else :
107- builder_map [key ] = value .finish ()
108-
109- for field in to_remove :
110- key = field .encode ("utf-8" )
111- if key in builder_map :
112- del builder_map [key ]
113-
87+ builder_map = _parse_builder_map (self .manager .finish ())
11488 arrays = list (builder_map .values ())
11589 if self .schema is not None :
11690 return Table .from_arrays (arrays = arrays , schema = self .schema .to_arrow ())
11791 return Table .from_arrays (arrays = arrays , names = list (builder_map .keys ()))
11892
11993
120- def _parse_types (str_type_map , schema_map , tzinfo ):
94+ def _parse_builder_map (builder_map ):
95+ # Handle nested builders.
96+ to_remove = []
97+ # Traverse the builder map right to left.
98+ for key , value in reversed (builder_map .items ()):
99+ field = key .decode ("utf-8" )
100+ if isinstance (value , DocumentBuilder ):
101+ arr = value .finish ()
102+ full_names = [f"{ field } .{ name .decode ('utf-8' )} " for name in arr ]
103+ arrs = [builder_map [c .encode ("utf-8" )] for c in full_names ]
104+ builder_map [field ] = StructArray .from_arrays (arrs , names = arr )
105+ to_remove .extend (full_names )
106+ elif isinstance (value , ListBuilder ):
107+ arr = value .finish ()
108+ child_name = field + "[]"
109+ to_remove .append (child_name )
110+ child = builder_map [child_name .encode ("utf-8" )]
111+ builder_map [key ] = ListArray .from_arrays (arr , child )
112+ else :
113+ builder_map [key ] = value .finish ()
114+
115+ for field in to_remove :
116+ key = field .encode ("utf-8" )
117+ if key in builder_map :
118+ del builder_map [key ]
119+
120+
121+ def _parse_types (str_type_map , builder_map , tzinfo ):
121122 for fname , (ftype , arrow_type ) in str_type_map .items ():
122123 builder_cls = _TYPE_TO_BUILDER_CLS [ftype ]
123124 encoded_fname = fname .encode ("utf-8" )
124- schema_map [encoded_fname ] = (arrow_type , builder_cls )
125125
126126 # special-case nested builders
127127 if builder_cls == DocumentBuilder :
@@ -132,6 +132,7 @@ def _parse_types(str_type_map, schema_map, tzinfo):
132132 sub_name = f"{ fname } .{ field .name } "
133133 sub_type_map [sub_name ] = field .type
134134 sub_type_map = _get_internal_typemap (sub_type_map )
135+ _parse_types (sub_type_map , builder_map , tzinfo )
135136 elif builder_cls == ListBuilder :
136137 if is_struct (arrow_type .value_type ):
137138 # construct a sub type map here
@@ -141,4 +142,15 @@ def _parse_types(str_type_map, schema_map, tzinfo):
141142 sub_name = f"{ fname } [].{ field .name } "
142143 sub_type_map [sub_name ] = field .type
143144 sub_type_map = _get_internal_typemap (sub_type_map )
144- _parse_types (sub_type_map , schema_map , tzinfo )
145+ _parse_types (sub_type_map , sub_type_map , tzinfo )
146+
147+ # special-case initializing builders for parameterized types
148+ if builder_cls == DatetimeBuilder :
149+ if tzinfo is not None and arrow_type .tz is None :
150+ arrow_type = timestamp (arrow_type .unit , tz = tzinfo ) # noqa: PLW2901
151+ builder_map [encoded_fname ] = DatetimeBuilder (dtype = arrow_type )
152+ elif builder_cls == BinaryBuilder :
153+ subtype = arrow_type .subtype
154+ builder_map [fname ] = BinaryBuilder (subtype )
155+ else :
156+ builder_map [fname ] = builder_cls ()
0 commit comments