@@ -40,6 +40,7 @@ def __init__(
4040 self .parse_html = (
4141 True if node_config is None else node_config .get ("parse_html" , True )
4242 )
43+ self .llm_model = node_config ['llm_model' ]
4344
4445 def execute (self , state : dict ) -> dict :
4546 """
@@ -64,31 +65,33 @@ def execute(self, state: dict) -> dict:
6465 input_data = [state [key ] for key in input_keys ]
6566 docs_transformed = input_data [0 ]
6667
68+ def count_tokens (text ):
69+ from ..utils import token_count
70+ return token_count (text , self .llm_model .model_name )
71+
6772 if self .parse_html :
6873 docs_transformed = Html2TextTransformer ().transform_documents (input_data [0 ])
6974 docs_transformed = docs_transformed [0 ]
7075
7176 chunks = chunk (text = docs_transformed .page_content ,
7277 chunk_size = self .node_config .get ("chunk_size" , 4096 )- 250 ,
73- token_counter = lambda text : len ( text . split ()) ,
78+ token_counter = count_tokens ,
7479 memoize = False )
7580 else :
7681 docs_transformed = docs_transformed [0 ]
77-
7882 chunk_size = self .node_config .get ("chunk_size" , 4096 )
7983 chunk_size = min (chunk_size - 500 , int (chunk_size * 0.9 ))
8084
8185 if isinstance (docs_transformed , Document ):
8286 chunks = chunk (text = docs_transformed .page_content ,
8387 chunk_size = chunk_size ,
84- token_counter = lambda text : len ( text . split ()) ,
88+ token_counter = count_tokens ,
8589 memoize = False )
8690 else :
8791 chunks = chunk (text = docs_transformed ,
8892 chunk_size = chunk_size ,
89- token_counter = lambda text : len ( text . split ()) ,
93+ token_counter = count_tokens ,
9094 memoize = False )
9195
9296 state .update ({self .output [0 ]: chunks })
93-
9497 return state
0 commit comments