diff --git a/src/datachain/data_storage/db_engine.py b/src/datachain/data_storage/db_engine.py index 2c0ca4686..24ef7dd47 100644 --- a/src/datachain/data_storage/db_engine.py +++ b/src/datachain/data_storage/db_engine.py @@ -112,7 +112,7 @@ def has_table(self, name: str) -> bool: return sa.inspect(self.engine).has_table(name) @abstractmethod - def create_table(self, table: "Table", if_not_exists: bool = True) -> None: ... + def create_table(self, table: "Table", if_not_exists: bool = True) -> "Table": ... @abstractmethod def drop_table(self, table: "Table", if_exists: bool = False) -> None: ... diff --git a/src/datachain/data_storage/sqlite.py b/src/datachain/data_storage/sqlite.py index d0247a4df..ecdfe7eab 100644 --- a/src/datachain/data_storage/sqlite.py +++ b/src/datachain/data_storage/sqlite.py @@ -276,8 +276,9 @@ def has_table(self, name: str) -> bool: ) return bool(next(self.execute(query))[0]) - def create_table(self, table: "Table", if_not_exists: bool = True) -> None: + def create_table(self, table: "Table", if_not_exists: bool = True) -> "Table": self.execute(CreateTable(table, if_not_exists=if_not_exists)) + return table def drop_table(self, table: "Table", if_exists: bool = False) -> None: self.execute(DropTable(table, if_exists=if_exists)) diff --git a/tests/func/test_datachain_merge.py b/tests/func/test_datachain_merge.py index 78b5049f1..dafdae85d 100644 --- a/tests/func/test_datachain_merge.py +++ b/tests/func/test_datachain_merge.py @@ -1,7 +1,6 @@ import pytest from datachain.lib.dc import DataChain -from datachain.sql.types import Int @pytest.mark.parametrize( @@ -11,7 +10,6 @@ ) @pytest.mark.parametrize("inner", [True, False]) def test_merge_union(cloud_test_catalog, inner, cloud_type): - catalog = cloud_test_catalog.catalog session = cloud_test_catalog.session src = cloud_test_catalog.src_uri @@ -19,8 +17,6 @@ def test_merge_union(cloud_test_catalog, inner, cloud_type): dogs = DataChain.from_storage(f"{src}/dogs/*", session=session) cats = DataChain.from_storage(f"{src}/cats/*", session=session) - signal_default_value = Int.default_value(catalog.warehouse.db.dialect) - dogs1 = dogs.map(sig1=lambda: 1, output={"sig1": int}) dogs2 = dogs.map(sig2=lambda: 2, output={"sig2": int}) cats1 = cats.map(sig1=lambda: 1, output={"sig1": int}) @@ -37,8 +33,8 @@ def test_merge_union(cloud_test_catalog, inner, cloud_type): ] else: assert signals == [ - ("cats/cat1", 1, signal_default_value), - ("cats/cat2", 1, signal_default_value), + ("cats/cat1", 1, None), + ("cats/cat2", 1, None), ("dogs/dog1", 1, 2), ("dogs/dog2", 1, 2), ("dogs/dog3", 1, 2), @@ -55,7 +51,6 @@ def test_merge_union(cloud_test_catalog, inner, cloud_type): @pytest.mark.parametrize("inner2", [True, False]) @pytest.mark.parametrize("inner3", [True, False]) def test_merge_multiple(cloud_test_catalog, inner1, inner2, inner3): - catalog = cloud_test_catalog.catalog session = cloud_test_catalog.session src = cloud_test_catalog.src_uri @@ -63,8 +58,6 @@ def test_merge_multiple(cloud_test_catalog, inner1, inner2, inner3): dogs = DataChain.from_storage(f"{src}/dogs/*", session=session) cats = DataChain.from_storage(f"{src}/cats/*", session=session) - signal_default_value = Int.default_value(catalog.warehouse.db.dialect) - dogs_and_cats = dogs | cats dogs1 = dogs.map(sig1=lambda: 1, output={"sig1": int}) cats1 = cats.map(sig2=lambda: 2, output={"sig2": int}) @@ -80,22 +73,22 @@ def test_merge_multiple(cloud_test_catalog, inner1, inner2, inner3): assert merged_signals == [] elif inner1: assert merged_signals == [ - ("dogs/dog1", 1, signal_default_value), - ("dogs/dog2", 1, signal_default_value), - ("dogs/dog3", 1, signal_default_value), - ("dogs/others/dog4", 1, signal_default_value), + ("dogs/dog1", 1, None), + ("dogs/dog2", 1, None), + ("dogs/dog3", 1, None), + ("dogs/others/dog4", 1, None), ] elif inner2 and inner3: assert merged_signals == [ - ("cats/cat1", signal_default_value, 2), - ("cats/cat2", signal_default_value, 2), + ("cats/cat1", None, 2), + ("cats/cat2", None, 2), ] else: assert merged_signals == [ - ("cats/cat1", signal_default_value, 2), - ("cats/cat2", signal_default_value, 2), - ("dogs/dog1", 1, signal_default_value), - ("dogs/dog2", 1, signal_default_value), - ("dogs/dog3", 1, signal_default_value), - ("dogs/others/dog4", 1, signal_default_value), + ("cats/cat1", None, 2), + ("cats/cat2", None, 2), + ("dogs/dog1", 1, None), + ("dogs/dog2", 1, None), + ("dogs/dog3", 1, None), + ("dogs/others/dog4", 1, None), ] diff --git a/tests/func/test_dataset_query.py b/tests/func/test_dataset_query.py index 658a47273..1b0e189b9 100644 --- a/tests/func/test_dataset_query.py +++ b/tests/func/test_dataset_query.py @@ -743,10 +743,9 @@ def test_join_with_binary_expression( ("dogs/others/dog4", "dogs/others/dog4"), ] else: - string_default = String.default_value(catalog.warehouse.db.dialect) expected = [ - ("cats/cat1", string_default), - ("cats/cat2", string_default), + ("cats/cat1", None), + ("cats/cat2", None), ("dogs/dog1", "dogs/dog1"), ("dogs/dog2", "dogs/dog2"), ("dogs/dog3", "dogs/dog3"), @@ -793,10 +792,9 @@ def test_join_with_combination_binary_expression_and_column_predicates( ("dogs/others/dog4", "dogs/others/dog4"), ] else: - string_default = String.default_value(catalog.warehouse.db.dialect) expected = [ - ("cats/cat1", string_default), - ("cats/cat2", string_default), + ("cats/cat1", None), + ("cats/cat2", None), ("dogs/dog1", "dogs/dog1"), ("dogs/dog2", "dogs/dog2"), ("dogs/dog3", "dogs/dog3"), @@ -918,10 +916,9 @@ def test_join_with_using_functions_in_expression( ("dogs/others/dog4", "dogs/others/dog4"), ] else: - string_default = String.default_value(catalog.warehouse.db.dialect) expected = [ - ("cats/cat1", string_default), - ("cats/cat2", string_default), + ("cats/cat1", None), + ("cats/cat2", None), ("dogs/dog1", "dogs/dog1"), ("dogs/dog2", "dogs/dog2"), ("dogs/dog3", "dogs/dog3"), diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 1d0496170..92f3dfec8 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -1052,17 +1052,13 @@ def test_parse_nested_json(tmp_dir, test_session): # E.g. nAmE -> name, l--as@t -> l_as_t, etc df1 = dc.select("na_me", "age", "city").to_pandas() - # In CH we replace None with '' for peforance reasons, - # have to handle it here - string_default = String.default_value(test_session.catalog.warehouse.db.dialect) - assert sorted(df1["na_me"]["first_select"].to_list()) == sorted( d["first-SELECT"] for d in df["nA-mE"].to_list() ) assert sorted( df1["na_me"]["l_as_t"].to_list(), key=lambda x: (x is None, x) ) == sorted( - [d.get("l--as@t", string_default) for d in df["nA-mE"].to_list()], + [d.get("l--as@t", None) for d in df["nA-mE"].to_list()], key=lambda x: (x is None, x), ) @@ -1304,6 +1300,7 @@ def test_from_csv_null_collect(tmp_dir, test_session): for i, row in enumerate(dc.collect()): # None value in numeric column will get converted to nan. if not height[i]: + print(row[1].height) assert math.isnan(row[1].height) else: assert row[1].height == height[i] @@ -1420,10 +1417,6 @@ def test_explode(tmp_dir, test_session, column_type, object_name, model_name): object_name = object_name or "content_expl" model_name = model_name or "ContentExplodedModel" - # In CH we have (atm at least) None converted to '' - # for performance reasons, so we need to handle this case - string_default = String.default_value(test_session.catalog.warehouse.db.dialect) - assert set( dc.collect( f"{object_name}.na_me.first_select", @@ -1433,7 +1426,7 @@ def test_explode(tmp_dir, test_session, column_type, object_name, model_name): ) == { ("Alice", 25, "New York"), ("Bob", 30, "Los Angeles"), - ("Charlie", 35, string_default), + ("Charlie", 35, None), ("David", 40, "Houston"), ("Eva", 45, "Phoenix"), ("Ivan", 41, "San Francisco"), @@ -2097,6 +2090,22 @@ def test_from_values_array_of_floats(test_session): assert list(chain.order_by("emd").collect("emd")) == embeddings +def test_from_values_array_of_ints_with_nones(test_session): + ids = [1, 2] + embeddings = [[1, None], [4, 5]] + chain = DataChain.from_values(emd=embeddings, ids=ids, session=test_session) + + assert list(chain.order_by("ids").collect("emd")) == embeddings + + +def test_from_values_with_nones(test_session): + ids = [1, 2, 3, 4] + embeddings = [100, None, 300, None] + chain = DataChain.from_values(emd=embeddings, ids=ids, session=test_session) + + assert list(chain.order_by("ids").collect("emd")) == [100, None, 300, None] + + def test_custom_model_with_nested_lists(test_session): class Trace(BaseModel): x: float diff --git a/tests/unit/lib/test_datachain_merge.py b/tests/unit/lib/test_datachain_merge.py index e55e5ea1c..30e3c2fa1 100644 --- a/tests/unit/lib/test_datachain_merge.py +++ b/tests/unit/lib/test_datachain_merge.py @@ -7,7 +7,6 @@ from sqlalchemy import func from datachain.lib.dc import C, DataChain, DatasetMergeError -from datachain.sql.types import Int, String from tests.utils import skip_if_not_sqlite @@ -52,8 +51,6 @@ def test_merge_objects(test_session): ch2 = DataChain.from_values(team=team, session=test_session) ch = ch1.merge(ch2, "emp.person.name", "team.player") - str_default = String.default_value(test_session.catalog.warehouse.db.dialect) - i = 0 j = 0 for items in ch.order_by("emp.person.name", "team.player").collect(): @@ -72,8 +69,8 @@ def test_merge_objects(test_session): assert math.isclose(player.height, team[j].height, rel_tol=1e-7) j += 1 else: - assert player.player == str_default - assert player.sport == str_default + assert player.player is None + assert player.sport is None assert pd.isnull(player.weight) assert pd.isnull(player.height) @@ -95,9 +92,6 @@ def test_merge_objects_full_join(test_session, multiple_predicates): else: ch = ch1.merge(ch2, "emp.person.name", "team.player", full=True) - str_default = String.default_value(test_session.catalog.warehouse.db.dialect) - int_default = Int.default_value(test_session.catalog.warehouse.db.dialect) - i = 0 for items in ch.order_by("emp.person.name", "team.player").collect(): assert len(items) == 2 @@ -107,13 +101,13 @@ def test_merge_objects_full_join(test_session, multiple_predicates): assert isinstance(player, TeamMember) if player.player == "John": - assert empl.person.name == str_default - assert empl.person.age == int_default + assert empl.person.name is None + assert empl.person.age is None continue if empl.person.name == "Bob": - assert player.player == str_default - assert player.sport == str_default + assert player.player is None + assert player.sport is None assert pd.isnull(player.weight) assert pd.isnull(player.height) continue diff --git a/tests/unit/lib/test_diff.py b/tests/unit/lib/test_diff.py index e7a128df3..176ad50b7 100644 --- a/tests/unit/lib/test_diff.py +++ b/tests/unit/lib/test_diff.py @@ -6,7 +6,6 @@ from datachain.diff import CompareStatus, compare_and_split from datachain.lib.dc import DataChain from datachain.lib.file import File -from datachain.sql.types import Int64, String from tests.utils import sorted_dicts @@ -169,15 +168,13 @@ def test_compare_with_explicit_compare_fields(test_session, right_name): status_col="diff", ) - string_default = String.default_value(test_session.catalog.warehouse.db.dialect) - expected = [ (CompareStatus.MODIFIED, 1, "John1", "New York"), (CompareStatus.ADDED, 2, "Doe", "Boston"), ( CompareStatus.DELETED, 3, - string_default if right_name == "other_name" else "Mark", + None if right_name == "other_name" else "Mark", "Seattle", ), (CompareStatus.SAME, 4, "Andy", "San Francisco"), @@ -208,13 +205,11 @@ def test_compare_different_left_right_on_columns(test_session): status_col="diff", ) - int_default = Int64.default_value(test_session.catalog.warehouse.db.dialect) - expected = [ (CompareStatus.SAME, 4, "Andy"), (CompareStatus.ADDED, 2, "Doe"), (CompareStatus.MODIFIED, 1, "John1"), - (CompareStatus.DELETED, int_default, "Mark"), + (CompareStatus.DELETED, None, "Mark"), ] collect_fields = ["diff", "id", "name"] @@ -322,8 +317,6 @@ def test_compare_additional_column_on_left(test_session): session=test_session, ).save("ds2") - string_default = String.default_value(test_session.catalog.warehouse.db.dialect) - diff = ds1.compare(ds2, same=True, on=["id"], status_col="diff") assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( @@ -334,7 +327,7 @@ def test_compare_additional_column_on_left(test_session): "diff": CompareStatus.DELETED, "id": 3, "name": "Mark", - "city": string_default, + "city": None, }, {"diff": CompareStatus.MODIFIED, "id": 4, "name": "Andy", "city": "Tokyo"}, ],