|
10 | 10 |
|
11 | 11 | from datachain import Column |
12 | 12 | from datachain.lib.data_model import DataModel |
13 | | -from datachain.lib.dc import C, DataChain, Sys |
| 13 | +from datachain.lib.dc import C, DataChain, DataChainColumnError, Sys |
14 | 14 | from datachain.lib.file import File |
15 | 15 | from datachain.lib.signal_schema import ( |
16 | 16 | SignalResolvingError, |
|
19 | 19 | ) |
20 | 20 | from datachain.lib.udf_signature import UdfSignatureError |
21 | 21 | from datachain.lib.utils import DataChainParamsError |
| 22 | +from datachain.sql import functions as func |
| 23 | +from datachain.sql.types import Float, Int64, String |
22 | 24 | from tests.utils import skip_if_not_sqlite |
23 | 25 |
|
24 | 26 | DF_DATA = { |
@@ -1254,14 +1256,20 @@ def test_column_math(test_session): |
1254 | 1256 | fib = [1, 1, 2, 3, 5, 8] |
1255 | 1257 | chain = DataChain.from_values(num=fib, session=test_session) |
1256 | 1258 |
|
1257 | | - ch = chain.mutate(add2=Column("num") + 2) |
| 1259 | + ch = chain.mutate(add2=chain.column("num") + 2) |
1258 | 1260 | assert list(ch.collect("add2")) == [x + 2 for x in fib] |
1259 | 1261 |
|
1260 | | - ch = chain.mutate(div2=Column("num") / 2.0) |
1261 | | - assert list(ch.collect("div2")) == [x / 2.0 for x in fib] |
| 1262 | + ch2 = ch.mutate(x=1 - ch.column("add2")) |
| 1263 | + assert list(ch2.collect("x")) == [1 - (x + 2.0) for x in fib] |
| 1264 | + |
| 1265 | + |
| 1266 | +def test_column_math_division(test_session): |
| 1267 | + skip_if_not_sqlite() |
| 1268 | + fib = [1, 1, 2, 3, 5, 8] |
| 1269 | + chain = DataChain.from_values(num=fib, session=test_session) |
1262 | 1270 |
|
1263 | | - ch2 = ch.mutate(x=1 - Column("div2")) |
1264 | | - assert list(ch2.collect("x")) == [1 - (x / 2.0) for x in fib] |
| 1271 | + ch = chain.mutate(div2=chain.column("num") / 2.0) |
| 1272 | + assert list(ch.collect("div2")) == [x / 2.0 for x in fib] |
1265 | 1273 |
|
1266 | 1274 |
|
1267 | 1275 | def test_from_values_array_of_floats(test_session): |
@@ -1409,3 +1417,83 @@ def test_rename_object_name_with_mutate(catalog): |
1409 | 1417 | assert ds.signals_schema.values.get("ids") is int |
1410 | 1418 | assert "file" not in ds.signals_schema.values |
1411 | 1419 | assert list(ds.order_by("my_file.name").collect("my_file.name")) == ["a", "b", "c"] |
| 1420 | + |
| 1421 | + |
| 1422 | +def test_column(catalog): |
| 1423 | + ds = DataChain.from_values( |
| 1424 | + ints=[1, 2], floats=[0.5, 0.5], file=[File(name="a"), File(name="b")] |
| 1425 | + ) |
| 1426 | + |
| 1427 | + c = ds.column("ints") |
| 1428 | + assert isinstance(c, Column) |
| 1429 | + assert c.name == "ints" |
| 1430 | + assert isinstance(c.type, Int64) |
| 1431 | + |
| 1432 | + c = ds.column("floats") |
| 1433 | + assert isinstance(c, Column) |
| 1434 | + assert c.name == "floats" |
| 1435 | + assert isinstance(c.type, Float) |
| 1436 | + |
| 1437 | + c = ds.column("file.name") |
| 1438 | + assert isinstance(c, Column) |
| 1439 | + assert c.name == "file__name" |
| 1440 | + assert isinstance(c.type, String) |
| 1441 | + |
| 1442 | + with pytest.raises(ValueError): |
| 1443 | + c = ds.column("missing") |
| 1444 | + |
| 1445 | + |
| 1446 | +def test_mutate_with_subtraction(): |
| 1447 | + ds = DataChain.from_values(id=[1, 2]) |
| 1448 | + assert ds.mutate(new=ds.column("id") - 1).signals_schema.values["new"] is int |
| 1449 | + |
| 1450 | + |
| 1451 | +def test_mutate_with_addition(): |
| 1452 | + ds = DataChain.from_values(id=[1, 2]) |
| 1453 | + assert ds.mutate(new=ds.column("id") + 1).signals_schema.values["new"] is int |
| 1454 | + |
| 1455 | + |
| 1456 | +def test_mutate_with_division(): |
| 1457 | + ds = DataChain.from_values(id=[1, 2]) |
| 1458 | + assert ds.mutate(new=ds.column("id") / 10).signals_schema.values["new"] is float |
| 1459 | + |
| 1460 | + |
| 1461 | +def test_mutate_with_multiplication(): |
| 1462 | + ds = DataChain.from_values(id=[1, 2]) |
| 1463 | + assert ds.mutate(new=ds.column("id") * 10).signals_schema.values["new"] is int |
| 1464 | + |
| 1465 | + |
| 1466 | +def test_mutate_with_func(): |
| 1467 | + ds = DataChain.from_values(id=[1, 2]) |
| 1468 | + assert ( |
| 1469 | + ds.mutate(new=func.avg(ds.column("id"))).signals_schema.values["new"] is float |
| 1470 | + ) |
| 1471 | + |
| 1472 | + |
| 1473 | +def test_mutate_with_complex_expression(): |
| 1474 | + ds = DataChain.from_values(id=[1, 2], name=["Jim", "Jon"]) |
| 1475 | + assert ( |
| 1476 | + ds.mutate( |
| 1477 | + new=(func.sum(ds.column("id"))) * (5 - func.min(ds.column("id"))) |
| 1478 | + ).signals_schema.values["new"] |
| 1479 | + is int |
| 1480 | + ) |
| 1481 | + |
| 1482 | + |
| 1483 | +def test_mutate_with_saving(): |
| 1484 | + skip_if_not_sqlite() |
| 1485 | + ds = DataChain.from_values(id=[1, 2]) |
| 1486 | + ds = ds.mutate(new=ds.column("id") / 2).save("mutated") |
| 1487 | + |
| 1488 | + ds = DataChain(name="mutated") |
| 1489 | + assert ds.signals_schema.values["new"] is float |
| 1490 | + assert list(ds.collect("new")) == [0.5, 1.0] |
| 1491 | + |
| 1492 | + |
| 1493 | +def test_mutate_with_expression_without_type(catalog): |
| 1494 | + with pytest.raises(DataChainColumnError) as excinfo: |
| 1495 | + DataChain.from_values(id=[1, 2]).mutate(new=(Column("id") - 1)).save() |
| 1496 | + |
| 1497 | + assert str(excinfo.value) == ( |
| 1498 | + "Error for column new: Cannot infer type with expression id - :id_1" |
| 1499 | + ) |
0 commit comments