1515
1616from  __future__ import  annotations 
1717
18+ import  hashlib 
1819import  logging 
1920import  posixpath 
2021import  threading 
22+ from  collections  import  OrderedDict 
2123from  concurrent .futures  import  Future , ThreadPoolExecutor 
2224from  contextlib  import  ExitStack 
2325from  dataclasses  import  asdict , dataclass 
@@ -73,8 +75,16 @@ class CompletionRefs:
7375
7476JsonEncodeable  =  list [dict [str , Any ]]
7577
76- # mapping of upload path to function computing upload data dict 
77- UploadData  =  dict [str , Callable [[], JsonEncodeable ]]
78+ # mapping of upload path and whether the contents were hashed to the filename to function computing upload data dict 
79+ UploadData  =  dict [tuple [str , bool ], Callable [[], JsonEncodeable ]]
80+ 
81+ 
82+ def  is_system_instructions_hashable (
83+     system_instruction : list [types .MessagePart ] |  None ,
84+ ) ->  bool :
85+     return  bool (system_instruction ) and  all (
86+         isinstance (x , types .Text ) for  x  in  system_instruction 
87+     )
7888
7989
8090class  UploadCompletionHook (CompletionHook ):
@@ -97,10 +107,13 @@ def __init__(
97107        base_path : str ,
98108        max_size : int  =  20 ,
99109        upload_format : Format  |  None  =  None ,
110+         lru_cache_max_size : int  =  1024 ,
100111    ) ->  None :
101112        self ._max_size  =  max_size 
102113        self ._fs , base_path  =  fsspec .url_to_fs (base_path )
103114        self ._base_path  =  self ._fs .unstrip_protocol (base_path )
115+         self .lru_dict : OrderedDict [str , bool ] =  OrderedDict ()
116+         self .lru_cache_max_size  =  lru_cache_max_size 
104117
105118        if  upload_format  not  in _FORMATS  +  (None ,):
106119            raise  ValueError (
@@ -132,7 +145,10 @@ def done(future: Future[None]) -> None:
132145            finally :
133146                self ._semaphore .release ()
134147
135-         for  path , json_encodeable  in  upload_data .items ():
148+         for  (
149+             path ,
150+             contents_hashed_to_filename ,
151+         ), json_encodeable  in  upload_data .items ():
136152            # could not acquire, drop data 
137153            if  not  self ._semaphore .acquire (blocking = False ):  # pylint: disable=consider-using-with 
138154                _logger .warning (
@@ -143,7 +159,10 @@ def done(future: Future[None]) -> None:
143159
144160            try :
145161                fut  =  self ._executor .submit (
146-                     self ._do_upload , path , json_encodeable 
162+                     self ._do_upload ,
163+                     path ,
164+                     contents_hashed_to_filename ,
165+                     json_encodeable ,
147166                )
148167                fut .add_done_callback (done )
149168            except  RuntimeError :
@@ -152,10 +171,20 @@ def done(future: Future[None]) -> None:
152171                )
153172                self ._semaphore .release ()
154173
155-     def  _calculate_ref_path (self ) ->  CompletionRefs :
174+     def  _calculate_ref_path (
175+         self , system_instruction : list [types .MessagePart ]
176+     ) ->  CompletionRefs :
156177        # TODO: experimental with using the trace_id and span_id, or fetching 
157178        # gen_ai.response.id from the active span. 
158- 
179+         system_instruction_hash  =  None 
180+         if  is_system_instructions_hashable (system_instruction ):
181+             # Get a hash of the text. 
182+             system_instruction_hash  =  hashlib .sha256 (
183+                 "\n " .join (x .content  for  x  in  system_instruction ).encode (  # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue, reportUnknownArgumentType] 
184+                     "utf-8" 
185+                 ),
186+                 usedforsecurity = False ,
187+             ).hexdigest ()
159188        uuid_str  =  str (uuid4 ())
160189        return  CompletionRefs (
161190            inputs_ref = posixpath .join (
@@ -166,13 +195,32 @@ def _calculate_ref_path(self) -> CompletionRefs:
166195            ),
167196            system_instruction_ref = posixpath .join (
168197                self ._base_path ,
169-                 f"{ uuid_str } { self ._format }  ,
198+                 f"{ system_instruction_hash   or   uuid_str } { self ._format }  ,
170199            ),
171200        )
172201
202+     def  _file_exists (self , path : str ) ->  bool :
203+         if  path  in  self .lru_dict :
204+             self .lru_dict .move_to_end (path )
205+             return  True 
206+         # https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.exists 
207+         file_exists  =  self ._fs .exists (path )
208+         # don't cache this because soon the file will exist.. 
209+         if  not  file_exists :
210+             return  False 
211+         self .lru_dict [path ] =  True 
212+         if  len (self .lru_dict ) >  self .lru_cache_max_size :
213+             self .lru_dict .popitem (last = False )
214+         return  True 
215+ 
173216    def  _do_upload (
174-         self , path : str , json_encodeable : Callable [[], JsonEncodeable ]
217+         self ,
218+         path : str ,
219+         contents_hashed_to_filename : bool ,
220+         json_encodeable : Callable [[], JsonEncodeable ],
175221    ) ->  None :
222+         if  contents_hashed_to_filename  and  self ._file_exists (path ):
223+             return 
176224        if  self ._format  ==  "json" :
177225            # output as a single line with the json messages array 
178226            message_lines  =  [json_encodeable ()]
@@ -194,6 +242,11 @@ def _do_upload(
194242                gen_ai_json_dump (message , file )
195243                file .write ("\n " )
196244
245+         if  contents_hashed_to_filename :
246+             self .lru_dict [path ] =  True 
247+             if  len (self .lru_dict ) >  self .lru_cache_max_size :
248+                 self .lru_dict .popitem (last = False )
249+ 
197250    def  on_completion (
198251        self ,
199252        * ,
@@ -213,7 +266,7 @@ def on_completion(
213266            system_instruction = system_instruction  or  None ,
214267        )
215268        # generate the paths to upload to 
216-         ref_names  =  self ._calculate_ref_path ()
269+         ref_names  =  self ._calculate_ref_path (system_instruction )
217270
218271        def  to_dict (
219272            dataclass_list : list [types .InputMessage ]
@@ -223,35 +276,40 @@ def to_dict(
223276            return  [asdict (dc ) for  dc  in  dataclass_list ]
224277
225278        references  =  [
226-             (ref_name , ref , ref_attr )
227-             for  ref_name , ref , ref_attr  in  [
279+             (ref_name , ref , ref_attr ,  contents_hashed_to_filename )
280+             for  ref_name , ref , ref_attr ,  contents_hashed_to_filename  in  [
228281                (
229282                    ref_names .inputs_ref ,
230283                    completion .inputs ,
231284                    GEN_AI_INPUT_MESSAGES_REF ,
285+                     False ,
232286                ),
233287                (
234288                    ref_names .outputs_ref ,
235289                    completion .outputs ,
236290                    GEN_AI_OUTPUT_MESSAGES_REF ,
291+                     False ,
237292                ),
238293                (
239294                    ref_names .system_instruction_ref ,
240295                    completion .system_instruction ,
241296                    GEN_AI_SYSTEM_INSTRUCTIONS_REF ,
297+                     is_system_instructions_hashable (
298+                         completion .system_instruction 
299+                     ),
242300                ),
243301            ]
244-             if  ref 
302+             if  ref    # Filter out empty input/output/sys instruction 
245303        ]
246304        self ._submit_all (
247305            {
248-                 ref_name : partial (to_dict , ref )
249-                 for  ref_name , ref , _  in  references 
306+                 ( ref_name ,  contents_hashed_to_filename ) : partial (to_dict , ref )
307+                 for  ref_name , ref , _ ,  contents_hashed_to_filename  in  references 
250308            }
251309        )
252310
253311        # stamp the refs on telemetry 
254-         references  =  {ref_attr : name  for  name , _ , ref_attr  in  references }
312+         references  =  {ref_attr : name  for  name , _ , ref_attr ,  _  in  references }
255313        if  span :
256314            span .set_attributes (references )
257315        if  log_record :
0 commit comments