diff --git a/guidance/models/_googleai.py b/guidance/models/_googleai.py index 8cf1c2a5a..3da285f32 100644 --- a/guidance/models/_googleai.py +++ b/guidance/models/_googleai.py @@ -6,7 +6,6 @@ _image_token_pattern = re.compile(r"<\|_image:(.*)\|>") - class GoogleAIEngine(GrammarlessEngine): def __init__( self, @@ -65,14 +64,6 @@ def __init__( # chat found_subclass = GoogleAIChat # we assume all models are chat right now - # instruct - # elif "instruct" in model: - # found_subclass = GoogleAIInstruct - - # # regular completion - # else: - # found_subclass = GoogleAICompletion - # convert to any found subclass self.__class__ = found_subclass found_subclass.__init__( @@ -89,7 +80,11 @@ def __init__( return # we return since we just ran init above and don't need to run again # this allows us to use a single constructor for all our subclasses - engine_map = {GoogleAIChat: GoogleAIChatEngine} + engine_map = { + GoogleAIChat: GoogleAIChatEngine, + GoogleAIInstruct: GoogleAIInstructEngine, + GoogleAICompletion: GoogleAICompletionEngine, + } super().__init__( engine=engine_map[self.__class__]( @@ -104,6 +99,75 @@ def __init__( echo=echo, ) +class GoogleAICompletion(GoogleAI): + pass + +class GoogleAICompletionEngine(GoogleAIEngine): + def _generator(self, prompt, temperature): + + self._not_running_stream.clear() # so we know we are running + self._data = prompt # we start with this data + + try: + kwargs = {} + generation_config = {"temperature": temperature} + if self.max_streaming_tokens is not None: + generation_config["max_output_tokens"] = self.max_streaming_tokens + kwargs["generation_config"] = generation_config + + generator = self.model_obj.generate_content( + contents=self._data.decode("utf8"), + stream=True, + **kwargs, + ) + except Exception as e: # TODO: add retry logic + raise e + + for chunk in generator: + yield chunk.candidates[0].content.parts[0].text.encode("utf8") + +class GoogleAIInstruct(GoogleAI, Instruct): + def get_role_start(self, name): + return "" + + def get_role_end(self, name): + if name == "instruction": + return "<|endofprompt|>" + else: + raise Exception( + f"The GoogleAIInstruct model does not know about the {name} role type!" + ) + +class GoogleAIInstructEngine(GoogleAIEngine): + def _generator(self, prompt, temperature): + # start the new stream + eop_count = prompt.count(b"<|endofprompt|>") + if eop_count > 1: + raise Exception( + "This model has been given multiple instruct blocks or <|endofprompt|> tokens, but this is not allowed!" + ) + updated_prompt = prompt + b"<|endofprompt|>" if eop_count == 0 else prompt + + self._not_running_stream.clear() # so we know we are running + self._data = updated_prompt # we start with this data + + try: + kwargs = {} + generation_config = {"temperature": temperature} + if self.max_streaming_tokens is not None: + generation_config["max_output_tokens"] = self.max_streaming_tokens + kwargs["generation_config"] = generation_config + + generator = self.model_obj.generate_content( + contents=self._data.decode("utf8"), + stream=True, + **kwargs, + ) + except Exception as e: # TODO: add retry logic + raise e + + for chunk in generator: + yield chunk.candidates[0].content.parts[0].text.encode("utf8") class GoogleAIChatEngine(GoogleAIEngine): def _generator(self, prompt, temperature): diff --git a/tests/models/test_googleai.py b/tests/models/test_googleai.py index 44378b1d6..fe4f819f7 100644 --- a/tests/models/test_googleai.py +++ b/tests/models/test_googleai.py @@ -4,6 +4,29 @@ from ..utils import get_model +def test_googleai_basic(): + try: + lm = models.GoogleAICompletion("gemini-pro") + except: + pytest.skip("Skipping GoogleAI test because we can't load the model!") + + lm += "Count to 20: 1,2,3,4," + nl = "\n" + lm += f"""\ +5,6,7""" + lm += f"""{gen(max_tokens=1, suffix=nl)}aaaaaa""" + assert str(lm)[-5:] == "aaaaa" + +def test_googleai_instruct(): + try: + lm = models.GoogleAIInstruct("gemini-pro") + except: + pytest.skip("Skipping GoogleAI test because we can't load the model!") + + with instruction(): + lm += "this is a test about" + lm += gen("test", max_tokens=100) + assert len(lm["test"]) > 0 def test_gemini_pro(): from guidance import assistant, gen, models, system, user