Skip to content

Commit 5b8f4b8

Browse files
authored
Merge pull request #2850 from mabel-dev/#2849
Return column names as per SELECT #2849
2 parents 0c3e43e + d88da8d commit 5b8f4b8

File tree

18 files changed

+280
-212
lines changed

18 files changed

+280
-212
lines changed

opteryx/__version__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# THIS FILE IS AUTOMATICALLY UPDATED DURING THE BUILD PROCESS
22
# DO NOT EDIT THIS FILE DIRECTLY
33

4-
__build__ = 1652
4+
__build__ = 1654
55
__author__ = "@joocer"
6-
__version__ = "0.26.0-beta.1652"
6+
__version__ = "0.26.0-beta.1654"
77

88
# Store the version here so:
99
# 1) we don't load dependencies by storing it in __init__.py

opteryx/compiled/joins/cross_join.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,11 @@ cpdef tuple numpy_build_filtered_rows_indices_and_column(numpy.ndarray column_da
161161

162162
# Handle set initialization based on element dtype
163163
if numpy.issubdtype(element_dtype, numpy.integer):
164-
valid_values_typed = set([int(v) for v in valid_values])
164+
valid_values_typed = {int(v) for v in valid_values}
165165
elif numpy.issubdtype(element_dtype, numpy.floating):
166-
valid_values_typed = set([parse_fast_float(v) for v in valid_values])
166+
valid_values_typed = {parse_fast_float(v) for v in valid_values}
167167
elif numpy.issubdtype(element_dtype, numpy.str_):
168-
valid_values_typed = set([unicode(v) for v in valid_values])
168+
valid_values_typed = {unicode(v) for v in valid_values}
169169
else:
170170
valid_values_typed = valid_values # Fallback to generic Python set
171171

opteryx/connectors/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def connector_factory(dataset, statistics, **config):
270270
break
271271
else:
272272
# Check if dataset is a file or contains wildcards
273-
has_wildcards = any(char in dataset for char in ['*', '?', '['])
273+
has_wildcards = any(char in dataset for char in ["*", "?", "["])
274274
if os.path.isfile(dataset) or has_wildcards:
275275
from opteryx.connectors import file_connector
276276

@@ -286,7 +286,7 @@ def connector_factory(dataset, statistics, **config):
286286
remove_prefix = connector_entry.pop("remove_prefix", False)
287287
if prefix and remove_prefix and dataset.startswith(prefix):
288288
# Remove the prefix. If there's a separator (. or //) after the prefix, skip it too
289-
dataset = dataset[len(prefix):]
289+
dataset = dataset[len(prefix) :]
290290
if dataset.startswith(".") or dataset.startswith("//"):
291291
dataset = dataset[1:] if dataset.startswith(".") else dataset[2:]
292292

opteryx/connectors/aws_s3_connector.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,13 @@ def __init__(self, credentials=None, **kwargs):
8686
)
8787

8888
self.minio = Minio(end_point, access_key, secret_key, secure=secure)
89-
89+
9090
# Only convert dots to path separators if the dataset doesn't already contain slashes
9191
# Dataset references like "my.dataset.table" use dots as separators
9292
# File paths like "bucket/path/file.parquet" already have slashes and should not be converted
9393
if OS_SEP not in self.dataset and "/" not in self.dataset:
9494
self.dataset = self.dataset.replace(".", OS_SEP)
95-
95+
9696
# Check if dataset contains wildcards
9797
self.has_wildcards = paths.has_wildcards(self.dataset)
9898
if self.has_wildcards:
@@ -111,28 +111,28 @@ def get_list_of_blob_names(self, *, prefix: str) -> List[str]:
111111
else:
112112
list_prefix = prefix
113113
filter_pattern = None
114-
114+
115115
bucket, object_path, _, _ = paths.get_parts(list_prefix)
116116
blobs = self.minio.list_objects(bucket_name=bucket, prefix=object_path, recursive=True)
117-
117+
118118
blob_list = []
119119
for blob in blobs:
120120
if blob.object_name.endswith("/"):
121121
continue
122-
122+
123123
full_path = bucket + "/" + blob.object_name
124-
124+
125125
# Check if blob has valid extension
126126
if ("." + full_path.split(".")[-1].lower()) not in VALID_EXTENSIONS:
127127
continue
128-
128+
129129
# If we have a wildcard pattern, filter by it
130130
if filter_pattern:
131131
if paths.match_wildcard(filter_pattern, full_path):
132132
blob_list.append(full_path)
133133
else:
134134
blob_list.append(full_path)
135-
135+
136136
return sorted(blob_list)
137137

138138
def read_dataset(

opteryx/connectors/file_connector.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ def __init__(self, *args, **kwargs):
136136
if ".." in self.dataset or self.dataset[0] in ("\\", "/", "~"):
137137
# Don't find any datasets which look like path traversal
138138
raise DatasetNotFoundError(dataset=self.dataset)
139-
139+
140140
# Check if dataset contains wildcards
141-
self.has_wildcards = any(char in self.dataset for char in ['*', '?', '['])
142-
141+
self.has_wildcards = any(char in self.dataset for char in ["*", "?", "["])
142+
143143
if self.has_wildcards:
144144
# Expand wildcards to get list of files
145145
self.files = self._expand_wildcards(self.dataset)
@@ -150,43 +150,43 @@ def __init__(self, *args, **kwargs):
150150
else:
151151
self.files = [self.dataset]
152152
self.decoder = get_decoder(self.dataset)
153-
153+
154154
def _expand_wildcards(self, pattern: str) -> List[str]:
155155
"""
156156
Expand wildcard patterns in file paths while preventing path traversal.
157-
157+
158158
Supports wildcards:
159159
- * matches any number of characters
160-
- ? matches a single character
160+
- ? matches a single character
161161
- [range] matches a range of characters (e.g., [0-9], [a-z])
162-
162+
163163
Args:
164164
pattern: File path pattern with wildcards
165-
165+
166166
Returns:
167167
List of matching file paths
168168
"""
169169
# Additional path traversal check after expansion
170170
if ".." in pattern:
171171
raise DatasetNotFoundError(dataset=pattern)
172-
172+
173173
# Use glob to expand the pattern
174174
matched_files = glob.glob(pattern, recursive=False)
175-
175+
176176
# Filter out any results that might have path traversal
177177
# This is an extra safety check
178178
safe_files = []
179179
for file_path in matched_files:
180180
if ".." not in file_path and os.path.isfile(file_path):
181181
safe_files.append(file_path)
182-
182+
183183
return sorted(safe_files)
184184

185185
def read_dataset(
186186
self, columns: list = None, predicates: list = None, limit: int = None, **kwargs
187187
) -> pyarrow.Table:
188188
rows_read = 0
189-
189+
190190
# Iterate over all matched files
191191
for file_path in self.files:
192192
morsel = read_blob(
@@ -221,7 +221,7 @@ def get_dataset_schema(self) -> RelationSchema:
221221

222222
# Use the first file to get the schema
223223
first_file = self.files[0]
224-
224+
225225
try:
226226
file_descriptor = os.open(first_file, os.O_RDONLY | os.O_BINARY)
227227
size = os.path.getsize(first_file)

opteryx/connectors/gcp_cloudstorage_connector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, credentials=None, **kwargs):
9797
if OS_SEP not in self.dataset and "/" not in self.dataset:
9898
self.dataset = self.dataset.replace(".", OS_SEP)
9999
self.credentials = credentials
100-
100+
101101
# Check if dataset contains wildcards
102102
self.has_wildcards = paths.has_wildcards(self.dataset)
103103
if self.has_wildcards:
@@ -231,9 +231,9 @@ def get_list_of_blob_names(self, *, prefix: str) -> List[str]:
231231
name = blob["name"]
232232
if not name.endswith(TUPLE_OF_VALID_EXTENSIONS):
233233
continue
234-
234+
235235
full_path = f"{bucket}/{name}"
236-
236+
237237
# If we have a wildcard pattern, filter by it
238238
if filter_pattern:
239239
if paths.match_wildcard(filter_pattern, full_path):

opteryx/cursor.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,25 @@ def execute_to_arrow(
336336
if isinstance(result_data, pyarrow.Table):
337337
return result_data
338338
try:
339-
return pyarrow.concat_tables(result_data, promote_options="permissive")
339+
# arrow allows duplicate column names, but not when concatting
340+
from itertools import chain
341+
342+
first_table = next(result_data, None)
343+
if first_table is not None:
344+
column_names = first_table.column_names
345+
if len(column_names) != len(set(column_names)):
346+
temporary_names = [f"col_{i}" for i in range(len(column_names))]
347+
first_table = first_table.rename_columns(temporary_names)
348+
return_table = pyarrow.concat_tables(
349+
chain(
350+
[first_table], (t.rename_columns(temporary_names) for t in result_data)
351+
),
352+
promote_options="permissive",
353+
)
354+
return return_table.rename_columns(column_names)
355+
return pyarrow.concat_tables(
356+
chain([first_table], result_data), promote_options="permissive"
357+
)
340358
except (
341359
pyarrow.ArrowInvalid,
342360
pyarrow.ArrowTypeError,

opteryx/operators/exit_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, properties: QueryProperties, **parameters):
4040
final_names = []
4141
for column in self.columns:
4242
final_columns.append(column.schema_column.identity)
43-
final_names.append(column.current_name)
43+
final_names.append(column.alias)
4444

4545
if len(final_columns) != len(set(final_columns)): # pragma: no cover
4646
from collections import Counter
@@ -57,7 +57,7 @@ def __init__(self, properties: QueryProperties, **parameters):
5757
# if column.schema_column.origin:
5858
# final_names.append(f"{column.schema_column.origin[0]}.{column.current_name}")
5959
# else:
60-
final_names.append(column.qualified_name)
60+
final_names.append(column.alias)
6161

6262
self.final_columns = final_columns
6363
self.final_names = final_names

opteryx/planner/binder/binder_visitor.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -380,14 +380,7 @@ def visit_exit(self, node: Node, context: BindingContext) -> Tuple[Node, Binding
380380
# clear the derived schema
381381
context.schemas.pop("$derived", None)
382382

383-
seen = set()
384-
needs_qualifier = len(context.schemas) > 1 or any(
385-
column.name in seen or seen.add(column.name) is not None # type: ignore
386-
for schema in context.schemas.values()
387-
for column in schema.columns
388-
)
389-
390-
def name_column(qualifier, column):
383+
def name_column(column):
391384
for projection_column in node.columns:
392385
if (
393386
projection_column.schema_column
@@ -396,20 +389,11 @@ def name_column(qualifier, column):
396389
if projection_column.alias:
397390
return projection_column.alias
398391

399-
if len(context.relations) > 1 or needs_qualifier:
400-
if isinstance(projection_column, LogicalColumn):
401-
if qualifier:
402-
projection_column.source = qualifier
403-
return projection_column.qualified_name
404-
return f"{qualifier}.{column.name}"
405-
406392
if projection_column.query_column:
407393
return str(projection_column.query_column)
408394
if projection_column.current_name:
409395
return projection_column.current_name
410396

411-
if needs_qualifier:
412-
return f"{qualifier}.{column.name}"
413397
return column.name
414398

415399
def keep_column(column, identities):
@@ -441,15 +425,15 @@ def keep_column(column, identities):
441425
identities.append(column.identity)
442426

443427
columns = []
444-
for qualifier, schema in context.schemas.items():
428+
for _, schema in context.schemas.items():
445429
for column in schema.columns:
446430
if keep_column(column, identities):
447-
column_name = name_column(qualifier=qualifier, column=column)
431+
column_name = name_column(column=column)
448432
column_reference = LogicalColumn(
449433
node_type=NodeType.IDENTIFIER,
450434
source_column=column_name,
451435
source=None,
452-
alias=None,
436+
alias=column_name,
453437
schema_column=column,
454438
)
455439
columns.append(column_reference)

opteryx/planner/logical_planner/logical_planner_builders.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,21 @@ def ceiling(value, alias: Optional[List[str]] = None, key=None):
271271

272272

273273
def compound_identifier(branch, alias: Optional[List[str]] = None, key=None):
274-
return LogicalColumn(
274+
column = LogicalColumn(
275275
node_type=NodeType.IDENTIFIER, # column type
276276
alias=alias, # type: ignore
277277
source_column=branch[-1]["value"], # the source column
278278
source=".".join(p["value"] for p in branch[:-1]), # the source relation
279279
)
280+
alias_name = alias[0] if isinstance(alias, list) and alias else alias
281+
if alias_name:
282+
column.query_column = alias_name
283+
else:
284+
qualifier = column.source
285+
column.query_column = (
286+
f"{qualifier}.{column.source_column}" if qualifier else column.source_column
287+
)
288+
return column
280289

281290

282291
def expression_with_alias(branch, alias: Optional[List[str]] = None, key=None):
@@ -424,11 +433,14 @@ def identifier(branch, alias: Optional[List[str]] = None, key=None):
424433
"""idenitifier doesn't have a qualifier (recorded in source)"""
425434
if "Identifier" in branch:
426435
return build(branch["Identifier"], alias=alias)
427-
return LogicalColumn(
436+
column = LogicalColumn(
428437
node_type=NodeType.IDENTIFIER, # column type
429438
alias=alias, # type: ignore
430439
source_column=branch["value"], # the source column
431440
)
441+
alias_name = alias[0] if isinstance(alias, list) and alias else alias
442+
column.query_column = alias_name or column.source_column
443+
return column
432444

433445

434446
def in_list(branch, alias: Optional[List[str]] = None, key=None):

0 commit comments

Comments
 (0)