Skip to content

Commit 6c17b68

Browse files
authored
remove duplicate tests (#1397)
1 parent a72c865 commit 6c17b68

File tree

2 files changed

+11
-373
lines changed

2 files changed

+11
-373
lines changed

tests/func/test_datachain.py

Lines changed: 1 addition & 373 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from PIL import Image
1717

1818
import datachain as dc
19-
from datachain import DataChainError, DataModel, Mapper, func
19+
from datachain import DataModel, func
2020
from datachain.data_storage.sqlite import SQLiteWarehouse
2121
from datachain.dataset import DatasetDependencyType
2222
from datachain.lib.file import File, ImageFile
@@ -25,7 +25,6 @@
2525
from datachain.query.dataset import QueryStep
2626
from tests.utils import (
2727
ANY_VALUE,
28-
LARGE_TREE,
2928
TARRED_TREE,
3029
df_equal,
3130
images_equal,
@@ -734,377 +733,6 @@ def test_parallel(processes, test_session_tmpfile):
734733
assert res == [prefix + v for v in vals]
735734

736735

737-
@pytest.mark.parametrize(
738-
"cloud_type,version_aware",
739-
[("s3", True)],
740-
indirect=True,
741-
)
742-
def test_udf(cloud_test_catalog):
743-
session = cloud_test_catalog.session
744-
745-
def name_len(path):
746-
return (len(posixpath.basename(path)),)
747-
748-
chain = (
749-
dc.read_storage(cloud_test_catalog.src_uri, session=session)
750-
.filter(dc.C("file.size") < 13)
751-
.filter(dc.C("file.path").glob("cats*") | (dc.C("file.size") < 4))
752-
.map(name_len, params=["file.path"], output={"name_len": int})
753-
)
754-
result1 = chain.select("file.path", "name_len").to_list()
755-
# ensure that we're able to run with same query multiple times
756-
result2 = chain.select("file.path", "name_len").to_list()
757-
count = chain.count()
758-
assert len(result1) == 3
759-
assert len(result2) == 3
760-
assert count == 3
761-
762-
for r1, r2 in zip(result1, result2, strict=False):
763-
# Check that the UDF ran successfully
764-
assert len(posixpath.basename(r1[0])) == r1[1]
765-
assert len(posixpath.basename(r2[0])) == r2[1]
766-
767-
768-
@pytest.mark.parametrize(
769-
"cloud_type,version_aware",
770-
[("s3", True)],
771-
indirect=True,
772-
)
773-
@pytest.mark.xdist_group(name="tmpfile")
774-
def test_udf_parallel(cloud_test_catalog_tmpfile):
775-
session = cloud_test_catalog_tmpfile.session
776-
777-
def name_len(name):
778-
return (len(name),)
779-
780-
chain = (
781-
dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session)
782-
.settings(parallel=True)
783-
.map(name_len, params=["file.path"], output={"name_len": int})
784-
.select("file.path", "name_len")
785-
)
786-
787-
# Check that the UDF ran successfully
788-
count = 0
789-
for r in chain:
790-
count += 1
791-
assert len(r[0]) == r[1]
792-
assert count == 7
793-
794-
795-
@pytest.mark.xdist_group(name="tmpfile")
796-
def test_udf_parallel_boostrap(test_session_tmpfile):
797-
vals = ["a", "b", "c", "d", "e", "f"]
798-
799-
class MyMapper(Mapper):
800-
DEFAULT_VALUE = 84
801-
BOOTSTRAP_VALUE = 1452
802-
TEARDOWN_VALUE = 98763
803-
804-
def __init__(self):
805-
super().__init__()
806-
self.value = MyMapper.DEFAULT_VALUE
807-
self._had_teardown = False
808-
809-
def process(self, key) -> int:
810-
return self.value
811-
812-
def setup(self):
813-
self.value = MyMapper.BOOTSTRAP_VALUE
814-
815-
def teardown(self):
816-
self.value = MyMapper.TEARDOWN_VALUE
817-
818-
chain = dc.read_values(key=vals, session=test_session_tmpfile)
819-
820-
res = chain.settings(parallel=4).map(res=MyMapper()).to_values("res")
821-
822-
assert res == [MyMapper.BOOTSTRAP_VALUE] * len(vals)
823-
824-
825-
@pytest.mark.parametrize(
826-
"cloud_type,version_aware,tree",
827-
[("s3", True, LARGE_TREE)],
828-
indirect=True,
829-
)
830-
@pytest.mark.parametrize("workers", (1, 2))
831-
@pytest.mark.parametrize("parallel", (1, 2))
832-
@pytest.mark.skipif(
833-
"not os.environ.get('DATACHAIN_DISTRIBUTED')",
834-
reason="Set the DATACHAIN_DISTRIBUTED environment variable "
835-
"to test distributed UDFs",
836-
)
837-
@pytest.mark.xdist_group(name="tmpfile")
838-
def test_udf_distributed(
839-
cloud_test_catalog_tmpfile, workers, parallel, tree, run_datachain_worker
840-
):
841-
session = cloud_test_catalog_tmpfile.session
842-
843-
def name_len(name):
844-
return (len(name),)
845-
846-
chain = (
847-
dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session)
848-
.settings(parallel=parallel, workers=workers)
849-
.map(name_len, params=["file.path"], output={"name_len": int})
850-
.select("file.path", "name_len")
851-
)
852-
853-
# Check that the UDF ran successfully
854-
count = 0
855-
for r in chain:
856-
count += 1
857-
assert len(r[0]) == r[1]
858-
assert count == 225
859-
860-
861-
@pytest.mark.parametrize(
862-
"cloud_type,version_aware",
863-
[("s3", True)],
864-
indirect=True,
865-
)
866-
def test_class_udf(cloud_test_catalog):
867-
session = cloud_test_catalog.session
868-
869-
class MyUDF(Mapper):
870-
def __init__(self, constant, multiplier=1):
871-
self.constant = constant
872-
self.multiplier = multiplier
873-
874-
def process(self, size):
875-
return (self.constant + size * self.multiplier,)
876-
877-
chain = (
878-
dc.read_storage(cloud_test_catalog.src_uri, session=session)
879-
.filter(dc.C("file.size") < 13)
880-
.map(
881-
MyUDF(5, multiplier=2),
882-
output={"total": int},
883-
params=["file.size"],
884-
)
885-
.select("file.size", "total")
886-
.order_by("file.size")
887-
)
888-
889-
assert chain.to_list() == [
890-
(3, 11),
891-
(4, 13),
892-
(4, 13),
893-
(4, 13),
894-
(4, 13),
895-
(4, 13),
896-
]
897-
898-
899-
@pytest.mark.parametrize(
900-
"cloud_type,version_aware",
901-
[("s3", True)],
902-
indirect=True,
903-
)
904-
@pytest.mark.xdist_group(name="tmpfile")
905-
def test_class_udf_parallel(cloud_test_catalog_tmpfile):
906-
session = cloud_test_catalog_tmpfile.session
907-
908-
class MyUDF(Mapper):
909-
def __init__(self, constant, multiplier=1):
910-
self.constant = constant
911-
self.multiplier = multiplier
912-
913-
def process(self, size):
914-
return (self.constant + size * self.multiplier,)
915-
916-
chain = (
917-
dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session)
918-
.filter(dc.C("file.size") < 13)
919-
.settings(parallel=2)
920-
.map(
921-
MyUDF(5, multiplier=2),
922-
output={"total": int},
923-
params=["file.size"],
924-
)
925-
.select("file.size", "total")
926-
.order_by("file.size")
927-
)
928-
929-
assert chain.to_list() == [
930-
(3, 11),
931-
(4, 13),
932-
(4, 13),
933-
(4, 13),
934-
(4, 13),
935-
(4, 13),
936-
]
937-
938-
939-
@pytest.mark.parametrize(
940-
"cloud_type,version_aware",
941-
[("s3", True)],
942-
indirect=True,
943-
)
944-
@pytest.mark.xdist_group(name="tmpfile")
945-
def test_udf_parallel_exec_error(cloud_test_catalog_tmpfile):
946-
session = cloud_test_catalog_tmpfile.session
947-
948-
def name_len_error(_name):
949-
# A udf that raises an exception
950-
raise RuntimeError("Test Error!")
951-
952-
chain = (
953-
dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session)
954-
.filter(dc.C("file.size") < 13)
955-
.filter(dc.C("file.path").glob("cats*") | (dc.C("file.size") < 4))
956-
.settings(parallel=True)
957-
.map(name_len_error, params=["file.path"], output={"name_len": int})
958-
)
959-
960-
if os.environ.get("DATACHAIN_DISTRIBUTED"):
961-
# in distributed mode we expect DataChainError with the error message
962-
with pytest.raises(DataChainError, match="Test Error!"):
963-
chain.show()
964-
else:
965-
# while in local mode we expect RuntimeError with the error message
966-
with pytest.raises(RuntimeError, match="UDF Execution Failed!"):
967-
chain.show()
968-
969-
970-
@pytest.mark.parametrize(
971-
"cloud_type,version_aware,tree",
972-
[("s3", True, LARGE_TREE)],
973-
indirect=True,
974-
)
975-
@pytest.mark.parametrize("workers", (1, 2))
976-
@pytest.mark.parametrize("parallel", (1, 2))
977-
@pytest.mark.skipif(
978-
"not os.environ.get('DATACHAIN_DISTRIBUTED')",
979-
reason="Set the DATACHAIN_DISTRIBUTED environment variable "
980-
"to test distributed UDFs",
981-
)
982-
@pytest.mark.xdist_group(name="tmpfile")
983-
def test_udf_distributed_exec_error(
984-
cloud_test_catalog_tmpfile, workers, parallel, tree, run_datachain_worker
985-
):
986-
session = cloud_test_catalog_tmpfile.session
987-
988-
def name_len_error(_name):
989-
# A udf that raises an exception
990-
raise RuntimeError("Test Error!")
991-
992-
chain = (
993-
dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session)
994-
.filter(dc.C("file.size") < 13)
995-
.filter(dc.C("file.path").glob("cats*") | (dc.C("file.size") < 4))
996-
.settings(parallel=parallel, workers=workers)
997-
.map(name_len_error, params=["file.path"], output={"name_len": int})
998-
)
999-
with pytest.raises(DataChainError, match="Test Error!"):
1000-
chain.show()
1001-
1002-
1003-
@pytest.mark.parametrize(
1004-
"cloud_type,version_aware",
1005-
[("s3", True)],
1006-
indirect=True,
1007-
)
1008-
@pytest.mark.xdist_group(name="tmpfile")
1009-
def test_udf_reuse_on_error(cloud_test_catalog_tmpfile):
1010-
session = cloud_test_catalog_tmpfile.session
1011-
1012-
error_state = {"error": True}
1013-
1014-
def name_len_maybe_error(path):
1015-
if error_state["error"]:
1016-
# A udf that raises an exception
1017-
raise RuntimeError("Test Error!")
1018-
return (len(path),)
1019-
1020-
chain = (
1021-
dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session)
1022-
.filter(dc.C("file.size") < 13)
1023-
.filter(dc.C("file.path").glob("cats*") | (dc.C("file.size") < 4))
1024-
.map(name_len_maybe_error, params=["file.path"], output={"path_len": int})
1025-
.select("file.path", "path_len")
1026-
)
1027-
1028-
with pytest.raises(DataChainError, match="Test Error!"):
1029-
chain.show()
1030-
1031-
# Simulate fixing the error
1032-
error_state["error"] = False
1033-
1034-
# Retry Query
1035-
count = 0
1036-
for r in chain:
1037-
# Check that the UDF ran successfully
1038-
count += 1
1039-
assert len(r[0]) == r[1]
1040-
assert count == 3
1041-
1042-
1043-
@pytest.mark.parametrize(
1044-
"cloud_type,version_aware",
1045-
[("s3", True)],
1046-
indirect=True,
1047-
)
1048-
@pytest.mark.xdist_group(name="tmpfile")
1049-
def test_udf_parallel_interrupt(cloud_test_catalog_tmpfile, capfd):
1050-
session = cloud_test_catalog_tmpfile.session
1051-
1052-
def name_len_interrupt(_name):
1053-
# A UDF that emulates cancellation due to a KeyboardInterrupt.
1054-
raise KeyboardInterrupt
1055-
1056-
chain = (
1057-
dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session)
1058-
.filter(dc.C("file.size") < 13)
1059-
.filter(dc.C("file.path").glob("cats*") | (dc.C("file.size") < 4))
1060-
.settings(parallel=True)
1061-
.map(name_len_interrupt, params=["file.path"], output={"name_len": int})
1062-
)
1063-
if os.environ.get("DATACHAIN_DISTRIBUTED"):
1064-
with pytest.raises(KeyboardInterrupt):
1065-
chain.show()
1066-
else:
1067-
with pytest.raises(RuntimeError, match="UDF Execution Failed!"):
1068-
chain.show()
1069-
captured = capfd.readouterr()
1070-
assert "semaphore" not in captured.err
1071-
1072-
1073-
@pytest.mark.parametrize(
1074-
"cloud_type,version_aware,tree",
1075-
[("s3", True, LARGE_TREE)],
1076-
indirect=True,
1077-
)
1078-
@pytest.mark.skipif(
1079-
"not os.environ.get('DATACHAIN_DISTRIBUTED')",
1080-
reason="Set the DATACHAIN_DISTRIBUTED environment variable "
1081-
"to test distributed UDFs",
1082-
)
1083-
@pytest.mark.parametrize("workers", (1, 2))
1084-
@pytest.mark.parametrize("parallel", (1, 2))
1085-
@pytest.mark.xdist_group(name="tmpfile")
1086-
def test_udf_distributed_interrupt(
1087-
cloud_test_catalog_tmpfile, capfd, tree, workers, parallel, run_datachain_worker
1088-
):
1089-
session = cloud_test_catalog_tmpfile.session
1090-
1091-
def name_len_interrupt(_name):
1092-
# A UDF that emulates cancellation due to a KeyboardInterrupt.
1093-
raise KeyboardInterrupt
1094-
1095-
chain = (
1096-
dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session)
1097-
.filter(dc.C("file.size") < 13)
1098-
.filter(dc.C("file.path").glob("cats*") | (dc.C("file.size") < 4))
1099-
.settings(parallel=parallel, workers=workers)
1100-
.map(name_len_interrupt, params=["file.path"], output={"name_len": int})
1101-
)
1102-
with pytest.raises(KeyboardInterrupt):
1103-
chain.show()
1104-
captured = capfd.readouterr()
1105-
assert "semaphore" not in captured.err
1106-
1107-
1108736
@pytest.mark.parametrize(
1109737
"cloud_type,version_aware",
1110738
[("s3", True)],

0 commit comments

Comments
 (0)