1717from tuneapi .apis .turbo import distributed_chat , distributed_chat_async
1818
1919
20- class Openai (tt .ModelInterface ):
20+ class OpenAIProtocol (tt .ModelInterface ):
2121 def __init__ (
2222 self ,
23- id : str = "gpt-4o" ,
24- base_url : str = "https://api.openai.com/v1/chat/completions" ,
25- extra_headers : Optional [Dict [str , str ]] = None ,
26- api_token : Optional [str ] = None ,
27- emebdding_url : Optional [str ] = None ,
28- image_gen_url : Optional [str ] = None ,
29- audio_transcribe : Optional [str ] = None ,
30- audio_gen_url : Optional [str ] = None ,
23+ id : str ,
24+ base_url : str ,
25+ extra_headers : Optional [Dict [str , str ]],
26+ api_token : Optional [str ],
27+ emebdding_url : Optional [str ],
28+ image_gen_url : Optional [str ],
29+ audio_transcribe_url : Optional [str ],
30+ audio_gen_url : Optional [str ],
3131 ):
3232 self .model_id = id
3333 self .base_url = base_url
34- self .api_token = api_token or tu . ENV . OPENAI_TOKEN ( "" )
34+ self .api_token = api_token
3535 self .extra_headers = extra_headers
36- self .emebdding_url = emebdding_url or base_url .replace (
37- "/chat/completions" ,
38- "/embeddings" ,
39- )
40- self .image_gen_url = image_gen_url or base_url .replace (
41- "/chat/completions" ,
42- "/images/generations" ,
43- )
44- self .audio_transcribe_url = audio_transcribe or base_url .replace (
45- "/chat/completions" ,
46- "/audio/transcriptions" ,
47- )
48- self .audio_gen_url = audio_gen_url or base_url .replace (
49- "/chat/completions" ,
50- "/audio/speech" ,
51- )
36+ self .emebdding_url = emebdding_url
37+ self .image_gen_url = image_gen_url
38+ self .audio_transcribe_url = audio_transcribe_url
39+ self .audio_gen_url = audio_gen_url
5240
5341 def set_api_token (self , token : str ) -> None :
5442 self .api_token = token
@@ -875,7 +863,47 @@ async def text_to_speech_async(
875863# Other OpenAI compatible models
876864
877865
878- class Mistral (Openai ):
866+ class Openai (OpenAIProtocol ):
867+ def __init__ (
868+ self ,
869+ id : str = "gpt-4o" ,
870+ base_url : str = "https://api.openai.com/v1/chat/completions" ,
871+ extra_headers : Optional [Dict [str , str ]] = None ,
872+ api_token : Optional [str ] = None ,
873+ emebdding_url : Optional [str ] = None ,
874+ image_gen_url : Optional [str ] = None ,
875+ audio_transcribe : Optional [str ] = None ,
876+ audio_gen_url : Optional [str ] = None ,
877+ ):
878+ super ().__init__ (
879+ id = id ,
880+ base_url = base_url ,
881+ api_token = api_token or tu .ENV .OPENAI_TOKEN ("" ),
882+ extra_headers = extra_headers ,
883+ emebdding_url = emebdding_url
884+ or base_url .replace (
885+ "/chat/completions" ,
886+ "/embeddings" ,
887+ ),
888+ image_gen_url = image_gen_url
889+ or base_url .replace (
890+ "/chat/completions" ,
891+ "/images/generations" ,
892+ ),
893+ audio_transcribe_url = audio_transcribe
894+ or base_url .replace (
895+ "/chat/completions" ,
896+ "/audio/transcriptions" ,
897+ ),
898+ audio_gen_url = audio_gen_url
899+ or base_url .replace (
900+ "/chat/completions" ,
901+ "/audio/speech" ,
902+ ),
903+ )
904+
905+
906+ class Mistral (OpenAIProtocol ):
879907 """
880908 A class to interact with Mistral's Large Language Models (LLMs) via their API. Note this class does not contain the
881909 `embedding` method.
@@ -905,13 +933,17 @@ def __init__(
905933 base_url = base_url ,
906934 extra_headers = extra_headers ,
907935 api_token = api_token or tu .ENV .MISTRAL_TOKEN (),
936+ emebdding_url = None ,
937+ image_gen_url = None ,
938+ audio_transcribe_url = None ,
939+ audio_gen_url = None ,
908940 )
909941
910942 def embedding (* a , ** k ):
911943 raise NotImplementedError ("Mistral does not support embeddings" )
912944
913945
914- class Groq (Openai ):
946+ class Groq (OpenAIProtocol ):
915947 """
916948 A class to interact with Groq's Large Language Models (LLMs) via their API. Note this class does not contain the
917949 `embedding` method.
@@ -938,13 +970,17 @@ def __init__(
938970 base_url = base_url ,
939971 extra_headers = extra_headers ,
940972 api_token = api_token or tu .ENV .GROQ_TOKEN (),
973+ emebdding_url = None ,
974+ image_gen_url = None ,
975+ audio_transcribe_url = None ,
976+ audio_gen_url = None ,
941977 )
942978
943979 def embedding (* a , ** k ):
944980 raise NotImplementedError ("Groq does not support embeddings" )
945981
946982
947- class TuneModel (Openai ):
983+ class TuneModel (OpenAIProtocol ):
948984 """
949985 A class to interact with Groq's Large Language Models (LLMs) via their API.
950986
@@ -978,6 +1014,9 @@ def __init__(
9781014 extra_headers = extra_headers ,
9791015 api_token = api_token or tu .ENV .TUNEAPI_TOKEN (),
9801016 emebdding_url = "https://proxy.tune.app/v1/embeddings" ,
1017+ image_gen_url = None ,
1018+ audio_transcribe_url = None ,
1019+ audio_gen_url = None ,
9811020 )
9821021
9831022 def embedding (
0 commit comments