Skip to content

Commit bc7608b

Browse files
authored
Adding DataChain.column(...) and fixing functions and types (#226)
* fixing sql to python * added tests for sql to python and changed input type of sql to python * changing docstring * fixing tests and Decimal type conversion * returning exception when column is not found * changed docstring * fixed typo * added new exception type * renaming error class * skipping division expression tests for CH * using new column method from dc * updating studio branch * return to develop
1 parent d47aee3 commit bc7608b

File tree

7 files changed

+175
-26
lines changed

7 files changed

+175
-26
lines changed
Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
1-
from datetime import datetime
1+
from decimal import Decimal
22
from typing import Any
33

4-
from sqlalchemy import ARRAY, JSON, Boolean, DateTime, Float, Integer, String
4+
from sqlalchemy import ColumnElement
55

6-
from datachain.data_storage.sqlite import Column
76

8-
SQL_TO_PYTHON = {
9-
String: str,
10-
Integer: int,
11-
Float: float,
12-
Boolean: bool,
13-
DateTime: datetime,
14-
ARRAY: list,
15-
JSON: dict,
16-
}
7+
def sql_to_python(args_map: dict[str, ColumnElement]) -> dict[str, Any]:
8+
res = {}
9+
for name, sql_exp in args_map.items():
10+
try:
11+
type_ = sql_exp.type.python_type
12+
if type_ == Decimal:
13+
type_ = float
14+
except NotImplementedError:
15+
type_ = str
16+
res[name] = type_
1717

18-
19-
def sql_to_python(args_map: dict[str, Column]) -> dict[str, Any]:
20-
return {
21-
k: SQL_TO_PYTHON.get(type(v.type), str) # type: ignore[union-attr]
22-
for k, v in args_map.items()
23-
}
18+
return res

src/datachain/lib/dc.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
import sqlalchemy
2121
from pydantic import BaseModel, create_model
2222
from sqlalchemy.sql.functions import GenericFunction
23+
from sqlalchemy.sql.sqltypes import NullType
2324

2425
from datachain import DataModel
26+
from datachain.lib.convert.python_to_sql import python_to_sql
2527
from datachain.lib.convert.values_to_tuples import values_to_tuples
2628
from datachain.lib.data_model import DataType
2729
from datachain.lib.dataset_info import DatasetInfo
@@ -110,6 +112,11 @@ def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: st
110112
super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")
111113

112114

115+
class DataChainColumnError(DataChainParamsError): # noqa: D101
116+
def __init__(self, col_name, msg): # noqa: D107
117+
super().__init__(f"Error for column {col_name}: {msg}")
118+
119+
113120
OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]
114121

115122

@@ -225,6 +232,17 @@ def schema(self) -> dict[str, DataType]:
225232
"""Get schema of the chain."""
226233
return self._effective_signals_schema.values
227234

235+
def column(self, name: str) -> Column:
236+
"""Returns Column instance with a type if name is found in current schema,
237+
otherwise raises an exception.
238+
"""
239+
name_path = name.split(".")
240+
for path, type_, _, _ in self.signals_schema.get_flat_tree():
241+
if path == name_path:
242+
return Column(name, python_to_sql(type_))
243+
244+
raise ValueError(f"Column with name {name} not found in the schema")
245+
228246
def print_schema(self) -> None:
229247
"""Print schema of the chain."""
230248
self._effective_signals_schema.print_tree()
@@ -829,6 +847,12 @@ def mutate(self, **kwargs) -> "Self":
829847
)
830848
```
831849
"""
850+
for col_name, expr in kwargs.items():
851+
if not isinstance(expr, Column) and isinstance(expr.type, NullType):
852+
raise DataChainColumnError(
853+
col_name, f"Cannot infer type with expression {expr}"
854+
)
855+
832856
mutated = {}
833857
schema = self.signals_schema
834858
for name, value in kwargs.items():

src/datachain/sql/functions/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from sqlalchemy.sql.expression import func
22

3-
from . import path, string
3+
from . import array, path, string
4+
from .array import avg
45
from .conditional import greatest, least
56
from .random import rand
67

78
count = func.count
89
sum = func.sum
9-
avg = func.avg
1010
min = func.min
1111
max = func.max
1212

1313
__all__ = [
14+
"array",
1415
"avg",
1516
"count",
1617
"func",

src/datachain/sql/functions/array.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,15 @@ class sip_hash_64(GenericFunction): # noqa: N801
4444
inherit_cache = True
4545

4646

47+
class avg(GenericFunction): # noqa: N801
48+
type = Float()
49+
package = "array"
50+
name = "avg"
51+
inherit_cache = True
52+
53+
4754
compiler_not_implemented(cosine_distance)
4855
compiler_not_implemented(euclidean_distance)
4956
compiler_not_implemented(length)
5057
compiler_not_implemented(sip_hash_64)
58+
compiler_not_implemented(avg)

src/datachain/sql/sqlite/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def setup():
7878
compiles(conditional.least, "sqlite")(compile_least)
7979
compiles(Values, "sqlite")(compile_values)
8080
compiles(random.rand, "sqlite")(compile_rand)
81+
compiles(array.avg, "sqlite")(compile_avg)
8182

8283
if load_usearch_extension(sqlite3.connect(":memory:")):
8384
compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext)
@@ -349,6 +350,10 @@ def compile_rand(element, compiler, **kwargs):
349350
return compiler.process(func.random(), **kwargs)
350351

351352

353+
def compile_avg(element, compiler, **kwargs):
354+
return compiler.process(func.avg(*element.clauses.clauses), **kwargs)
355+
356+
352357
def load_usearch_extension(conn) -> bool:
353358
try:
354359
# usearch is part of the vector optional dependencies

tests/unit/lib/test_datachain.py

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from datachain import Column
1212
from datachain.lib.data_model import DataModel
13-
from datachain.lib.dc import C, DataChain, Sys
13+
from datachain.lib.dc import C, DataChain, DataChainColumnError, Sys
1414
from datachain.lib.file import File
1515
from datachain.lib.signal_schema import (
1616
SignalResolvingError,
@@ -19,6 +19,8 @@
1919
)
2020
from datachain.lib.udf_signature import UdfSignatureError
2121
from datachain.lib.utils import DataChainParamsError
22+
from datachain.sql import functions as func
23+
from datachain.sql.types import Float, Int64, String
2224
from tests.utils import skip_if_not_sqlite
2325

2426
DF_DATA = {
@@ -1254,14 +1256,20 @@ def test_column_math(test_session):
12541256
fib = [1, 1, 2, 3, 5, 8]
12551257
chain = DataChain.from_values(num=fib, session=test_session)
12561258

1257-
ch = chain.mutate(add2=Column("num") + 2)
1259+
ch = chain.mutate(add2=chain.column("num") + 2)
12581260
assert list(ch.collect("add2")) == [x + 2 for x in fib]
12591261

1260-
ch = chain.mutate(div2=Column("num") / 2.0)
1261-
assert list(ch.collect("div2")) == [x / 2.0 for x in fib]
1262+
ch2 = ch.mutate(x=1 - ch.column("add2"))
1263+
assert list(ch2.collect("x")) == [1 - (x + 2.0) for x in fib]
1264+
1265+
1266+
def test_column_math_division(test_session):
1267+
skip_if_not_sqlite()
1268+
fib = [1, 1, 2, 3, 5, 8]
1269+
chain = DataChain.from_values(num=fib, session=test_session)
12621270

1263-
ch2 = ch.mutate(x=1 - Column("div2"))
1264-
assert list(ch2.collect("x")) == [1 - (x / 2.0) for x in fib]
1271+
ch = chain.mutate(div2=chain.column("num") / 2.0)
1272+
assert list(ch.collect("div2")) == [x / 2.0 for x in fib]
12651273

12661274

12671275
def test_from_values_array_of_floats(test_session):
@@ -1409,3 +1417,83 @@ def test_rename_object_name_with_mutate(catalog):
14091417
assert ds.signals_schema.values.get("ids") is int
14101418
assert "file" not in ds.signals_schema.values
14111419
assert list(ds.order_by("my_file.name").collect("my_file.name")) == ["a", "b", "c"]
1420+
1421+
1422+
def test_column(catalog):
1423+
ds = DataChain.from_values(
1424+
ints=[1, 2], floats=[0.5, 0.5], file=[File(name="a"), File(name="b")]
1425+
)
1426+
1427+
c = ds.column("ints")
1428+
assert isinstance(c, Column)
1429+
assert c.name == "ints"
1430+
assert isinstance(c.type, Int64)
1431+
1432+
c = ds.column("floats")
1433+
assert isinstance(c, Column)
1434+
assert c.name == "floats"
1435+
assert isinstance(c.type, Float)
1436+
1437+
c = ds.column("file.name")
1438+
assert isinstance(c, Column)
1439+
assert c.name == "file__name"
1440+
assert isinstance(c.type, String)
1441+
1442+
with pytest.raises(ValueError):
1443+
c = ds.column("missing")
1444+
1445+
1446+
def test_mutate_with_subtraction():
1447+
ds = DataChain.from_values(id=[1, 2])
1448+
assert ds.mutate(new=ds.column("id") - 1).signals_schema.values["new"] is int
1449+
1450+
1451+
def test_mutate_with_addition():
1452+
ds = DataChain.from_values(id=[1, 2])
1453+
assert ds.mutate(new=ds.column("id") + 1).signals_schema.values["new"] is int
1454+
1455+
1456+
def test_mutate_with_division():
1457+
ds = DataChain.from_values(id=[1, 2])
1458+
assert ds.mutate(new=ds.column("id") / 10).signals_schema.values["new"] is float
1459+
1460+
1461+
def test_mutate_with_multiplication():
1462+
ds = DataChain.from_values(id=[1, 2])
1463+
assert ds.mutate(new=ds.column("id") * 10).signals_schema.values["new"] is int
1464+
1465+
1466+
def test_mutate_with_func():
1467+
ds = DataChain.from_values(id=[1, 2])
1468+
assert (
1469+
ds.mutate(new=func.avg(ds.column("id"))).signals_schema.values["new"] is float
1470+
)
1471+
1472+
1473+
def test_mutate_with_complex_expression():
1474+
ds = DataChain.from_values(id=[1, 2], name=["Jim", "Jon"])
1475+
assert (
1476+
ds.mutate(
1477+
new=(func.sum(ds.column("id"))) * (5 - func.min(ds.column("id")))
1478+
).signals_schema.values["new"]
1479+
is int
1480+
)
1481+
1482+
1483+
def test_mutate_with_saving():
1484+
skip_if_not_sqlite()
1485+
ds = DataChain.from_values(id=[1, 2])
1486+
ds = ds.mutate(new=ds.column("id") / 2).save("mutated")
1487+
1488+
ds = DataChain(name="mutated")
1489+
assert ds.signals_schema.values["new"] is float
1490+
assert list(ds.collect("new")) == [0.5, 1.0]
1491+
1492+
1493+
def test_mutate_with_expression_without_type(catalog):
1494+
with pytest.raises(DataChainColumnError) as excinfo:
1495+
DataChain.from_values(id=[1, 2]).mutate(new=(Column("id") - 1)).save()
1496+
1497+
assert str(excinfo.value) == (
1498+
"Error for column new: Cannot infer type with expression id - :id_1"
1499+
)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from sqlalchemy.sql.sqltypes import NullType
2+
3+
from datachain import Column
4+
from datachain.lib.convert.sql_to_python import sql_to_python
5+
from datachain.sql import functions as func
6+
from datachain.sql.types import Float, Int64, String
7+
8+
9+
def test_sql_columns_to_python_types():
10+
assert sql_to_python(
11+
{
12+
"name": Column("name", String),
13+
"age": Column("age", Int64),
14+
"score": Column("score", Float),
15+
}
16+
) == {"name": str, "age": int, "score": float}
17+
18+
19+
def test_sql_expression_to_python_types():
20+
assert sql_to_python({"age": Column("age", Int64) - 2}) == {"age": int}
21+
22+
23+
def test_sql_function_to_python_types():
24+
assert sql_to_python({"age": func.avg(Column("age", Int64))}) == {"age": float}
25+
26+
27+
def test_sql_to_python_types_default_type():
28+
assert sql_to_python({"null": Column("null", NullType)}) == {"null": str}

0 commit comments

Comments
 (0)