1111
1212from __future__ import annotations
1313
14+ import copy
1415import heapq
16+ import itertools
1517from abc import ABC , abstractmethod
1618from collections import defaultdict
1719from dataclasses import dataclass
2022 Any ,
2123 Callable ,
2224 Dict ,
23- Generic ,
2425 List ,
2526 NamedTuple ,
2627 Optional ,
3738from typing_extensions import cast , dataclass_transform
3839
3940from .categories import Categories
40- from .schema import Field , Schema , Semantic
41+ from .schema import AttributeSpec , Field , Schema , Semantic
42+ from .transform import Transform
4143
4244TField = TypeVar ("TField" , bound = Field )
4345
@@ -50,30 +52,10 @@ class ConversionPaths(NamedTuple):
5052 while lazy converters must be deferred and applied at sample access time.
5153 """
5254
53- batch_converters : List ["Converter" ]
54- lazy_converters : Dict [str , List ["Converter" ]]
55-
56-
57- @dataclass (frozen = True )
58- class AttributeSpec (Generic [TField ]):
59- """
60- Specification for an attribute used in converters.
61-
62- Links an attribute name with its corresponding field type definition,
63- providing the complete specification needed for converter operations.
64-
65- Args:
66- TField: The specific Field type, defaults to Field
67-
68- Attributes:
69- name: The attribute name
70- field: The field type specification
71- categories: Optional categories information (e.g., LabelCategories, MaskCategories)
72- """
73-
74- name : str
75- field : TField
76- categories : Optional [Categories ] = None
55+ converters : Dict [str , List ["Converter" ]]
56+ lazy_outputs : Dict [str , List ["Converter" ]]
57+ required_inputs_by_output : dict [str , set [str ]]
58+ dependent_outputs_by_input : dict [str , set [str ]]
7759
7860
7961@dataclass_transform ()
@@ -960,71 +942,140 @@ def _separate_batch_and_lazy_converters(
960942 ConversionPaths with separated batch and lazy converter lists
961943 """
962944 if not conversion_path :
963- return ConversionPaths (batch_converters = [], lazy_converters = {})
964-
965- # Track which converters must be lazy
966- lazy_indices : Set [int ] = set ()
945+ return ConversionPaths (
946+ converters = {},
947+ lazy_outputs = {},
948+ required_inputs_by_output = {},
949+ dependent_outputs_by_input = {},
950+ )
967951
952+ # Track which outputs must be lazy
968953 lazy_fields : dict [str , bool ] = defaultdict (
969954 bool
970955 ) # Maps fields whether they were produced lazily
971956
957+ required_inputs_by_output : dict [str , set [str ]] = defaultdict (set )
958+
972959 for i , converter in enumerate (conversion_path ):
973960 lazy = False
961+ input_specs = converter .get_input_attr_specs ()
974962
975963 if converter .lazy :
976964 # Mark all intrinsically lazy converters as lazy
977965 lazy = True
978966 else :
979967 # Check whether the converter depends on a lazy converter
980- input_specs = converter .get_input_attr_specs ()
981968 for attr_spec in input_specs :
982969 if attr_spec .name in lazy_fields :
983970 lazy = True
984971 break
985972
986- if lazy :
987- lazy_indices .add (i )
973+ output_specs = converter .get_output_attr_specs ()
988974
975+ if lazy :
989976 # Mark all output fields as lazy
990- output_specs = converter .get_output_attr_specs ()
991977 for attr_spec in output_specs :
992978 lazy_fields [attr_spec .name ] = True
993979
994- # Collect batch converters (non-lazy ones)
995- batch_converters : List [Converter ] = []
996- for i , converter in enumerate (conversion_path ):
997- if i not in lazy_indices :
998- batch_converters .append (converter )
980+ required_inputs = [
981+ required_inputs_by_output [attr_spec .name ]
982+ if attr_spec .name in required_inputs_by_output
983+ else {attr_spec .name }
984+ for attr_spec in input_specs
985+ ]
986+ flattened_required_inputs = set (itertools .chain (* required_inputs ))
987+ for attr_spec in output_specs :
988+ required_inputs_by_output [attr_spec .name ] = flattened_required_inputs
999989
1000990 # Collect lazy converters by output attribute
1001- lazy_converters_by_output : Dict [str , List [Converter ]] = defaultdict (list )
991+ converters_by_output : Dict [str , List [Converter ]] = defaultdict (list )
1002992
1003993 # Iterate through converters in reverse to propagate output dependencies
1004- dependents_by_output : Dict [str , Set [Converter ]] = defaultdict (set )
994+ dependents_by_output : Dict [str , Set [str ]] = defaultdict (set )
1005995
1006996 for i , converter in reversed (list (enumerate (conversion_path ))):
1007- if i in lazy_indices :
1008- # This is a lazy converter - track its outputs
1009- dependents = set ()
997+ # This is a lazy converter - track its outputs
998+ dependents = set ()
1010999
1011- output_specs = converter .get_output_attr_specs ()
1012- for attr_spec in output_specs :
1013- dependents .update (dependents_by_output .get (attr_spec .name , []))
1014- dependents .add (attr_spec .name )
1000+ output_specs = converter .get_output_attr_specs ()
1001+ for attr_spec in output_specs :
1002+ dependents .update (dependents_by_output .get (attr_spec .name , []))
1003+ dependents .add (attr_spec .name )
10151004
1016- for dependent in dependents :
1017- lazy_converters_by_output [dependent ].append (converter )
1005+ for dependent in dependents :
1006+ converters_by_output [dependent ].append (converter )
10181007
1019- # Propagate dependencies from outputs to inputs
1020- input_specs = converter .get_input_attr_specs ()
1021- for input_spec in input_specs :
1022- dependents_by_output [input_spec .name ].update (dependents )
1008+ # Propagate dependencies from outputs to inputs
1009+ input_specs = converter .get_input_attr_specs ()
1010+ for input_spec in input_specs :
1011+ dependents_by_output [input_spec .name ].update (dependents )
10231012
10241013 # Reverse all chains to get dependencies-first order
1025- for output_name , chain in lazy_converters_by_output .items ():
1026- lazy_converters_by_output [output_name ] = list (reversed (chain ))
1014+ for output_name , chain in converters_by_output .items ():
1015+ converters_by_output [output_name ] = list (reversed (chain ))
10271016
10281017 return ConversionPaths (
1029- batch_converters = batch_converters , lazy_converters = lazy_converters_by_output
1018+ converters = converters_by_output ,
1019+ lazy_outputs = lazy_fields ,
1020+ required_inputs_by_output = required_inputs_by_output ,
1021+ dependent_outputs_by_input = dependents_by_output ,
10301022 )
1023+
1024+
1025+ class ConverterTransform (Transform ):
1026+ def __init__ (self , parent : Transform , schema : Schema , conversion_paths : ConversionPaths ):
1027+ super ().__init__ (schema )
1028+
1029+ lazy_inputs = parent .get_lazy_attributes ()
1030+
1031+ lazy_outputs = set (conversion_paths .lazy_outputs )
1032+ for input in lazy_inputs :
1033+ lazy_outputs .update (conversion_paths .dependent_outputs_by_input [input ])
1034+ self ._lazy_outputs = lazy_outputs
1035+
1036+ batch_outputs = self .get_batch_attributes ()
1037+
1038+ self ._parent = parent
1039+ self ._conversion_paths = conversion_paths
1040+ self ._df_input_columns = set ()
1041+ self ._df = pl .DataFrame ()
1042+ self ._applied_converters = set ()
1043+
1044+ self .apply (batch_outputs )
1045+
1046+ def apply (self , fields : Sequence [str ]) -> pl .DataFrame :
1047+ required_inputs = set ()
1048+ for field in fields :
1049+ if field in self ._conversion_paths .converters :
1050+ required_inputs .update (self ._conversion_paths .required_inputs_by_output [field ])
1051+
1052+ parent_df = self ._parent .apply (required_inputs )
1053+ input_columns = set (parent_df .columns )
1054+ new_columns = set (parent_df .columns ) - self ._df_input_columns
1055+
1056+ self ._df = self ._df .with_columns (parent_df .select (new_columns ))
1057+ self ._df_input_columns = input_columns
1058+
1059+ for field in fields :
1060+ converters = self ._conversion_paths .converters .get (field , None )
1061+
1062+ if converters is not None :
1063+ for converter in converters :
1064+ if id (converter ) not in self ._applied_converters :
1065+ self ._df = converter .convert (self ._df )
1066+ self ._applied_converters .add (id (converter ))
1067+
1068+ return self ._df
1069+
1070+ def get_lazy_attributes (self ) -> set [str ]:
1071+ return self ._lazy_outputs
1072+
1073+ def slice (self , offset : int , length : int | None = None ) -> "Transform" :
1074+ instance = copy .copy (self )
1075+ instance ._parent = self ._parent .slice (offset , length )
1076+ instance ._applied_converters = copy .copy (self ._applied_converters )
1077+ instance ._df = self ._df .slice (offset , length )
1078+ return instance
1079+
1080+ def __len__ (self ):
1081+ return len (self ._df )
0 commit comments