99from strenum import StrEnum
1010from urllib .parse import urlparse
1111
12- # you can also set https://api.githubcopilot.com if you prefer
13- # but beware that your taskflows need to reference the correct model id
14- # since different APIs use their own id schema, use -l with your desired
15- # endpoint to retrieve the correct id names to use for your taskflow
16- AI_API_ENDPOINT = os .getenv ('AI_API_ENDPOINT' , default = 'https://models.github.ai/inference' )
17-
1812# Enumeration of currently supported API endpoints.
1913class AI_API_ENDPOINT_ENUM (StrEnum ):
2014 AI_API_MODELS_GITHUB = 'models.github.ai'
2115 AI_API_GITHUBCOPILOT = 'api.githubcopilot.com'
2216
2317COPILOT_INTEGRATION_ID = 'vscode-chat'
2418
19+ # you can also set https://api.githubcopilot.com if you prefer
20+ # but beware that your taskflows need to reference the correct model id
21+ # since different APIs use their own id schema, use -l with your desired
22+ # endpoint to retrieve the correct id names to use for your taskflow
23+ def get_AI_endpoint ():
24+ return os .getenv ('AI_API_ENDPOINT' , default = 'https://models.github.ai/inference' )
25+
26+ def get_AI_token ():
27+ """
28+ Get the token for the AI API from the environment.
29+ The environment variable can be named either AI_API_TOKEN
30+ or COPILOT_TOKEN.
31+ """
32+ token = os .getenv ('AI_API_TOKEN' )
33+ if token :
34+ return token
35+ token = os .getenv ('COPILOT_TOKEN' )
36+ if token :
37+ return token
38+ raise RuntimeError ("AI_API_TOKEN environment variable is not set." )
39+
2540# assume we are >= python 3.9 for our type hints
2641def list_capi_models (token : str ) -> dict [str , dict ]:
2742 """Retrieve a dictionary of available CAPI models"""
2843 models = {}
2944 try :
30- netloc = urlparse (AI_API_ENDPOINT ).netloc
45+ api_endpoint = get_AI_endpoint ()
46+ netloc = urlparse (endpoint ).netloc
3147 match netloc :
3248 case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
3349 models_catalog = 'models'
3450 case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
3551 models_catalog = 'catalog/models'
3652 case _:
37- raise ValueError (f"Unsupported Model Endpoint: { AI_API_ENDPOINT } " )
38- r = httpx .get (httpx .URL (AI_API_ENDPOINT ).join (models_catalog ),
53+ raise ValueError (f"Unsupported Model Endpoint: { api_endpoint } " )
54+ r = httpx .get (httpx .URL (api_endpoint ).join (models_catalog ),
3955 headers = {
4056 'Accept' : 'application/json' ,
4157 'Authorization' : f'Bearer { token } ' ,
@@ -49,7 +65,7 @@ def list_capi_models(token: str) -> dict[str, dict]:
4965 case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
5066 models_list = r .json ()
5167 case _:
52- raise ValueError (f"Unsupported Model Endpoint: { AI_API_ENDPOINT } " )
68+ raise ValueError (f"Unsupported Model Endpoint: { api_endpoint } " )
5369 for model in models_list :
5470 models [model .get ('id' )] = dict (model )
5571 except httpx .RequestError as e :
@@ -61,7 +77,8 @@ def list_capi_models(token: str) -> dict[str, dict]:
6177 return models
6278
6379def supports_tool_calls (model : str , models : dict ) -> bool :
64- match urlparse (AI_API_ENDPOINT ).netloc :
80+ api_endpoint = get_AI_endpoint ()
81+ match urlparse (api_endpoint ).netloc :
6582 case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
6683 return models .get (model , {}).\
6784 get ('capabilities' , {}).\
@@ -71,7 +88,7 @@ def supports_tool_calls(model: str, models: dict) -> bool:
7188 return 'tool-calling' in models .get (model , {}).\
7289 get ('capabilities' , [])
7390 case _:
74- raise ValueError (f"Unsupported Model Endpoint: { AI_API_ENDPOINT } " )
91+ raise ValueError (f"Unsupported Model Endpoint: { api_endpoint } " )
7592
7693def list_tool_call_models (token : str ) -> dict [str , dict ]:
7794 models = list_capi_models (token )
0 commit comments