Skip to content

Commit b67cb70

Browse files
authored
Move 'create_pre_udf_table' function to warehouse module (#187)
1 parent 8f431dd commit b67cb70

File tree

4 files changed

+51
-47
lines changed

4 files changed

+51
-47
lines changed

src/datachain/data_storage/sqlite.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from sqlalchemy.dialects.sqlite import Insert
4343
from sqlalchemy.schema import SchemaItem
4444
from sqlalchemy.sql.elements import ColumnClause, ColumnElement, TextClause
45+
from sqlalchemy.sql.selectable import Select
4546
from sqlalchemy.types import TypeEngine
4647

4748

@@ -705,3 +706,23 @@ def export_dataset_table(
705706
client_config=None,
706707
) -> list[str]:
707708
raise NotImplementedError("Exporting dataset table not implemented for SQLite")
709+
710+
def create_pre_udf_table(self, query: "Select") -> "Table":
711+
"""
712+
Create a temporary table from a query for use in a UDF.
713+
"""
714+
columns = [
715+
sqlalchemy.Column(c.name, c.type)
716+
for c in query.selected_columns
717+
if c.name != "sys__id"
718+
]
719+
table = self.create_udf_table(columns)
720+
721+
select_q = query.with_only_columns(
722+
*[c for c in query.selected_columns if c.name != "sys__id"]
723+
)
724+
self.db.execute(
725+
table.insert().from_select(list(select_q.selected_columns), select_q)
726+
)
727+
728+
return table

src/datachain/data_storage/warehouse.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import json
33
import logging
44
import posixpath
5+
import random
6+
import string
57
from abc import ABC, abstractmethod
68
from collections.abc import Generator, Iterable, Iterator, Sequence
79
from typing import TYPE_CHECKING, Any, Optional, Union
@@ -24,6 +26,7 @@
2426
if TYPE_CHECKING:
2527
from sqlalchemy.sql._typing import _ColumnsClauseArgument
2628
from sqlalchemy.sql.elements import ColumnElement
29+
from sqlalchemy.sql.selectable import Select
2730
from sqlalchemy.types import TypeEngine
2831

2932
from datachain.data_storage import AbstractIDGenerator, schema
@@ -252,6 +255,12 @@ def dataset_table_name(self, dataset_name: str, version: int) -> str:
252255
prefix = self.DATASET_SOURCE_TABLE_PREFIX
253256
return f"{prefix}{dataset_name}_{version}"
254257

258+
def temp_table_name(self) -> str:
259+
return self.TMP_TABLE_NAME_PREFIX + _random_string(6)
260+
261+
def udf_table_name(self) -> str:
262+
return self.UDF_TABLE_NAME_PREFIX + _random_string(6)
263+
255264
#
256265
# Datasets
257266
#
@@ -869,23 +878,29 @@ def update_node(self, node_id: int, values: dict[str, Any]) -> None:
869878

870879
def create_udf_table(
871880
self,
872-
name: str,
873881
columns: Sequence["sa.Column"] = (),
882+
name: Optional[str] = None,
874883
) -> "sa.Table":
875884
"""
876885
Create a temporary table for storing custom signals generated by a UDF.
877886
SQLite TEMPORARY tables cannot be directly used as they are process-specific,
878887
and UDFs are run in other processes when run in parallel.
879888
"""
880889
tbl = sa.Table(
881-
name,
890+
name or self.udf_table_name(),
882891
sa.MetaData(),
883892
sa.Column("sys__id", Int, primary_key=True),
884893
*columns,
885894
)
886895
self.db.create_table(tbl, if_not_exists=True)
887896
return tbl
888897

898+
@abstractmethod
899+
def create_pre_udf_table(self, query: "Select") -> "Table":
900+
"""
901+
Create a temporary table from a query for use in a UDF.
902+
"""
903+
889904
def is_temp_table_name(self, name: str) -> bool:
890905
"""Returns if the given table name refers to a temporary
891906
or no longer needed table."""
@@ -937,3 +952,10 @@ def changed_query(
937952
& (tq.c.is_latest == true())
938953
)
939954
)
955+
956+
957+
def _random_string(length: int) -> str:
958+
return "".join(
959+
random.choice(string.ascii_letters + string.digits) # noqa: S311
960+
for i in range(length)
961+
)

src/datachain/query/dataset.py

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,7 @@ def apply(self, query_generator, temp_tables: list[str]):
262262
temp_tables.extend(self.dq.temp_table_names)
263263

264264
# creating temp table that will hold subtract results
265-
temp_table_name = self.catalog.warehouse.TMP_TABLE_NAME_PREFIX + _random_string(
266-
6
267-
)
265+
temp_table_name = self.catalog.warehouse.temp_table_name()
268266
temp_tables.append(temp_table_name)
269267

270268
columns = [
@@ -448,9 +446,6 @@ def create_result_query(
448446
to select
449447
"""
450448

451-
def udf_table_name(self) -> str:
452-
return self.catalog.warehouse.UDF_TABLE_NAME_PREFIX + _random_string(6)
453-
454449
def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
455450
use_partitioning = self.partition_by is not None
456451
batching = self.udf.properties.get_batching(use_partitioning)
@@ -574,9 +569,7 @@ def create_partitions_table(self, query: Select) -> "Table":
574569
list_partition_by = [self.partition_by]
575570

576571
# create table with partitions
577-
tbl = self.catalog.warehouse.create_udf_table(
578-
self.udf_table_name(), partition_columns()
579-
)
572+
tbl = self.catalog.warehouse.create_udf_table(partition_columns())
580573

581574
# fill table with partitions
582575
cols = [
@@ -638,37 +631,12 @@ def create_udf_table(self, query: Select) -> "Table":
638631
for (col_name, col_type) in self.udf.output.items()
639632
]
640633

641-
return self.catalog.warehouse.create_udf_table(
642-
self.udf_table_name(), udf_output_columns
643-
)
644-
645-
def create_pre_udf_table(self, query: Select) -> "Table":
646-
columns = [
647-
sqlalchemy.Column(c.name, c.type)
648-
for c in query.selected_columns
649-
if c.name != "sys__id"
650-
]
651-
table = self.catalog.warehouse.create_udf_table(self.udf_table_name(), columns)
652-
select_q = query.with_only_columns(
653-
*[c for c in query.selected_columns if c.name != "sys__id"]
654-
)
655-
656-
# if there is order by clause we need row_number to preserve order
657-
# if there is no order by clause we still need row_number to generate
658-
# unique ids as uniqueness is important for this table
659-
select_q = select_q.add_columns(
660-
f.row_number().over(order_by=select_q._order_by_clauses).label("sys__id")
661-
)
662-
663-
self.catalog.warehouse.db.execute(
664-
table.insert().from_select(list(select_q.selected_columns), select_q)
665-
)
666-
return table
634+
return self.catalog.warehouse.create_udf_table(udf_output_columns)
667635

668636
def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]:
669637
if os.getenv("DATACHAIN_DISABLE_QUERY_CACHE", "") not in ("", "0"):
670638
return query, []
671-
table = self.create_pre_udf_table(query)
639+
table = self.catalog.warehouse.create_pre_udf_table(query)
672640
q: Select = sqlalchemy.select(*table.c)
673641
if query._order_by_clauses:
674642
# we are adding ordering only if it's explicitly added by user in
@@ -732,7 +700,7 @@ class RowGenerator(UDFStep):
732700
def create_udf_table(self, query: Select) -> "Table":
733701
warehouse = self.catalog.warehouse
734702

735-
table_name = self.udf_table_name()
703+
table_name = self.catalog.warehouse.udf_table_name()
736704
columns: tuple[Column, ...] = tuple(
737705
Column(name, typ) for name, typ in self.udf.output.items()
738706
)
@@ -1802,10 +1770,3 @@ def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery:
18021770

18031771
_send_result(dataset_query)
18041772
return dataset_query
1805-
1806-
1807-
def _random_string(length: int) -> str:
1808-
return "".join(
1809-
random.choice(string.ascii_letters + string.digits) # noqa: S311
1810-
for i in range(length)
1811-
)

tests/func/test_catalog.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,7 @@ def test_garbage_collect(cloud_test_catalog, from_cli, capsys):
11131113
assert catalog.get_temp_table_names() == []
11141114
temp_tables = ["tmp_vc12F", "udf_jh653", "ds_shadow_12345", "old_ds_shadow"]
11151115
for t in temp_tables:
1116-
catalog.warehouse.create_udf_table(t)
1116+
catalog.warehouse.create_udf_table(name=t)
11171117
assert set(catalog.get_temp_table_names()) == set(temp_tables)
11181118
if from_cli:
11191119
garbage_collect(catalog)

0 commit comments

Comments
 (0)