@@ -8,12 +8,8 @@ use std::time::Duration;
88use anyhow:: { anyhow, Context , Result } ;
99use clap:: Parser ;
1010use crossbeam_channel:: { bounded, Receiver , RecvTimeoutError , Sender , TryRecvError } ;
11+ use fastcrawl:: embedder:: openai:: OpenAiEmbedder ;
1112use fastcrawl:: { EmbeddedChunkRecord , ManifestRecord , NormalizedPage } ;
12- use reqwest:: blocking:: Client ;
13- use reqwest:: header:: { HeaderMap , HeaderValue , AUTHORIZATION , CONTENT_TYPE } ;
14- use reqwest:: StatusCode ;
15- use serde:: Deserialize ;
16- use serde:: Serialize ;
1713
1814#[ derive( Parser , Debug ) ]
1915#[ command(
@@ -456,144 +452,3 @@ struct EmbeddingBatchResult {
456452}
457453
458454type EmbeddingResult = Result < EmbeddingBatchResult > ;
459-
460- #[ derive( Clone ) ]
461- struct OpenAiEmbedder {
462- client : Client ,
463- endpoint : String ,
464- model : String ,
465- dimensions : Option < usize > ,
466- max_retries : usize ,
467- batch_size : usize ,
468- }
469-
470- impl OpenAiEmbedder {
471- fn new (
472- api_key : String ,
473- base_url : String ,
474- model : String ,
475- dimensions : Option < usize > ,
476- timeout : Duration ,
477- max_retries : usize ,
478- batch_size : usize ,
479- ) -> Result < Self > {
480- anyhow:: ensure!( !api_key. trim( ) . is_empty( ) , "missing OpenAI API key" ) ;
481- anyhow:: ensure!( !model. trim( ) . is_empty( ) , "missing OpenAI model name" ) ;
482- let mut headers = HeaderMap :: new ( ) ;
483- let auth = format ! ( "Bearer {}" , api_key. trim( ) ) ;
484- headers. insert (
485- AUTHORIZATION ,
486- HeaderValue :: from_str ( & auth) . context ( "invalid OpenAI API key" ) ?,
487- ) ;
488- headers. insert ( CONTENT_TYPE , HeaderValue :: from_static ( "application/json" ) ) ;
489- let client = Client :: builder ( )
490- . timeout ( timeout)
491- . default_headers ( headers)
492- . build ( )
493- . context ( "failed to build OpenAI HTTP client" ) ?;
494- let endpoint = format ! ( "{}/embeddings" , base_url. trim_end_matches( '/' ) ) ;
495- Ok ( Self {
496- client,
497- endpoint,
498- model,
499- dimensions,
500- max_retries,
501- batch_size,
502- } )
503- }
504-
505- fn embed_batch ( & self , inputs : & [ & str ] ) -> Result < Vec < Vec < f32 > > > {
506- if inputs. is_empty ( ) {
507- return Ok ( Vec :: new ( ) ) ;
508- }
509- anyhow:: ensure!(
510- inputs. len( ) <= self . batch_size,
511- "batch of {} exceeds configured max {}" ,
512- inputs. len( ) ,
513- self . batch_size
514- ) ;
515-
516- let mut attempt = 0usize ;
517- loop {
518- let request = EmbeddingRequest {
519- model : & self . model ,
520- input : inputs,
521- dimensions : self . dimensions ,
522- } ;
523- let response = self . client . post ( & self . endpoint ) . json ( & request) . send ( ) ;
524- match response {
525- Ok ( resp) => {
526- let status = resp. status ( ) ;
527- if status. is_success ( ) {
528- let mut parsed: EmbeddingResponse = resp
529- . json ( )
530- . context ( "failed to parse OpenAI embedding response" ) ?;
531- parsed. data . sort_by_key ( |entry| entry. index ) ;
532- anyhow:: ensure!(
533- parsed. data. len( ) == inputs. len( ) ,
534- "OpenAI returned {} embeddings for {} inputs" ,
535- parsed. data. len( ) ,
536- inputs. len( )
537- ) ;
538- return Ok ( parsed
539- . data
540- . into_iter ( )
541- . map ( |entry| entry. embedding )
542- . collect ( ) ) ;
543- }
544-
545- let body = resp
546- . text ( )
547- . unwrap_or_else ( |_| "<body unavailable>" . to_string ( ) ) ;
548- if self . should_retry ( status) && attempt + 1 < self . max_retries {
549- attempt += 1 ;
550- thread:: sleep ( self . retry_backoff ( attempt) ) ;
551- continue ;
552- }
553- anyhow:: bail!( "OpenAI embeddings request failed ({}): {}" , status, body) ;
554- }
555- Err ( err) => {
556- if self . is_retryable_error ( & err) && attempt + 1 < self . max_retries {
557- attempt += 1 ;
558- thread:: sleep ( self . retry_backoff ( attempt) ) ;
559- continue ;
560- }
561- return Err ( err. into ( ) ) ;
562- }
563- }
564- }
565- }
566-
567- fn should_retry ( & self , status : StatusCode ) -> bool {
568- status == StatusCode :: TOO_MANY_REQUESTS || status. is_server_error ( )
569- }
570-
571- fn is_retryable_error ( & self , err : & reqwest:: Error ) -> bool {
572- err. is_timeout ( ) || err. is_connect ( ) || err. is_body ( ) || err. is_request ( ) || err. is_decode ( )
573- }
574-
575- fn retry_backoff ( & self , attempt : usize ) -> Duration {
576- let capped = attempt. min ( 5 ) as u32 ;
577- Duration :: from_millis ( 500 * ( 1 << capped) )
578- }
579- }
580-
581- #[ derive( Serialize ) ]
582- struct EmbeddingRequest < ' a > {
583- model : & ' a str ,
584- #[ serde( borrow) ]
585- input : & ' a [ & ' a str ] ,
586- #[ serde( skip_serializing_if = "Option::is_none" ) ]
587- dimensions : Option < usize > ,
588- }
589-
590- #[ derive( Debug , Deserialize ) ]
591- struct EmbeddingResponse {
592- data : Vec < EmbeddingData > ,
593- }
594-
595- #[ derive( Debug , Deserialize ) ]
596- struct EmbeddingData {
597- embedding : Vec < f32 > ,
598- index : usize ,
599- }
0 commit comments