diff --git a/lib/langchain/output_parsers/output_fixing_parser.rb b/lib/langchain/output_parsers/output_fixing_parser.rb index eaf92487f..00e761984 100644 --- a/lib/langchain/output_parsers/output_fixing_parser.rb +++ b/lib/langchain/output_parsers/output_fixing_parser.rb @@ -47,12 +47,7 @@ def parse(completion) parser.parse(completion) rescue OutputParserException => e new_completion = llm.chat( - messages: [{role: "user", - content: prompt.format( - instructions: parser.get_format_instructions, - completion: completion, - error: e - )}] + messages: chat_messages(completion, e) ).completion parser.parse(new_completion) end @@ -70,6 +65,38 @@ def self.from_llm(llm:, parser:, prompt: nil) private + def chat_messages(completion, e) + # For Google LLMs, use the parts format + if llm.is_a?(Langchain::LLM::GoogleGemini) || llm.is_a?(Langchain::LLM::GoogleVertexAI) + return [ + { + role: "user", + parts: [ + { + text: prompt.format( + instructions: parser.get_format_instructions, + completion: completion, + error: e + ) + } + ] + } + ] + end + + # For other LLMs, use the standard content format + [ + { + role: "user", + content: prompt.format( + instructions: parser.get_format_instructions, + completion: completion, + error: e + ) + } + ] + end + private_class_method def self.naive_fix_prompt Langchain::Prompt.load_from_path( file_path: Langchain.root.join("langchain/output_parsers/prompts/naive_fix_prompt.yaml") diff --git a/spec/lib/langchain/output_parsers/fix_spec.rb b/spec/lib/langchain/output_parsers/fix_spec.rb index 22d951f77..5ba72d88d 100644 --- a/spec/lib/langchain/output_parsers/fix_spec.rb +++ b/spec/lib/langchain/output_parsers/fix_spec.rb @@ -186,5 +186,26 @@ expect { parser.parse("Whoops I don't understand") }.to raise_error(Langchain::OutputParsers::OutputParserException) expect(parser.llm).to have_received(:chat).once end + + context "with a Gemini based model" do + let(:llm_example) { Langchain::LLM::GoogleGemini.new(api_key: "123") } + + it "parses when the llm is a Gemini based model" do + parser = described_class.new(**kwargs_example.merge(prompt: fix_prompt_template_example)) + expect(parser.llm).to receive(:chat).with({ + messages: [ + { + role: "user", + parts: [ + { + text: match(fix_prompt_matcher_example) + } + ] + } + ] + }).and_return(double(completion: json_text_response)) + expect(parser.parse("Whoops I don't understand")).to eq(json_response) + end + end end end