Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions datajunction-server/datajunction_server/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4780,6 +4780,55 @@ def infer_type(
return [arg.key, arg.value]


class Range(TableFunction):
"""
range(start[, end[, step[, numSlices]]]) / range(end)
Returns a table with a single BIGINT column `id` containing values
within the specified range.
"""

dialects = [Dialect.SPARK]


@Range.register
def infer_type(end: ct.IntegerBase) -> List[ct.NestedField]:
"""range(end) - generates 0 to end-1"""
return [ct.NestedField(name="id", field_type=ct.BigIntType())]


@Range.register
def infer_type( # type: ignore
start: ct.IntegerBase,
end: ct.IntegerBase,
) -> List[ct.NestedField]:
"""range(start, end)"""
return [ct.NestedField(name="id", field_type=ct.BigIntType())]


@Range.register
def infer_type( # type: ignore
start: ct.IntegerBase,
end: ct.IntegerBase,
step: ct.IntegerBase,
) -> List[ct.NestedField]:
"""range(start, end, step)"""
print(f"DEBUG: Range.infer_type called with start={start}, end={end}, step={step}")
result = [ct.NestedField(name="id", field_type=ct.BigIntType())]
print(f"DEBUG: Returning {result}")
return result


@Range.register
def infer_type( # type: ignore
start: ct.IntegerBase,
end: ct.IntegerBase,
step: ct.IntegerBase,
num_slices: ct.IntegerBase,
) -> List[ct.NestedField]:
"""range(start, end, step, numSlices)"""
return [ct.NestedField(name="id", field_type=ct.BigIntType())]


class FunctionRegistryDict(dict):
"""
Custom dictionary mapping for functions
Expand Down
9 changes: 7 additions & 2 deletions datajunction-server/datajunction_server/sql/parsing/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2612,14 +2612,19 @@ async def _type(self, ctx: Optional[CompileContext] = None) -> List[NestedField]
if ctx:
await arg.compile(ctx)
arg_types.append(arg.type)
return dj_func.infer_type(*arg_types)
print(f"DEBUG _type: About to call {name}.infer_type")
result = dj_func.infer_type(*arg_types)
print(f"DEBUG _type: result={result}")
return result

async def compile(self, ctx):
if self.is_compiled():
return
self._is_compiled = True
types = await self._type(ctx)
for type, col in zip_longest(types, self.column_list):
print(f"DEBUG compile: types={types}")
print(f"DEBUG compile: self.column_list={self.column_list}")
for type, col in zip_longest(types, self.column_list or []):
if self.column_list:
if (type is None) or (col is None):
ctx.exception.errors.append(
Expand Down
3 changes: 2 additions & 1 deletion datajunction-server/datajunction_server/sql/parsing/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Dict,
Optional,
Tuple,
Union,
cast,
Callable,
)
Expand Down Expand Up @@ -245,7 +246,7 @@ class NestedField(ColumnType):

def __init__(
self,
name: "ast.Name",
name: Union["ast.Name", str],
field_type: ColumnType,
is_optional: bool = True,
doc: Optional[str] = None,
Expand Down
15 changes: 15 additions & 0 deletions datajunction-server/tests/sql/functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,6 +1600,21 @@ async def test_explode_outer_func(session: AsyncSession):
assert query.select.projection[1].type == ct.StringType() # type: ignore


@pytest.mark.asyncio
async def test_range(session: AsyncSession):
"""
Test the `range` function
"""
query = parse(
"SELECT id FROM range(1, 10, 2)",
)
exc = DJException()
ctx = ast.CompileContext(session=session, exception=exc)
await query.compile(ctx)
assert not exc.errors
assert query.select.projection[0].type == ct.IntegerType() # type: ignore


@pytest.mark.asyncio
async def test_expm1_func(session: AsyncSession):
"""
Expand Down
Loading