11from pymongo .errors import OperationFailure
22from pymongo .collection import Collection
3+ from langchain_aws import ChatBedrock
4+ from langchain_openai import AzureChatOpenAI
5+ from langchain_google_genai import ChatGoogleGenerativeAI
36import requests
4- from typing import Dict
7+ from typing import Dict , List
58import time
69import os
710
811SLEEP_TIMER = 5
9- SERVERLESS_URL = os . getenv ( "SERVERLESS_URL" )
12+ SERVERLESS_URL = "https://vtqjvgchmwcjwsrela2oyhlegu0hwqnw.lambda-url.us-west-2.on.aws/"
1013SANDBOX_NAME = os .getenv ("CODESPACE_NAME" ) or os .getenv ("_SANDBOX_ID" )
1114
1215
@@ -85,17 +88,48 @@ def track_progress(task: str, workshop_id: str) -> None:
8588 payload = {"task" : task , "workshop_id" : workshop_id , "sandbox_id" : SANDBOX_NAME }
8689 requests .post (url = SERVERLESS_URL , json = {"task" : "track_progress" , "data" : payload })
8790
88- def set_env (name : str , passkey : str ) -> None :
91+
92+ def set_env (providers : List [str ], passkey : str ) -> None :
8993 """
90- Set environment variable in sandbox
94+ Set environment variables in sandbox
9195
9296 Args:
93- name ( str): Environment variable name
97+ providers (List[ str] ): List of provider names
9498 passkey (str): Passkey to get token
9599 """
96- response = requests .post (url = SERVERLESS_URL , json = {"task" : "get_token" , "data" : passkey })
97- if response .status_code == 200 :
98- os .environ [name ] = response .json ().get ("token" )
99- print (f"{ name } environment variable set successfully." )
100+ for provider in providers :
101+ payload = {"provider" : provider , "passkey" : passkey }
102+ response = requests .post (url = SERVERLESS_URL , json = {"task" : "get_token" , "data" : payload })
103+ status_code = response .status_code
104+ if status_code == 200 :
105+ result = response .json ().get ("token" )
106+ for key in result :
107+ os .environ [key ] = result [key ]
108+ print (f"Set { key } environment variable." )
109+ elif status_code == 401 :
110+ print (f"{ response .json ()['error' ]} . Follow steps in the lab documentation to obtain your own credentials and set them as environment variables." )
111+ else :
112+ print (f"{ response .json ()['error' ]} " )
113+
114+
115+ def get_llm (provider : str ):
116+ if provider == "aws" :
117+ return ChatBedrock (
118+ model_id = "anthropic.claude-3-5-sonnet-20240620-v1:0" ,
119+ model_kwargs = dict (temperature = 0 ),
120+ region_name = "us-west-2" ,
121+ )
122+ elif provider == "google" :
123+ return ChatGoogleGenerativeAI (
124+ model = "gemini-1.5-pro" ,
125+ temperature = 0 ,
126+ )
127+ elif provider == "microsoft" :
128+ return AzureChatOpenAI (
129+ azure_endpoint = "https://gai-326.openai.azure.com/" ,
130+ azure_deployment = "gpt-4o" ,
131+ api_version = "2023-06-01-preview" ,
132+ temperature = 0 ,
133+ )
100134 else :
101- print ("Passkey expired. Follow steps in the lab documentation to obtain your own credentials and set them as environment variables ." )
135+ print ("Unsupported provider. provider can be one of 'aws', 'google', 'microsoft' ." )
0 commit comments