Skip to content

Commit 8ae9f8d

Browse files
authored
Use cloudpickle for parallel UDF processing (#65)
1 parent 00c846a commit 8ae9f8d

File tree

9 files changed

+373
-129
lines changed

9 files changed

+373
-129
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dependencies = [
3838
"sqlalchemy>=2",
3939
"multiprocess==0.70.16",
4040
"dill==0.3.8",
41+
"cloudpickle",
4142
"ujson>=5.9.0",
4243
"pydantic>=2,<3",
4344
"jmespath>=1.0",

src/datachain/query/dataset.py

Lines changed: 15 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import ast
21
import contextlib
32
import datetime
43
import inspect
@@ -10,7 +9,6 @@
109
import string
1110
import subprocess
1211
import sys
13-
import types
1412
from abc import ABC, abstractmethod
1513
from collections.abc import Generator, Iterable, Iterator, Sequence
1614
from copy import copy
@@ -28,9 +26,7 @@
2826
import attrs
2927
import sqlalchemy
3028
from attrs import frozen
31-
from dill import dumps, source
3229
from fsspec.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
33-
from pydantic import BaseModel
3430
from sqlalchemy import Column
3531
from sqlalchemy.sql import func as f
3632
from sqlalchemy.sql.elements import ColumnClause, ColumnElement
@@ -54,7 +50,11 @@
5450
from datachain.progress import CombinedDownloadCallback
5551
from datachain.sql.functions import rand
5652
from datachain.storage import Storage, StorageURI
57-
from datachain.utils import batched, determine_processes
53+
from datachain.utils import (
54+
batched,
55+
determine_processes,
56+
filtered_cloudpickle_dumps,
57+
)
5858

5959
from .metrics import metrics
6060
from .schema import C, UDFParamSpec, normalize_param
@@ -490,7 +490,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
490490
elif processes:
491491
# Parallel processing (faster for more CPU-heavy UDFs)
492492
udf_info = {
493-
"udf": self.udf,
493+
"udf_data": filtered_cloudpickle_dumps(self.udf),
494494
"catalog_init": self.catalog.get_init_params(),
495495
"id_generator_clone_params": (
496496
self.catalog.id_generator.clone_params()
@@ -511,16 +511,15 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
511511

512512
envs = dict(os.environ)
513513
envs.update({"PYTHONPATH": os.getcwd()})
514-
with self.process_feature_module():
515-
process_data = dumps(udf_info, recurse=True)
516-
result = subprocess.run( # noqa: S603
517-
[datachain_exec_path, "--internal-run-udf"],
518-
input=process_data,
519-
check=False,
520-
env=envs,
521-
)
522-
if result.returncode != 0:
523-
raise RuntimeError("UDF Execution Failed!")
514+
process_data = filtered_cloudpickle_dumps(udf_info)
515+
result = subprocess.run( # noqa: S603
516+
[datachain_exec_path, "--internal-run-udf"],
517+
input=process_data,
518+
check=False,
519+
env=envs,
520+
)
521+
if result.returncode != 0:
522+
raise RuntimeError("UDF Execution Failed!")
524523

525524
else:
526525
# Otherwise process single-threaded (faster for smaller UDFs)
@@ -569,57 +568,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
569568
self.catalog.warehouse.close()
570569
raise
571570

572-
@contextlib.contextmanager
573-
def process_feature_module(self):
574-
# Generate a random name for the feature module
575-
feature_module_name = "tmp" + _random_string(10)
576-
# Create a dynamic module with the generated name
577-
dynamic_module = types.ModuleType(feature_module_name)
578-
# Get the import lines for the necessary objects from the main module
579-
main_module = sys.modules["__main__"]
580-
if getattr(main_module, "__file__", None):
581-
import_lines = list(get_imports(main_module))
582-
else:
583-
import_lines = [
584-
source.getimport(obj, alias=name)
585-
for name, obj in main_module.__dict__.items()
586-
if _imports(obj) and not (name.startswith("__") and name.endswith("__"))
587-
]
588-
589-
# Get the feature classes from the main module
590-
feature_classes = {
591-
name: obj
592-
for name, obj in main_module.__dict__.items()
593-
if _feature_predicate(obj)
594-
}
595-
if not feature_classes:
596-
yield None
597-
return
598-
599-
# Get the source code of the feature classes
600-
feature_sources = [source.getsource(cls) for _, cls in feature_classes.items()]
601-
# Set the module name for the feature classes to the generated name
602-
for name, cls in feature_classes.items():
603-
cls.__module__ = feature_module_name
604-
setattr(dynamic_module, name, cls)
605-
# Add the dynamic module to the sys.modules dictionary
606-
sys.modules[feature_module_name] = dynamic_module
607-
# Combine the import lines and feature sources
608-
feature_file = "\n".join(import_lines) + "\n" + "\n".join(feature_sources)
609-
610-
# Write the module content to a .py file
611-
with open(f"{feature_module_name}.py", "w") as module_file:
612-
module_file.write(feature_file)
613-
614-
try:
615-
yield feature_module_name
616-
finally:
617-
for cls in feature_classes.values():
618-
cls.__module__ = main_module.__name__
619-
os.unlink(f"{feature_module_name}.py")
620-
# Remove the dynamic module from sys.modules
621-
del sys.modules[feature_module_name]
622-
623571
def create_partitions_table(self, query: Select) -> "Table":
624572
"""
625573
Create temporary table with group by partitions.
@@ -1829,34 +1777,3 @@ def _random_string(length: int) -> str:
18291777
random.choice(string.ascii_letters + string.digits) # noqa: S311
18301778
for i in range(length)
18311779
)
1832-
1833-
1834-
def _feature_predicate(obj):
1835-
return (
1836-
inspect.isclass(obj) and source.isfrommain(obj) and issubclass(obj, BaseModel)
1837-
)
1838-
1839-
1840-
def _imports(obj):
1841-
return not source.isfrommain(obj)
1842-
1843-
1844-
def get_imports(m):
1845-
root = ast.parse(inspect.getsource(m))
1846-
1847-
for node in ast.iter_child_nodes(root):
1848-
if isinstance(node, ast.Import):
1849-
module = None
1850-
elif isinstance(node, ast.ImportFrom):
1851-
module = node.module
1852-
else:
1853-
continue
1854-
1855-
for n in node.names:
1856-
import_script = ""
1857-
if module:
1858-
import_script += f"from {module} "
1859-
import_script += f"import {n.name}"
1860-
if n.asname:
1861-
import_script += f" as {n.asname}"
1862-
yield import_script

src/datachain/query/dispatch.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import attrs
1212
import multiprocess
13-
from dill import load
13+
from cloudpickle import load, loads
1414
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
1515
from multiprocess import get_context
1616

@@ -84,7 +84,7 @@ def put_into_queue(queue: Queue, item: Any) -> None:
8484

8585
def udf_entrypoint() -> int:
8686
# Load UDF info from stdin
87-
udf_info = load(stdin.buffer) # noqa: S301
87+
udf_info = load(stdin.buffer)
8888

8989
(
9090
warehouse_class,
@@ -95,7 +95,7 @@ def udf_entrypoint() -> int:
9595

9696
# Parallel processing (faster for more CPU-heavy UDFs)
9797
dispatch = UDFDispatcher(
98-
udf_info["udf"],
98+
udf_info["udf_data"],
9999
udf_info["catalog_init"],
100100
udf_info["id_generator_clone_params"],
101101
udf_info["metastore_clone_params"],
@@ -108,7 +108,7 @@ def udf_entrypoint() -> int:
108108
batching = udf_info["batching"]
109109
table = udf_info["table"]
110110
n_workers = udf_info["processes"]
111-
udf = udf_info["udf"]
111+
udf = loads(udf_info["udf_data"])
112112
if n_workers is True:
113113
# Use default number of CPUs (cores)
114114
n_workers = None
@@ -146,7 +146,7 @@ class UDFDispatcher:
146146

147147
def __init__(
148148
self,
149-
udf,
149+
udf_data,
150150
catalog_init_params,
151151
id_generator_clone_params,
152152
metastore_clone_params,
@@ -155,14 +155,7 @@ def __init__(
155155
is_generator=False,
156156
buffer_size=DEFAULT_BATCH_SIZE,
157157
):
158-
# isinstance cannot be used here, as dill packages the entire class definition,
159-
# and so these two types are not considered exactly equal,
160-
# even if they have the same import path.
161-
if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
162-
self.udf = udf
163-
else:
164-
self.udf = None
165-
self.udf_factory = udf
158+
self.udf_data = udf_data
166159
self.catalog_init_params = catalog_init_params
167160
(
168161
self.id_generator_class,
@@ -214,6 +207,15 @@ def _create_worker(self) -> "UDFWorker":
214207
self.catalog = Catalog(
215208
id_generator, metastore, warehouse, **self.catalog_init_params
216209
)
210+
udf = loads(self.udf_data)
211+
# isinstance cannot be used here, as cloudpickle packages the entire class
212+
# definition, and so these two types are not considered exactly equal,
213+
# even if they have the same import path.
214+
if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
215+
self.udf = udf
216+
else:
217+
self.udf = None
218+
self.udf_factory = udf
217219
if not self.udf:
218220
self.udf = self.udf_factory()
219221

src/datachain/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import glob
22
import importlib.util
3+
import io
34
import json
45
import os
56
import os.path as osp
@@ -13,8 +14,10 @@
1314
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
1415
from uuid import UUID
1516

17+
import cloudpickle
1618
from dateutil import tz
1719
from dateutil.parser import isoparse
20+
from pydantic import BaseModel
1821

1922
if TYPE_CHECKING:
2023
import pandas as pd
@@ -388,3 +391,39 @@ def inside_notebook() -> bool:
388391
return False
389392

390393
return False
394+
395+
396+
def get_all_subclasses(cls):
397+
"""Return all subclasses of a given class.
398+
Can return duplicates due to multiple inheritance."""
399+
for subclass in cls.__subclasses__():
400+
yield from get_all_subclasses(subclass)
401+
yield subclass
402+
403+
404+
def filtered_cloudpickle_dumps(obj: Any) -> bytes:
405+
"""Equivalent to cloudpickle.dumps, but this supports Pydantic models."""
406+
model_namespaces = {}
407+
408+
with io.BytesIO() as f:
409+
pickler = cloudpickle.CloudPickler(f)
410+
411+
for model_class in get_all_subclasses(BaseModel):
412+
# This "is not None" check is needed, because due to multiple inheritance,
413+
# it is theoretically possible to get the same class twice from
414+
# get_all_subclasses.
415+
if model_class.__pydantic_parent_namespace__ is not None:
416+
# __pydantic_parent_namespace__ can contain many unnecessary and
417+
# unpickleable entities, so should be removed for serialization.
418+
model_namespaces[model_class] = (
419+
model_class.__pydantic_parent_namespace__
420+
)
421+
model_class.__pydantic_parent_namespace__ = None
422+
423+
try:
424+
pickler.dump(obj)
425+
return f.getvalue()
426+
finally:
427+
for model_class, namespace in model_namespaces.items():
428+
# Restore original __pydantic_parent_namespace__ locally.
429+
model_class.__pydantic_parent_namespace__ = namespace

0 commit comments

Comments
 (0)