1- import ast
21import contextlib
32import datetime
43import inspect
109import string
1110import subprocess
1211import sys
13- import types
1412from abc import ABC , abstractmethod
1513from collections .abc import Generator , Iterable , Iterator , Sequence
1614from copy import copy
2826import attrs
2927import sqlalchemy
3028from attrs import frozen
31- from dill import dumps , source
3229from fsspec .callbacks import DEFAULT_CALLBACK , Callback , TqdmCallback
33- from pydantic import BaseModel
3430from sqlalchemy import Column
3531from sqlalchemy .sql import func as f
3632from sqlalchemy .sql .elements import ColumnClause , ColumnElement
5450from datachain .progress import CombinedDownloadCallback
5551from datachain .sql .functions import rand
5652from datachain .storage import Storage , StorageURI
57- from datachain .utils import batched , determine_processes
53+ from datachain .utils import (
54+ batched ,
55+ determine_processes ,
56+ filtered_cloudpickle_dumps ,
57+ )
5858
5959from .metrics import metrics
6060from .schema import C , UDFParamSpec , normalize_param
@@ -490,7 +490,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
490490 elif processes :
491491 # Parallel processing (faster for more CPU-heavy UDFs)
492492 udf_info = {
493- "udf " : self .udf ,
493+ "udf_data " : filtered_cloudpickle_dumps ( self .udf ) ,
494494 "catalog_init" : self .catalog .get_init_params (),
495495 "id_generator_clone_params" : (
496496 self .catalog .id_generator .clone_params ()
@@ -511,16 +511,15 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
511511
512512 envs = dict (os .environ )
513513 envs .update ({"PYTHONPATH" : os .getcwd ()})
514- with self .process_feature_module ():
515- process_data = dumps (udf_info , recurse = True )
516- result = subprocess .run ( # noqa: S603
517- [datachain_exec_path , "--internal-run-udf" ],
518- input = process_data ,
519- check = False ,
520- env = envs ,
521- )
522- if result .returncode != 0 :
523- raise RuntimeError ("UDF Execution Failed!" )
514+ process_data = filtered_cloudpickle_dumps (udf_info )
515+ result = subprocess .run ( # noqa: S603
516+ [datachain_exec_path , "--internal-run-udf" ],
517+ input = process_data ,
518+ check = False ,
519+ env = envs ,
520+ )
521+ if result .returncode != 0 :
522+ raise RuntimeError ("UDF Execution Failed!" )
524523
525524 else :
526525 # Otherwise process single-threaded (faster for smaller UDFs)
@@ -569,57 +568,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
569568 self .catalog .warehouse .close ()
570569 raise
571570
572- @contextlib .contextmanager
573- def process_feature_module (self ):
574- # Generate a random name for the feature module
575- feature_module_name = "tmp" + _random_string (10 )
576- # Create a dynamic module with the generated name
577- dynamic_module = types .ModuleType (feature_module_name )
578- # Get the import lines for the necessary objects from the main module
579- main_module = sys .modules ["__main__" ]
580- if getattr (main_module , "__file__" , None ):
581- import_lines = list (get_imports (main_module ))
582- else :
583- import_lines = [
584- source .getimport (obj , alias = name )
585- for name , obj in main_module .__dict__ .items ()
586- if _imports (obj ) and not (name .startswith ("__" ) and name .endswith ("__" ))
587- ]
588-
589- # Get the feature classes from the main module
590- feature_classes = {
591- name : obj
592- for name , obj in main_module .__dict__ .items ()
593- if _feature_predicate (obj )
594- }
595- if not feature_classes :
596- yield None
597- return
598-
599- # Get the source code of the feature classes
600- feature_sources = [source .getsource (cls ) for _ , cls in feature_classes .items ()]
601- # Set the module name for the feature classes to the generated name
602- for name , cls in feature_classes .items ():
603- cls .__module__ = feature_module_name
604- setattr (dynamic_module , name , cls )
605- # Add the dynamic module to the sys.modules dictionary
606- sys .modules [feature_module_name ] = dynamic_module
607- # Combine the import lines and feature sources
608- feature_file = "\n " .join (import_lines ) + "\n " + "\n " .join (feature_sources )
609-
610- # Write the module content to a .py file
611- with open (f"{ feature_module_name } .py" , "w" ) as module_file :
612- module_file .write (feature_file )
613-
614- try :
615- yield feature_module_name
616- finally :
617- for cls in feature_classes .values ():
618- cls .__module__ = main_module .__name__
619- os .unlink (f"{ feature_module_name } .py" )
620- # Remove the dynamic module from sys.modules
621- del sys .modules [feature_module_name ]
622-
623571 def create_partitions_table (self , query : Select ) -> "Table" :
624572 """
625573 Create temporary table with group by partitions.
@@ -1829,34 +1777,3 @@ def _random_string(length: int) -> str:
18291777 random .choice (string .ascii_letters + string .digits ) # noqa: S311
18301778 for i in range (length )
18311779 )
1832-
1833-
1834- def _feature_predicate (obj ):
1835- return (
1836- inspect .isclass (obj ) and source .isfrommain (obj ) and issubclass (obj , BaseModel )
1837- )
1838-
1839-
1840- def _imports (obj ):
1841- return not source .isfrommain (obj )
1842-
1843-
1844- def get_imports (m ):
1845- root = ast .parse (inspect .getsource (m ))
1846-
1847- for node in ast .iter_child_nodes (root ):
1848- if isinstance (node , ast .Import ):
1849- module = None
1850- elif isinstance (node , ast .ImportFrom ):
1851- module = node .module
1852- else :
1853- continue
1854-
1855- for n in node .names :
1856- import_script = ""
1857- if module :
1858- import_script += f"from { module } "
1859- import_script += f"import { n .name } "
1860- if n .asname :
1861- import_script += f" as { n .asname } "
1862- yield import_script
0 commit comments