11from abc import ABC , abstractmethod
2+ from copy import copy as lite_copy
23from dataclasses import dataclass , field
34import requests
45import os
@@ -22,7 +23,7 @@ class TransformArgs():
2223 trans_node : Optional [bool ] = None
2324 num_workers : int = 0
2425 kwargs : Dict = field (default_factory = dict )
25- pattern : Optional [str ] = None
26+ pattern : Optional [Union [ str , Callable [[ str ], bool ]] ] = None
2627
2728 @staticmethod
2829 def from_dict (d ):
@@ -73,6 +74,7 @@ def split_text_keep_separator(text: str, separator: str) -> List[str]:
7374class NodeTransform (ABC ):
7475 def __init__ (self , num_workers : int = 0 ):
7576 self ._number_workers = num_workers
77+ self ._name = None
7678
7779 def batch_forward (
7880 self , documents : Union [DocNode , List [DocNode ]], node_group : str , ** kwargs
@@ -100,36 +102,43 @@ def impl(node: DocNode):
100102 def transform (self , document : DocNode , ** kwargs ) -> List [Union [str , DocNode ]]:
101103 raise NotImplementedError ('Not implemented' )
102104
105+ def with_name (self , name : Optional [str ], * , copy : bool = True ) -> 'NodeTransform' :
106+ if name is not None :
107+ if copy : return lite_copy (self ).with_name (name , copy = False )
108+ self ._name = name
109+ return self
110+
103111 def __call__ (self , node : DocNode , ** kwargs : Any ) -> List [DocNode ]:
104112 # Parent and child should not be set here.
105113 results = self .transform (node , ** kwargs )
106114 if isinstance (results , (DocNode , str )): results = [results ]
107115 return [DocNode (text = chunk ) if isinstance (chunk , str ) else chunk for chunk in results if chunk ]
108116
109117
110- def make_transform (t ) :
118+ def make_transform (t : Union [ TransformArgs , Dict [ str , Any ]], group_name : Optional [ str ] = None ) -> NodeTransform :
111119 if isinstance (t , dict ): t = TransformArgs .from_dict (t )
112120 transform , trans_node , num_workers = t ['f' ], t ['trans_node' ], t ['num_workers' ]
113121 num_workers = dict (num_workers = num_workers ) if num_workers > 0 else dict ()
114- return (transform (** t ['kwargs' ], ** num_workers )
115- if isinstance (transform , type )
116- else transform if isinstance (transform , NodeTransform )
117- else FuncNodeTransform (transform , trans_node = trans_node , ** num_workers ))
122+ return (transform (** t ['kwargs' ], ** num_workers ).with_name (group_name , copy = False ) if isinstance (transform , type )
123+ else transform .with_name (group_name ) if isinstance (transform , NodeTransform )
124+ else FuncNodeTransform (transform , trans_node = trans_node , ** num_workers ).with_name (group_name , copy = False ))
118125
119126
120127class AdaptiveTransform (NodeTransform ):
121- def __init__ (self , transforms : Union [List [TransformArgs ], TransformArgs ]):
122- super ().__init__ (num_workers = 0 )
128+ def __init__ (self , transforms : Union [List [Union [TransformArgs , Dict ]], Union [TransformArgs , Dict ]],
129+ num_workers : int = 0 ):
130+ super ().__init__ (num_workers = num_workers )
123131 if not isinstance (transforms , (tuple , list )): transforms = [transforms ]
124132 self ._transformers = [(t .get ('pattern' ), make_transform (t )) for t in transforms ]
125133
126134 def transform (self , document : DocNode , ** kwargs ) -> List [Union [str , DocNode ]]:
135+ if not isinstance (document , DocNode ): LOG .warning (f'Invalud document type { type (document )} got' )
127136 for pt , transform in self ._transformers :
128- if pt and not pt .startswith ('*' ): pt = os .path .join (str (os .cwd ()), pt )
129- if not isinstance (document , DocNode ):
130- LOG .warning (f'Invalud document type { type (document )} got' )
131- if not pt or fnmatch .fnmatch (document .docpath , pt ):
137+ if pt and isinstance (pt , str ) and not pt .startswith ('*' ): pt = os .path .join (str (os .cwd ()), pt )
138+ if not pt or (callable (pt ) and pt (document .docpath )) or (
139+ isinstance (pt , str ) and fnmatch .fnmatch (document .docpath , pt )):
132140 return transform (document , ** kwargs )
141+ LOG .warning (f'No transform found for document { document .docpath } with group name `{ self ._name } `' )
133142 return []
134143
135144
0 commit comments