|
16 | 16 | from PIL import Image |
17 | 17 |
|
18 | 18 | import datachain as dc |
19 | | -from datachain import DataChainError, DataModel, Mapper, func |
| 19 | +from datachain import DataModel, func |
20 | 20 | from datachain.data_storage.sqlite import SQLiteWarehouse |
21 | 21 | from datachain.dataset import DatasetDependencyType |
22 | 22 | from datachain.lib.file import File, ImageFile |
|
25 | 25 | from datachain.query.dataset import QueryStep |
26 | 26 | from tests.utils import ( |
27 | 27 | ANY_VALUE, |
28 | | - LARGE_TREE, |
29 | 28 | TARRED_TREE, |
30 | 29 | df_equal, |
31 | 30 | images_equal, |
@@ -734,377 +733,6 @@ def test_parallel(processes, test_session_tmpfile): |
734 | 733 | assert res == [prefix + v for v in vals] |
735 | 734 |
|
736 | 735 |
|
737 | | -@pytest.mark.parametrize( |
738 | | - "cloud_type,version_aware", |
739 | | - [("s3", True)], |
740 | | - indirect=True, |
741 | | -) |
742 | | -def test_udf(cloud_test_catalog): |
743 | | - session = cloud_test_catalog.session |
744 | | - |
745 | | - def name_len(path): |
746 | | - return (len(posixpath.basename(path)),) |
747 | | - |
748 | | - chain = ( |
749 | | - dc.read_storage(cloud_test_catalog.src_uri, session=session) |
750 | | - .filter(dc.C("file.size") < 13) |
751 | | - .filter(dc.C("file.path").glob("cats*") | (dc.C("file.size") < 4)) |
752 | | - .map(name_len, params=["file.path"], output={"name_len": int}) |
753 | | - ) |
754 | | - result1 = chain.select("file.path", "name_len").to_list() |
755 | | - # ensure that we're able to run with same query multiple times |
756 | | - result2 = chain.select("file.path", "name_len").to_list() |
757 | | - count = chain.count() |
758 | | - assert len(result1) == 3 |
759 | | - assert len(result2) == 3 |
760 | | - assert count == 3 |
761 | | - |
762 | | - for r1, r2 in zip(result1, result2, strict=False): |
763 | | - # Check that the UDF ran successfully |
764 | | - assert len(posixpath.basename(r1[0])) == r1[1] |
765 | | - assert len(posixpath.basename(r2[0])) == r2[1] |
766 | | - |
767 | | - |
768 | | -@pytest.mark.parametrize( |
769 | | - "cloud_type,version_aware", |
770 | | - [("s3", True)], |
771 | | - indirect=True, |
772 | | -) |
773 | | -@pytest.mark.xdist_group(name="tmpfile") |
774 | | -def test_udf_parallel(cloud_test_catalog_tmpfile): |
775 | | - session = cloud_test_catalog_tmpfile.session |
776 | | - |
777 | | - def name_len(name): |
778 | | - return (len(name),) |
779 | | - |
780 | | - chain = ( |
781 | | - dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session) |
782 | | - .settings(parallel=True) |
783 | | - .map(name_len, params=["file.path"], output={"name_len": int}) |
784 | | - .select("file.path", "name_len") |
785 | | - ) |
786 | | - |
787 | | - # Check that the UDF ran successfully |
788 | | - count = 0 |
789 | | - for r in chain: |
790 | | - count += 1 |
791 | | - assert len(r[0]) == r[1] |
792 | | - assert count == 7 |
793 | | - |
794 | | - |
795 | | -@pytest.mark.xdist_group(name="tmpfile") |
796 | | -def test_udf_parallel_boostrap(test_session_tmpfile): |
797 | | - vals = ["a", "b", "c", "d", "e", "f"] |
798 | | - |
799 | | - class MyMapper(Mapper): |
800 | | - DEFAULT_VALUE = 84 |
801 | | - BOOTSTRAP_VALUE = 1452 |
802 | | - TEARDOWN_VALUE = 98763 |
803 | | - |
804 | | - def __init__(self): |
805 | | - super().__init__() |
806 | | - self.value = MyMapper.DEFAULT_VALUE |
807 | | - self._had_teardown = False |
808 | | - |
809 | | - def process(self, key) -> int: |
810 | | - return self.value |
811 | | - |
812 | | - def setup(self): |
813 | | - self.value = MyMapper.BOOTSTRAP_VALUE |
814 | | - |
815 | | - def teardown(self): |
816 | | - self.value = MyMapper.TEARDOWN_VALUE |
817 | | - |
818 | | - chain = dc.read_values(key=vals, session=test_session_tmpfile) |
819 | | - |
820 | | - res = chain.settings(parallel=4).map(res=MyMapper()).to_values("res") |
821 | | - |
822 | | - assert res == [MyMapper.BOOTSTRAP_VALUE] * len(vals) |
823 | | - |
824 | | - |
825 | | -@pytest.mark.parametrize( |
826 | | - "cloud_type,version_aware,tree", |
827 | | - [("s3", True, LARGE_TREE)], |
828 | | - indirect=True, |
829 | | -) |
830 | | -@pytest.mark.parametrize("workers", (1, 2)) |
831 | | -@pytest.mark.parametrize("parallel", (1, 2)) |
832 | | -@pytest.mark.skipif( |
833 | | - "not os.environ.get('DATACHAIN_DISTRIBUTED')", |
834 | | - reason="Set the DATACHAIN_DISTRIBUTED environment variable " |
835 | | - "to test distributed UDFs", |
836 | | -) |
837 | | -@pytest.mark.xdist_group(name="tmpfile") |
838 | | -def test_udf_distributed( |
839 | | - cloud_test_catalog_tmpfile, workers, parallel, tree, run_datachain_worker |
840 | | -): |
841 | | - session = cloud_test_catalog_tmpfile.session |
842 | | - |
843 | | - def name_len(name): |
844 | | - return (len(name),) |
845 | | - |
846 | | - chain = ( |
847 | | - dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session) |
848 | | - .settings(parallel=parallel, workers=workers) |
849 | | - .map(name_len, params=["file.path"], output={"name_len": int}) |
850 | | - .select("file.path", "name_len") |
851 | | - ) |
852 | | - |
853 | | - # Check that the UDF ran successfully |
854 | | - count = 0 |
855 | | - for r in chain: |
856 | | - count += 1 |
857 | | - assert len(r[0]) == r[1] |
858 | | - assert count == 225 |
859 | | - |
860 | | - |
861 | | -@pytest.mark.parametrize( |
862 | | - "cloud_type,version_aware", |
863 | | - [("s3", True)], |
864 | | - indirect=True, |
865 | | -) |
866 | | -def test_class_udf(cloud_test_catalog): |
867 | | - session = cloud_test_catalog.session |
868 | | - |
869 | | - class MyUDF(Mapper): |
870 | | - def __init__(self, constant, multiplier=1): |
871 | | - self.constant = constant |
872 | | - self.multiplier = multiplier |
873 | | - |
874 | | - def process(self, size): |
875 | | - return (self.constant + size * self.multiplier,) |
876 | | - |
877 | | - chain = ( |
878 | | - dc.read_storage(cloud_test_catalog.src_uri, session=session) |
879 | | - .filter(dc.C("file.size") < 13) |
880 | | - .map( |
881 | | - MyUDF(5, multiplier=2), |
882 | | - output={"total": int}, |
883 | | - params=["file.size"], |
884 | | - ) |
885 | | - .select("file.size", "total") |
886 | | - .order_by("file.size") |
887 | | - ) |
888 | | - |
889 | | - assert chain.to_list() == [ |
890 | | - (3, 11), |
891 | | - (4, 13), |
892 | | - (4, 13), |
893 | | - (4, 13), |
894 | | - (4, 13), |
895 | | - (4, 13), |
896 | | - ] |
897 | | - |
898 | | - |
899 | | -@pytest.mark.parametrize( |
900 | | - "cloud_type,version_aware", |
901 | | - [("s3", True)], |
902 | | - indirect=True, |
903 | | -) |
904 | | -@pytest.mark.xdist_group(name="tmpfile") |
905 | | -def test_class_udf_parallel(cloud_test_catalog_tmpfile): |
906 | | - session = cloud_test_catalog_tmpfile.session |
907 | | - |
908 | | - class MyUDF(Mapper): |
909 | | - def __init__(self, constant, multiplier=1): |
910 | | - self.constant = constant |
911 | | - self.multiplier = multiplier |
912 | | - |
913 | | - def process(self, size): |
914 | | - return (self.constant + size * self.multiplier,) |
915 | | - |
916 | | - chain = ( |
917 | | - dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session) |
918 | | - .filter(dc.C("file.size") < 13) |
919 | | - .settings(parallel=2) |
920 | | - .map( |
921 | | - MyUDF(5, multiplier=2), |
922 | | - output={"total": int}, |
923 | | - params=["file.size"], |
924 | | - ) |
925 | | - .select("file.size", "total") |
926 | | - .order_by("file.size") |
927 | | - ) |
928 | | - |
929 | | - assert chain.to_list() == [ |
930 | | - (3, 11), |
931 | | - (4, 13), |
932 | | - (4, 13), |
933 | | - (4, 13), |
934 | | - (4, 13), |
935 | | - (4, 13), |
936 | | - ] |
937 | | - |
938 | | - |
939 | | -@pytest.mark.parametrize( |
940 | | - "cloud_type,version_aware", |
941 | | - [("s3", True)], |
942 | | - indirect=True, |
943 | | -) |
944 | | -@pytest.mark.xdist_group(name="tmpfile") |
945 | | -def test_udf_parallel_exec_error(cloud_test_catalog_tmpfile): |
946 | | - session = cloud_test_catalog_tmpfile.session |
947 | | - |
948 | | - def name_len_error(_name): |
949 | | - # A udf that raises an exception |
950 | | - raise RuntimeError("Test Error!") |
951 | | - |
952 | | - chain = ( |
953 | | - dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session) |
954 | | - .filter(dc.C("file.size") < 13) |
955 | | - .filter(dc.C("file.path").glob("cats*") | (dc.C("file.size") < 4)) |
956 | | - .settings(parallel=True) |
957 | | - .map(name_len_error, params=["file.path"], output={"name_len": int}) |
958 | | - ) |
959 | | - |
960 | | - if os.environ.get("DATACHAIN_DISTRIBUTED"): |
961 | | - # in distributed mode we expect DataChainError with the error message |
962 | | - with pytest.raises(DataChainError, match="Test Error!"): |
963 | | - chain.show() |
964 | | - else: |
965 | | - # while in local mode we expect RuntimeError with the error message |
966 | | - with pytest.raises(RuntimeError, match="UDF Execution Failed!"): |
967 | | - chain.show() |
968 | | - |
969 | | - |
970 | | -@pytest.mark.parametrize( |
971 | | - "cloud_type,version_aware,tree", |
972 | | - [("s3", True, LARGE_TREE)], |
973 | | - indirect=True, |
974 | | -) |
975 | | -@pytest.mark.parametrize("workers", (1, 2)) |
976 | | -@pytest.mark.parametrize("parallel", (1, 2)) |
977 | | -@pytest.mark.skipif( |
978 | | - "not os.environ.get('DATACHAIN_DISTRIBUTED')", |
979 | | - reason="Set the DATACHAIN_DISTRIBUTED environment variable " |
980 | | - "to test distributed UDFs", |
981 | | -) |
982 | | -@pytest.mark.xdist_group(name="tmpfile") |
983 | | -def test_udf_distributed_exec_error( |
984 | | - cloud_test_catalog_tmpfile, workers, parallel, tree, run_datachain_worker |
985 | | -): |
986 | | - session = cloud_test_catalog_tmpfile.session |
987 | | - |
988 | | - def name_len_error(_name): |
989 | | - # A udf that raises an exception |
990 | | - raise RuntimeError("Test Error!") |
991 | | - |
992 | | - chain = ( |
993 | | - dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session) |
994 | | - .filter(dc.C("file.size") < 13) |
995 | | - .filter(dc.C("file.path").glob("cats*") | (dc.C("file.size") < 4)) |
996 | | - .settings(parallel=parallel, workers=workers) |
997 | | - .map(name_len_error, params=["file.path"], output={"name_len": int}) |
998 | | - ) |
999 | | - with pytest.raises(DataChainError, match="Test Error!"): |
1000 | | - chain.show() |
1001 | | - |
1002 | | - |
1003 | | -@pytest.mark.parametrize( |
1004 | | - "cloud_type,version_aware", |
1005 | | - [("s3", True)], |
1006 | | - indirect=True, |
1007 | | -) |
1008 | | -@pytest.mark.xdist_group(name="tmpfile") |
1009 | | -def test_udf_reuse_on_error(cloud_test_catalog_tmpfile): |
1010 | | - session = cloud_test_catalog_tmpfile.session |
1011 | | - |
1012 | | - error_state = {"error": True} |
1013 | | - |
1014 | | - def name_len_maybe_error(path): |
1015 | | - if error_state["error"]: |
1016 | | - # A udf that raises an exception |
1017 | | - raise RuntimeError("Test Error!") |
1018 | | - return (len(path),) |
1019 | | - |
1020 | | - chain = ( |
1021 | | - dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session) |
1022 | | - .filter(dc.C("file.size") < 13) |
1023 | | - .filter(dc.C("file.path").glob("cats*") | (dc.C("file.size") < 4)) |
1024 | | - .map(name_len_maybe_error, params=["file.path"], output={"path_len": int}) |
1025 | | - .select("file.path", "path_len") |
1026 | | - ) |
1027 | | - |
1028 | | - with pytest.raises(DataChainError, match="Test Error!"): |
1029 | | - chain.show() |
1030 | | - |
1031 | | - # Simulate fixing the error |
1032 | | - error_state["error"] = False |
1033 | | - |
1034 | | - # Retry Query |
1035 | | - count = 0 |
1036 | | - for r in chain: |
1037 | | - # Check that the UDF ran successfully |
1038 | | - count += 1 |
1039 | | - assert len(r[0]) == r[1] |
1040 | | - assert count == 3 |
1041 | | - |
1042 | | - |
1043 | | -@pytest.mark.parametrize( |
1044 | | - "cloud_type,version_aware", |
1045 | | - [("s3", True)], |
1046 | | - indirect=True, |
1047 | | -) |
1048 | | -@pytest.mark.xdist_group(name="tmpfile") |
1049 | | -def test_udf_parallel_interrupt(cloud_test_catalog_tmpfile, capfd): |
1050 | | - session = cloud_test_catalog_tmpfile.session |
1051 | | - |
1052 | | - def name_len_interrupt(_name): |
1053 | | - # A UDF that emulates cancellation due to a KeyboardInterrupt. |
1054 | | - raise KeyboardInterrupt |
1055 | | - |
1056 | | - chain = ( |
1057 | | - dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session) |
1058 | | - .filter(dc.C("file.size") < 13) |
1059 | | - .filter(dc.C("file.path").glob("cats*") | (dc.C("file.size") < 4)) |
1060 | | - .settings(parallel=True) |
1061 | | - .map(name_len_interrupt, params=["file.path"], output={"name_len": int}) |
1062 | | - ) |
1063 | | - if os.environ.get("DATACHAIN_DISTRIBUTED"): |
1064 | | - with pytest.raises(KeyboardInterrupt): |
1065 | | - chain.show() |
1066 | | - else: |
1067 | | - with pytest.raises(RuntimeError, match="UDF Execution Failed!"): |
1068 | | - chain.show() |
1069 | | - captured = capfd.readouterr() |
1070 | | - assert "semaphore" not in captured.err |
1071 | | - |
1072 | | - |
1073 | | -@pytest.mark.parametrize( |
1074 | | - "cloud_type,version_aware,tree", |
1075 | | - [("s3", True, LARGE_TREE)], |
1076 | | - indirect=True, |
1077 | | -) |
1078 | | -@pytest.mark.skipif( |
1079 | | - "not os.environ.get('DATACHAIN_DISTRIBUTED')", |
1080 | | - reason="Set the DATACHAIN_DISTRIBUTED environment variable " |
1081 | | - "to test distributed UDFs", |
1082 | | -) |
1083 | | -@pytest.mark.parametrize("workers", (1, 2)) |
1084 | | -@pytest.mark.parametrize("parallel", (1, 2)) |
1085 | | -@pytest.mark.xdist_group(name="tmpfile") |
1086 | | -def test_udf_distributed_interrupt( |
1087 | | - cloud_test_catalog_tmpfile, capfd, tree, workers, parallel, run_datachain_worker |
1088 | | -): |
1089 | | - session = cloud_test_catalog_tmpfile.session |
1090 | | - |
1091 | | - def name_len_interrupt(_name): |
1092 | | - # A UDF that emulates cancellation due to a KeyboardInterrupt. |
1093 | | - raise KeyboardInterrupt |
1094 | | - |
1095 | | - chain = ( |
1096 | | - dc.read_storage(cloud_test_catalog_tmpfile.src_uri, session=session) |
1097 | | - .filter(dc.C("file.size") < 13) |
1098 | | - .filter(dc.C("file.path").glob("cats*") | (dc.C("file.size") < 4)) |
1099 | | - .settings(parallel=parallel, workers=workers) |
1100 | | - .map(name_len_interrupt, params=["file.path"], output={"name_len": int}) |
1101 | | - ) |
1102 | | - with pytest.raises(KeyboardInterrupt): |
1103 | | - chain.show() |
1104 | | - captured = capfd.readouterr() |
1105 | | - assert "semaphore" not in captured.err |
1106 | | - |
1107 | | - |
1108 | 736 | @pytest.mark.parametrize( |
1109 | 737 | "cloud_type,version_aware", |
1110 | 738 | [("s3", True)], |
|
0 commit comments