Skip to content

Commit e780287

Browse files
committed
fix(trino): add explicit type when compiling sge.Struct fields
1 parent f75a64e commit e780287

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

ibis/backends/sql/compilers/trino.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,20 @@ def visit_ArrayMap(self, op, *, arg, param, body, index):
171171
sge.Lambda(this=body, expressions=[param, index]),
172172
)
173173

174-
def visit_ArrayFilter(self, op, *, arg, param, body, index):
174+
def visit_ArrayFilter(
175+
self,
176+
op: ops.ArrayFilter,
177+
*,
178+
arg,
179+
param,
180+
body: sge.Identifier,
181+
index: sge.Identifier | None,
182+
):
175183
# no index, life is simpler
176184
if index is None:
177185
return self.f.filter(arg, sge.Lambda(this=body, expressions=[param]))
178186

179187
placeholder = sg.to_identifier("__trino_filter__")
180-
index = sg.to_identifier(index)
181188
keep, value = map(sg.to_identifier, ("keep", "value"))
182189

183190
# first, zip the array with the index and call the user's function,
@@ -192,7 +199,14 @@ def visit_ArrayFilter(self, op, *, arg, param, body, index):
192199
sge.Struct(
193200
expressions=[
194201
sge.PropertyEQ(this=keep, expression=body),
195-
sge.PropertyEQ(this=value, expression=param),
202+
# When sqlglot compiles a sge.Struct on presto/trino,
203+
# it warns if the struct fields don't have explicit types:
204+
# https://github.com/tobymao/sqlglot/blob/fc55b9889bcb1e0dad404dc15d357d8c755d85e6/sqlglot/dialects/presto.py#L711-L737
205+
# (and if ibis gets a warning during compilation, it raises an error)
206+
sge.PropertyEQ(
207+
this=value,
208+
expression=self.cast(param, op.arg.dtype.value_type),
209+
),
196210
]
197211
),
198212
dt.Struct(

0 commit comments

Comments
 (0)