55
66import instructor
77
8- from typing import Any
9- from vertexai .generative_models import GenerativeModel , HarmBlockThreshold , HarmCategory # type: ignore
8+ from typing import Any , Type
9+ from vertexai .generative_models import GenerativeModel , HarmBlockThreshold , HarmCategory
1010from deepeval .models .base_model import DeepEvalBaseLLM
1111from pydantic import BaseModel
1212
@@ -31,18 +31,20 @@ def load_model(self, *args, **kwargs):
3131 HarmCategory .HARM_CATEGORY_HARASSMENT : HarmBlockThreshold .BLOCK_NONE ,
3232 HarmCategory .HARM_CATEGORY_SEXUALLY_EXPLICIT : HarmBlockThreshold .BLOCK_NONE ,
3333 }
34+ if not self .model_name :
35+ raise ValueError ("Model name must be specified for Google Vertex AI." )
3436
3537 return GenerativeModel (
3638 model_name = self .model_name ,
3739 safety_settings = safety_settings ,
3840 )
3941
40- def generate (self , prompt : str , schema : BaseModel ) -> Any :
42+ def generate (self , prompt : str , schema : Type [ BaseModel ] ) -> Any :
4143 instructor_client = instructor .from_vertexai (
4244 client = self .load_model (),
4345 mode = instructor .Mode .VERTEXAI_TOOLS ,
4446 )
45- resp = instructor_client .messages .create ( # type: ignore
47+ resp = instructor_client .messages .create (
4648 messages = [
4749 {
4850 "role" : "user" ,
@@ -53,13 +55,12 @@ def generate(self, prompt: str, schema: BaseModel) -> Any:
5355 )
5456 return resp
5557
56- async def a_generate (self , prompt : str , schema : BaseModel ) -> Any :
58+ async def a_generate (self , prompt : str , schema : Any ) -> Any :
5759 instructor_client = instructor .from_vertexai (
5860 client = self .load_model (),
5961 mode = instructor .Mode .VERTEXAI_TOOLS ,
60- _async = True ,
6162 )
62- resp = await instructor_client .messages .create ( # type: ignore
63+ resp = await instructor_client .completions .create (
6364 messages = [
6465 {
6566 "role" : "user" ,
@@ -71,7 +72,7 @@ async def a_generate(self, prompt: str, schema: BaseModel) -> Any:
7172 return resp
7273
7374 def get_model_name (self ):
74- return self .model_name
75+ return self .model_name or "model-not-specified"
7576
7677
7778def main ():
@@ -86,7 +87,7 @@ async def main_async():
8687 model = GoogleVertexAILangChain (model_name = "gemini-1.5-pro-002" )
8788 prompt = "Write me a joke"
8889 print (f"Prompt: { prompt } " )
89- response = await model .a_generate (prompt , Response )
90+ response = await model .a_generate (prompt , schema = Response )
9091 print (f"Response: { response } " )
9192
9293
0 commit comments