Skip to content

Commit 246825e

Browse files
committed
continue refactoring
1 parent 7e3b50c commit 246825e

File tree

6 files changed

+292
-291
lines changed

6 files changed

+292
-291
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ repos:
9494
stages: [manual]
9595
args: ["--no-strict-imports"]
9696

97+
- repo: https://github.com/MarcoGorelli/cython-lint
98+
rev: v0.16.2
99+
hooks:
100+
- id: cython-lint
101+
97102
- repo: https://github.com/codespell-project/codespell
98103
rev: "v2.2.6"
99104
hooks:

bindings/python/pymongoarrow/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def find_arrow_all(collection, query, *, schema=None, **kwargs):
8888
:Returns:
8989
An instance of class:`pyarrow.Table`.
9090
"""
91-
context = PyMongoArrowContext.from_schema(schema, codec_options=collection.codec_options)
91+
context = PyMongoArrowContext(schema, codec_options=collection.codec_options)
9292

9393
for opt in ("cursor_type",):
9494
if kwargs.pop(opt, None):
@@ -126,7 +126,7 @@ def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs):
126126
:Returns:
127127
An instance of class:`pyarrow.Table`.
128128
"""
129-
context = PyMongoArrowContext.from_schema(schema, codec_options=collection.codec_options)
129+
context = PyMongoArrowContext(schema, codec_options=collection.codec_options)
130130

131131
if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
132132
msg = (

bindings/python/pymongoarrow/context.py

Lines changed: 25 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from bson.codec_options import DEFAULT_CODEC_OPTIONS
15-
from pyarrow import ListArray, StructArray, Table, timestamp
14+
from pyarrow import ListArray, StructArray, Table
1615
from pyarrow.types import is_struct
1716

1817
from pymongoarrow.types import _BsonArrowTypes, _get_internal_typemap
@@ -54,15 +53,14 @@
5453
_BsonArrowTypes.date64: Date64Builder,
5554
_BsonArrowTypes.null: NullBuilder,
5655
}
57-
5856
except ImportError:
5957
pass
6058

6159

6260
class PyMongoArrowContext:
6361
"""A context for converting BSON-formatted data to an Arrow Table."""
6462

65-
def __init__(self, schema, builder_map, codec_options=None):
63+
def __init__(self, schema, codec_options=None):
6664
"""Initialize the context.
6765
6866
:Parameters:
@@ -71,90 +69,70 @@ def __init__(self, schema, builder_map, codec_options=None):
7169
:class:`~pymongoarrow.builders._BuilderBase` instances.
7270
"""
7371
self.schema = schema
74-
self.builder_map = builder_map
7572
if self.schema is None and codec_options is not None:
7673
self.tzinfo = codec_options.tzinfo
7774
else:
7875
self.tzinfo = None
79-
self.manager = BuilderManager(builder_map, self.schema is not None, self.tzinfo)
80-
81-
@classmethod
82-
def from_schema(cls, schema, codec_options=DEFAULT_CODEC_OPTIONS):
83-
"""Initialize the context from a :class:`~pymongoarrow.schema.Schema`
84-
instance.
85-
86-
:Parameters:
87-
- `schema`: Instance of :class:`~pymongoarrow.schema.Schema`.
88-
- `codec_options` (optional): An instance of
89-
:class:`~bson.codec_options.CodecOptions`.
90-
"""
91-
if schema is None:
92-
return cls(schema, {}, codec_options)
93-
94-
builder_map = {}
95-
tzinfo = codec_options.tzinfo
96-
str_type_map = _get_internal_typemap(schema.typemap)
97-
_parse_types(str_type_map, builder_map, tzinfo)
98-
return cls(schema, builder_map)
76+
self.manager = BuilderManager(self.schema is not None, self.tzinfo)
77+
if self.schema is not None:
78+
schema_map = {}
79+
str_type_map = _get_internal_typemap(schema.typemap)
80+
_parse_types(str_type_map, schema_map, self.tzinfo)
81+
self.manager.parse_types(schema_map)
9982

10083
def process_bson_stream(self, stream):
10184
self.manager.process_bson_stream(stream, len(stream))
10285

10386
def finish(self):
104-
return self._finish(self.builder_map, self.schema)
87+
builder_map = self.manager.finish().copy()
10588

106-
@staticmethod
107-
def _finish(builder_map, schema):
89+
# Handle nested builders.
10890
to_remove = []
10991
# Traverse the builder map right to left.
11092
for key, value in reversed(builder_map.items()):
11193
field = key.decode("utf-8")
112-
arr = value.finish()
11394
if isinstance(value, DocumentBuilder):
95+
arr = value.finish()
11496
full_names = [f"{field}.{name.decode('utf-8')}" for name in arr]
11597
arrs = [builder_map[c.encode("utf-8")] for c in full_names]
11698
builder_map[field] = StructArray.from_arrays(arrs, names=arr)
11799
to_remove.extend(full_names)
118100
elif isinstance(value, ListBuilder):
119-
child = field + "[]"
120-
to_remove.append(child)
121-
builder_map[key] = ListArray.from_arrays(arr, builder_map.get(child, []))
101+
arr = value.finish()
102+
child_name = field + "[]"
103+
to_remove.append(child_name)
104+
child = builder_map[child_name.encode("utf-8")]
105+
builder_map[key] = ListArray.from_arrays(arr, child)
122106
else:
123-
builder_map[key] = arr
107+
builder_map[key] = value.finish()
124108

125109
for field in to_remove:
126110
key = field.encode("utf-8")
127111
if key in builder_map:
128112
del builder_map[key]
129113

130114
arrays = list(builder_map.values())
131-
if schema is not None:
132-
return Table.from_arrays(arrays=arrays, schema=schema.to_arrow())
115+
if self.schema is not None:
116+
return Table.from_arrays(arrays=arrays, schema=self.schema.to_arrow())
133117
return Table.from_arrays(arrays=arrays, names=list(builder_map.keys()))
134118

135119

136-
def _parse_types(str_type_map, builder_map, tzinfo):
120+
def _parse_types(str_type_map, schema_map, tzinfo):
137121
for fname, (ftype, arrow_type) in str_type_map.items():
138122
builder_cls = _TYPE_TO_BUILDER_CLS[ftype]
139123
encoded_fname = fname.encode("utf-8")
140-
# special-case initializing builders for parameterized types
141-
if builder_cls == DatetimeBuilder:
142-
if tzinfo is not None and arrow_type.tz is None:
143-
arrow_type = timestamp(arrow_type.unit, tz=tzinfo) # noqa: PLW2901
144-
builder_map[encoded_fname] = DatetimeBuilder(dtype=arrow_type)
145-
elif builder_cls == DocumentBuilder:
146-
builder_map[encoded_fname] = DocumentBuilder()
124+
schema_map[encoded_fname] = (arrow_type, builder_cls)
125+
126+
# special-case nested builders
127+
if builder_cls == DocumentBuilder:
147128
# construct a sub type map here
148129
sub_type_map = {}
149130
for i in range(arrow_type.num_fields):
150131
field = arrow_type[i]
151132
sub_name = f"{fname}.{field.name}"
152133
sub_type_map[sub_name] = field.type
153134
sub_type_map = _get_internal_typemap(sub_type_map)
154-
_parse_types(sub_type_map, builder_map, tzinfo)
155-
continue
156135
elif builder_cls == ListBuilder:
157-
builder_map[encoded_fname] = ListBuilder()
158136
if is_struct(arrow_type.value_type):
159137
# construct a sub type map here
160138
sub_type_map = {}
@@ -163,10 +141,4 @@ def _parse_types(str_type_map, builder_map, tzinfo):
163141
sub_name = f"{fname}[].{field.name}"
164142
sub_type_map[sub_name] = field.type
165143
sub_type_map = _get_internal_typemap(sub_type_map)
166-
_parse_types(sub_type_map, builder_map, tzinfo)
167-
continue
168-
elif builder_cls == BinaryBuilder:
169-
subtype = arrow_type.subtype
170-
builder_map[encoded_fname] = BinaryBuilder(subtype)
171-
else:
172-
builder_map[encoded_fname] = builder_cls()
144+
_parse_types(sub_type_map, schema_map, tzinfo)

0 commit comments

Comments
 (0)