Skip to content

Commit 966eaa8

Browse files
Move adapter-related classes to separate files (#807)
* Move adapter-related classes to separate files * Remove AwsBedrock in the Adapter * Remove all adapters in `Langchain::Assistant` class * Remove `assistants/llm/adapters/aws_bedrock.rb` * Rename _base.rb to base.rb * Update assistant.rb * Rename open_ai.rb to openai.rb * Rename base.rb to _base.rb --------- Co-authored-by: Andrei Bondarev <[email protected]>
1 parent acb72c3 commit 966eaa8

File tree

8 files changed

+361
-315
lines changed

8 files changed

+361
-315
lines changed

lib/langchain/assistants/assistant.rb

Lines changed: 2 additions & 315 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# frozen_string_literal: true
22

3+
require_relative "llm/adapter"
4+
35
module Langchain
46
# Assistants are Agent-like objects that leverage helpful instructions, LLMs, tools and knowledge to respond to user queries.
57
# Assistants can be configured with an LLM of your choice, any vector search database and easily extended with additional tools.
@@ -412,320 +414,5 @@ def record_used_tokens(prompt_tokens, completion_tokens, total_tokens_from_opera
412414
def available_tool_names
413415
llm_adapter.available_tool_names(tools)
414416
end
415-
416-
# TODO: Fix the message truncation when context window is exceeded
417-
418-
module LLM
419-
class Adapter
420-
def self.build(llm)
421-
case llm
422-
when Langchain::LLM::Anthropic
423-
Adapters::Anthropic.new
424-
when Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI
425-
Adapters::GoogleGemini.new
426-
when Langchain::LLM::MistralAI
427-
Adapters::MistralAI.new
428-
when Langchain::LLM::Ollama
429-
Adapters::Ollama.new
430-
when Langchain::LLM::OpenAI
431-
Adapters::OpenAI.new
432-
else
433-
raise ArgumentError, "Unsupported LLM type: #{llm.class}"
434-
end
435-
end
436-
end
437-
438-
module Adapters
439-
class Base
440-
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
441-
raise NotImplementedError, "Subclasses must implement build_chat_params"
442-
end
443-
444-
def extract_tool_call_args(tool_call:)
445-
raise NotImplementedError, "Subclasses must implement extract_tool_call_args"
446-
end
447-
448-
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
449-
raise NotImplementedError, "Subclasses must implement build_message"
450-
end
451-
end
452-
453-
class Ollama < Base
454-
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
455-
params = {messages: messages}
456-
if tools.any?
457-
params[:tools] = build_tools(tools)
458-
end
459-
params
460-
end
461-
462-
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
463-
warn "Image URL is not supported by Ollama currently" if image_url
464-
465-
Langchain::Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
466-
end
467-
468-
# Extract the tool call information from the OpenAI tool call hash
469-
#
470-
# @param tool_call [Hash] The tool call hash
471-
# @return [Array] The tool call information
472-
def extract_tool_call_args(tool_call:)
473-
tool_call_id = tool_call.dig("id")
474-
475-
function_name = tool_call.dig("function", "name")
476-
tool_name, method_name = function_name.split("__")
477-
478-
tool_arguments = tool_call.dig("function", "arguments")
479-
tool_arguments = if tool_arguments.is_a?(Hash)
480-
Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments)
481-
else
482-
JSON.parse(tool_arguments, symbolize_names: true)
483-
end
484-
485-
[tool_call_id, tool_name, method_name, tool_arguments]
486-
end
487-
488-
def available_tool_names(tools)
489-
build_tools(tools).map { |tool| tool.dig(:function, :name) }
490-
end
491-
492-
def allowed_tool_choices
493-
["auto", "none"]
494-
end
495-
496-
private
497-
498-
def build_tools(tools)
499-
tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten
500-
end
501-
end
502-
503-
class OpenAI < Base
504-
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
505-
params = {messages: messages}
506-
if tools.any?
507-
params[:tools] = build_tools(tools)
508-
params[:tool_choice] = build_tool_choice(tool_choice)
509-
end
510-
params
511-
end
512-
513-
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
514-
Langchain::Messages::OpenAIMessage.new(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
515-
end
516-
517-
# Extract the tool call information from the OpenAI tool call hash
518-
#
519-
# @param tool_call [Hash] The tool call hash
520-
# @return [Array] The tool call information
521-
def extract_tool_call_args(tool_call:)
522-
tool_call_id = tool_call.dig("id")
523-
524-
function_name = tool_call.dig("function", "name")
525-
tool_name, method_name = function_name.split("__")
526-
527-
tool_arguments = tool_call.dig("function", "arguments")
528-
tool_arguments = if tool_arguments.is_a?(Hash)
529-
Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments)
530-
else
531-
JSON.parse(tool_arguments, symbolize_names: true)
532-
end
533-
534-
[tool_call_id, tool_name, method_name, tool_arguments]
535-
end
536-
537-
def build_tools(tools)
538-
tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten
539-
end
540-
541-
def allowed_tool_choices
542-
["auto", "none"]
543-
end
544-
545-
def available_tool_names(tools)
546-
build_tools(tools).map { |tool| tool.dig(:function, :name) }
547-
end
548-
549-
private
550-
551-
def build_tool_choice(choice)
552-
case choice
553-
when "auto"
554-
choice
555-
else
556-
{"type" => "function", "function" => {"name" => choice}}
557-
end
558-
end
559-
end
560-
561-
class MistralAI < Base
562-
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
563-
params = {messages: messages}
564-
if tools.any?
565-
params[:tools] = build_tools(tools)
566-
params[:tool_choice] = build_tool_choice(tool_choice)
567-
end
568-
params
569-
end
570-
571-
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
572-
Langchain::Messages::MistralAIMessage.new(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
573-
end
574-
575-
# Extract the tool call information from the OpenAI tool call hash
576-
#
577-
# @param tool_call [Hash] The tool call hash
578-
# @return [Array] The tool call information
579-
def extract_tool_call_args(tool_call:)
580-
tool_call_id = tool_call.dig("id")
581-
582-
function_name = tool_call.dig("function", "name")
583-
tool_name, method_name = function_name.split("__")
584-
585-
tool_arguments = tool_call.dig("function", "arguments")
586-
tool_arguments = if tool_arguments.is_a?(Hash)
587-
Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments)
588-
else
589-
JSON.parse(tool_arguments, symbolize_names: true)
590-
end
591-
592-
[tool_call_id, tool_name, method_name, tool_arguments]
593-
end
594-
595-
def build_tools(tools)
596-
tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten
597-
end
598-
599-
def allowed_tool_choices
600-
["auto", "none"]
601-
end
602-
603-
def available_tool_names(tools)
604-
build_tools(tools).map { |tool| tool.dig(:function, :name) }
605-
end
606-
607-
private
608-
609-
def build_tool_choice(choice)
610-
case choice
611-
when "auto"
612-
choice
613-
else
614-
{"type" => "function", "function" => {"name" => choice}}
615-
end
616-
end
617-
end
618-
619-
class GoogleGemini < Base
620-
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
621-
params = {messages: messages}
622-
if tools.any?
623-
params[:tools] = build_tools(tools)
624-
params[:system] = instructions if instructions
625-
params[:tool_choice] = build_tool_config(tool_choice)
626-
end
627-
params
628-
end
629-
630-
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
631-
warn "Image URL is not supported by Google Gemini" if image_url
632-
633-
Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
634-
end
635-
636-
# Extract the tool call information from the Google Gemini tool call hash
637-
#
638-
# @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}}
639-
# @return [Array] The tool call information
640-
def extract_tool_call_args(tool_call:)
641-
tool_call_id = tool_call.dig("functionCall", "name")
642-
function_name = tool_call.dig("functionCall", "name")
643-
tool_name, method_name = function_name.split("__")
644-
tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym)
645-
[tool_call_id, tool_name, method_name, tool_arguments]
646-
end
647-
648-
def build_tools(tools)
649-
tools.map { |tool| tool.class.function_schemas.to_google_gemini_format }.flatten
650-
end
651-
652-
def allowed_tool_choices
653-
["auto", "none"]
654-
end
655-
656-
def available_tool_names(tools)
657-
build_tools(tools).map { |tool| tool.dig(:name) }
658-
end
659-
660-
private
661-
662-
def build_tool_config(choice)
663-
case choice
664-
when "auto"
665-
{function_calling_config: {mode: "auto"}}
666-
when "none"
667-
{function_calling_config: {mode: "none"}}
668-
else
669-
{function_calling_config: {mode: "any", allowed_function_names: [choice]}}
670-
end
671-
end
672-
end
673-
674-
class Anthropic < Base
675-
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
676-
params = {messages: messages}
677-
if tools.any?
678-
params[:tools] = build_tools(tools)
679-
params[:tool_choice] = build_tool_choice(tool_choice)
680-
end
681-
params[:system] = instructions if instructions
682-
params
683-
end
684-
685-
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
686-
warn "Image URL is not supported by Anthropic currently" if image_url
687-
688-
Langchain::Messages::AnthropicMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
689-
end
690-
691-
# Extract the tool call information from the Anthropic tool call hash
692-
#
693-
# @param tool_call [Hash] The tool call hash, format: {"type"=>"tool_use", "id"=>"toolu_01TjusbFApEbwKPRWTRwzadR", "name"=>"news_retriever__get_top_headlines", "input"=>{"country"=>"us", "page_size"=>10}}], "stop_reason"=>"tool_use"}
694-
# @return [Array] The tool call information
695-
def extract_tool_call_args(tool_call:)
696-
tool_call_id = tool_call.dig("id")
697-
function_name = tool_call.dig("name")
698-
tool_name, method_name = function_name.split("__")
699-
tool_arguments = tool_call.dig("input").transform_keys(&:to_sym)
700-
[tool_call_id, tool_name, method_name, tool_arguments]
701-
end
702-
703-
def build_tools(tools)
704-
tools.map { |tool| tool.class.function_schemas.to_anthropic_format }.flatten
705-
end
706-
707-
def allowed_tool_choices
708-
["auto", "any"]
709-
end
710-
711-
def available_tool_names(tools)
712-
build_tools(tools).map { |tool| tool.dig(:name) }
713-
end
714-
715-
private
716-
717-
def build_tool_choice(choice)
718-
case choice
719-
when "auto"
720-
{type: "auto"}
721-
when "any"
722-
{type: "any"}
723-
else
724-
{type: "tool", name: choice}
725-
end
726-
end
727-
end
728-
end
729-
end
730417
end
731418
end
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
Dir[Pathname.new(__FILE__).dirname.join("adapters", "*.rb")].sort.each { |file| require file }
2+
3+
module Langchain
4+
class Assistant
5+
module LLM
6+
# TODO: Fix the message truncation when context window is exceeded
7+
class Adapter
8+
def self.build(llm)
9+
case llm
10+
when Langchain::LLM::Anthropic
11+
LLM::Adapters::Anthropic.new
12+
when Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI
13+
LLM::Adapters::GoogleGemini.new
14+
when Langchain::LLM::MistralAI
15+
LLM::Adapters::MistralAI.new
16+
when Langchain::LLM::Ollama
17+
LLM::Adapters::Ollama.new
18+
when Langchain::LLM::OpenAI
19+
LLM::Adapters::OpenAI.new
20+
else
21+
raise ArgumentError, "Unsupported LLM type: #{llm.class}"
22+
end
23+
end
24+
end
25+
end
26+
end
27+
end
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module Langchain
2+
class Assistant
3+
module LLM
4+
module Adapters
5+
class Base
6+
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
7+
raise NotImplementedError, "Subclasses must implement build_chat_params"
8+
end
9+
10+
def extract_tool_call_args(tool_call:)
11+
raise NotImplementedError, "Subclasses must implement extract_tool_call_args"
12+
end
13+
14+
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
15+
raise NotImplementedError, "Subclasses must implement build_message"
16+
end
17+
end
18+
end
19+
end
20+
end
21+
end

0 commit comments

Comments
 (0)