From 9b69774ab3bd798920d5076dbf27ae17f896ad7a Mon Sep 17 00:00:00 2001 From: Chad Kirby Date: Mon, 13 May 2024 14:03:58 -0700 Subject: [PATCH 1/3] added GoogleAI completion and instruct engines --- guidance/models/_googleai.py | 84 ++++++++++++++++++++++++++++++----- tests/models/test_googleai.py | 12 +++++ 2 files changed, 86 insertions(+), 10 deletions(-) diff --git a/guidance/models/_googleai.py b/guidance/models/_googleai.py index 8cf1c2a5a..4e2951bc3 100644 --- a/guidance/models/_googleai.py +++ b/guidance/models/_googleai.py @@ -6,7 +6,6 @@ _image_token_pattern = re.compile(r"<\|_image:(.*)\|>") - class GoogleAIEngine(GrammarlessEngine): def __init__( self, @@ -65,14 +64,6 @@ def __init__( # chat found_subclass = GoogleAIChat # we assume all models are chat right now - # instruct - # elif "instruct" in model: - # found_subclass = GoogleAIInstruct - - # # regular completion - # else: - # found_subclass = GoogleAICompletion - # convert to any found subclass self.__class__ = found_subclass found_subclass.__init__( @@ -89,7 +80,11 @@ def __init__( return # we return since we just ran init above and don't need to run again # this allows us to use a single constructor for all our subclasses - engine_map = {GoogleAIChat: GoogleAIChatEngine} + engine_map = { + GoogleAIChat: GoogleAIChatEngine, + GoogleAIInstruct: GoogleAIInstructEngine, + GoogleAICompletion: GoogleAICompletionEngine, + } super().__init__( engine=engine_map[self.__class__]( @@ -104,6 +99,75 @@ def __init__( echo=echo, ) +class GoogleAICompletion(GoogleAI): + pass + +class GoogleAICompletionEngine(GoogleAIEngine): + def _generator(self, prompt, temperature): + + self._not_running_stream.clear() # so we know we are running + self._data = prompt # we start with this data + + try: + kwargs = {} + generation_config = {"temperature": temperature} + if self.max_streaming_tokens is not None: + generation_config["max_output_tokens"] = self.max_streaming_tokens + kwargs["generation_config"] = generation_config + + generator = self.model_obj.generate_content( + contents=prompt, + stream=True, + **kwargs, + ) + except Exception as e: # TODO: add retry logic + raise e + + for chunk in generator: + yield chunk.candidates[0].content.parts[0].text.encode("utf8") + +class GoogleAIInstruct(GoogleAI, Instruct): + def get_role_start(self, name): + return "" + + def get_role_end(self, name): + if name == "instruction": + return "<|endofprompt|>" + else: + raise Exception( + f"The GoogleAIInstruct model does not know about the {name} role type!" + ) + +class GoogleAIInstructEngine(GoogleAIEngine): + def _generator(self, prompt, temperature): + # start the new stream + eop_count = prompt.count(b"<|endofprompt|>") + if eop_count > 1: + raise Exception( + "This model has been given multiple instruct blocks or <|endofprompt|> tokens, but this is not allowed!" + ) + updated_prompt = prompt + b"<|endofprompt|>" if eop_count == 0 else prompt + + self._not_running_stream.clear() # so we know we are running + self._data = updated_prompt # we start with this data + + try: + kwargs = {} + generation_config = {"temperature": temperature} + if self.max_streaming_tokens is not None: + generation_config["max_output_tokens"] = self.max_streaming_tokens + kwargs["generation_config"] = generation_config + + generator = self.model_obj.generate_content( + contents=self._data.decode("utf8"), + stream=True, + **kwargs, + ) + except Exception as e: # TODO: add retry logic + raise e + + for chunk in generator: + yield chunk.candidates[0].content.parts[0].text.encode("utf8") class GoogleAIChatEngine(GoogleAIEngine): def _generator(self, prompt, temperature): diff --git a/tests/models/test_googleai.py b/tests/models/test_googleai.py index 44378b1d6..706e842c4 100644 --- a/tests/models/test_googleai.py +++ b/tests/models/test_googleai.py @@ -4,6 +4,18 @@ from ..utils import get_model +def test_googleai_basic(): + try: + lm = models.GoogleAICompletion("gemini-pro") + except: + pytest.skip("Skipping GoogleAI test because we can't load the model!") + + lm += "Count to 20: 1,2,3,4," + nl = "\n" + lm += f"""\ +5,6,7""" + lm += f"""{gen(max_tokens=1, suffix=nl)}aaaaaa""" + assert str(lm)[-5:] == "aaaaa" def test_gemini_pro(): from guidance import assistant, gen, models, system, user From db1b748ce9d9db506a5b23600b7a72411a2d0662 Mon Sep 17 00:00:00 2001 From: Chad Kirby Date: Mon, 13 May 2024 14:50:42 -0700 Subject: [PATCH 2/3] fixed GoogleAI prompt input --- guidance/models/_googleai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guidance/models/_googleai.py b/guidance/models/_googleai.py index 4e2951bc3..3da285f32 100644 --- a/guidance/models/_googleai.py +++ b/guidance/models/_googleai.py @@ -116,7 +116,7 @@ def _generator(self, prompt, temperature): kwargs["generation_config"] = generation_config generator = self.model_obj.generate_content( - contents=prompt, + contents=self._data.decode("utf8"), stream=True, **kwargs, ) From 8ca177765f207fd6ed5cd5e4148b13628a85bd22 Mon Sep 17 00:00:00 2001 From: Chad Kirby Date: Mon, 13 May 2024 15:00:30 -0700 Subject: [PATCH 3/3] Added GoogleAIInstruct test --- tests/models/test_googleai.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/models/test_googleai.py b/tests/models/test_googleai.py index 706e842c4..fe4f819f7 100644 --- a/tests/models/test_googleai.py +++ b/tests/models/test_googleai.py @@ -17,6 +17,17 @@ def test_googleai_basic(): lm += f"""{gen(max_tokens=1, suffix=nl)}aaaaaa""" assert str(lm)[-5:] == "aaaaa" +def test_googleai_instruct(): + try: + lm = models.GoogleAIInstruct("gemini-pro") + except: + pytest.skip("Skipping GoogleAI test because we can't load the model!") + + with instruction(): + lm += "this is a test about" + lm += gen("test", max_tokens=100) + assert len(lm["test"]) > 0 + def test_gemini_pro(): from guidance import assistant, gen, models, system, user