Skip to content

Commit 5da18f7

Browse files
committed
Fix
1 parent cd35281 commit 5da18f7

File tree

4 files changed

+58
-7
lines changed

4 files changed

+58
-7
lines changed

datajunction-server/datajunction_server/sql/functions.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4791,12 +4791,42 @@ class Range(TableFunction):
47914791

47924792

47934793
@Range.register
4794-
def infer_type(
4795-
*args: ct.IntegerBase,
4794+
def infer_type(end: ct.IntegerBase) -> List[ct.NestedField]:
4795+
"""range(end) - generates 0 to end-1"""
4796+
return [ct.NestedField(name="id", field_type=ct.BigIntType())]
4797+
4798+
4799+
@Range.register
4800+
def infer_type( # type: ignore
4801+
start: ct.IntegerBase,
4802+
end: ct.IntegerBase,
4803+
) -> List[ct.NestedField]:
4804+
"""range(start, end)"""
4805+
return [ct.NestedField(name="id", field_type=ct.BigIntType())]
4806+
4807+
4808+
@Range.register
4809+
def infer_type( # type: ignore
4810+
start: ct.IntegerBase,
4811+
end: ct.IntegerBase,
4812+
step: ct.IntegerBase,
47964813
) -> List[ct.NestedField]:
4797-
from datajunction_server.sql.parsing.ast import Name
4814+
"""range(start, end, step)"""
4815+
print(f"DEBUG: Range.infer_type called with start={start}, end={end}, step={step}")
4816+
result = [ct.NestedField(name="id", field_type=ct.BigIntType())]
4817+
print(f"DEBUG: Returning {result}")
4818+
return result
4819+
47984820

4799-
return [ct.NestedField(name=Name("id"), field_type=ct.BigIntType())]
4821+
@Range.register
4822+
def infer_type( # type: ignore
4823+
start: ct.IntegerBase,
4824+
end: ct.IntegerBase,
4825+
step: ct.IntegerBase,
4826+
num_slices: ct.IntegerBase,
4827+
) -> List[ct.NestedField]:
4828+
"""range(start, end, step, numSlices)"""
4829+
return [ct.NestedField(name="id", field_type=ct.BigIntType())]
48004830

48014831

48024832
class FunctionRegistryDict(dict):

datajunction-server/datajunction_server/sql/parsing/ast.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2612,14 +2612,19 @@ async def _type(self, ctx: Optional[CompileContext] = None) -> List[NestedField]
26122612
if ctx:
26132613
await arg.compile(ctx)
26142614
arg_types.append(arg.type)
2615-
return dj_func.infer_type(*arg_types)
2615+
print(f"DEBUG _type: About to call {name}.infer_type")
2616+
result = dj_func.infer_type(*arg_types)
2617+
print(f"DEBUG _type: result={result}")
2618+
return result
26162619

26172620
async def compile(self, ctx):
26182621
if self.is_compiled():
26192622
return
26202623
self._is_compiled = True
26212624
types = await self._type(ctx)
2622-
for type, col in zip_longest(types, self.column_list):
2625+
print(f"DEBUG compile: types={types}")
2626+
print(f"DEBUG compile: self.column_list={self.column_list}")
2627+
for type, col in zip_longest(types, self.column_list or []):
26232628
if self.column_list:
26242629
if (type is None) or (col is None):
26252630
ctx.exception.errors.append(

datajunction-server/datajunction_server/sql/parsing/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Dict,
2121
Optional,
2222
Tuple,
23+
Union,
2324
cast,
2425
Callable,
2526
)
@@ -245,7 +246,7 @@ class NestedField(ColumnType):
245246

246247
def __init__(
247248
self,
248-
name: "ast.Name",
249+
name: Union["ast.Name", str],
249250
field_type: ColumnType,
250251
is_optional: bool = True,
251252
doc: Optional[str] = None,

datajunction-server/tests/sql/functions_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,6 +1600,21 @@ async def test_explode_outer_func(session: AsyncSession):
16001600
assert query.select.projection[1].type == ct.StringType() # type: ignore
16011601

16021602

1603+
@pytest.mark.asyncio
1604+
async def test_range(session: AsyncSession):
1605+
"""
1606+
Test the `range` function
1607+
"""
1608+
query = parse(
1609+
"SELECT id FROM range(1, 10, 2)",
1610+
)
1611+
exc = DJException()
1612+
ctx = ast.CompileContext(session=session, exception=exc)
1613+
await query.compile(ctx)
1614+
assert not exc.errors
1615+
assert query.select.projection[0].type == ct.IntegerType() # type: ignore
1616+
1617+
16031618
@pytest.mark.asyncio
16041619
async def test_expm1_func(session: AsyncSession):
16051620
"""

0 commit comments

Comments
 (0)