diff --git a/.env.example b/.env.example index a6724f8..be06f38 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,3 @@ -OPENAI_API_KEY= -ANTHROPIC_API_KEY= \ No newline at end of file +OPENAI_API_KEY=your_openai_api_key_here +ANTHROPIC_API_KEY=your_anthropic_api_key_here +GOOGLE_API_KEY=your_google_api_key_here diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 09e8c00..fd23068 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,7 @@ on: jobs: test: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v3 diff --git a/requirements.txt b/requirements.txt index ad2d878..21e17b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,10 @@ anthropic>=0.42.0 python-dotenv>=1.0.0 # Testing -unittest2>=1.1.0 \ No newline at end of file +unittest2>=1.1.0 + +# Google Generative AI +google-generativeai + +# gRPC, for Google Generative AI preventing WARNING: All log messages before absl::InitializeLog() is called are written to STDERR +grpcio==1.60.1 diff --git a/tools/llm_api.py b/tools/llm_api.py index fa379c3..3d02316 100644 --- a/tools/llm_api.py +++ b/tools/llm_api.py @@ -1,5 +1,6 @@ #!/usr/bin/env /workspace/tmp_windsurf/py310/bin/python3 +import google.generativeai as genai from openai import OpenAI from anthropic import Anthropic import argparse @@ -7,7 +8,6 @@ from dotenv import load_dotenv from pathlib import Path -# 載入 .env.local 檔案 env_path = Path('.') / '.env.local' load_dotenv(dotenv_path=env_path) @@ -26,10 +26,16 @@ def create_llm_client(provider="openai"): return Anthropic( api_key=api_key ) + elif provider == "gemini": + api_key = os.getenv('GOOGLE_API_KEY') + if not api_key: + raise ValueError("GOOGLE_API_KEY not found in environment variables") + genai.configure(api_key=api_key) + return genai elif provider == "local": return OpenAI( base_url="http://192.168.180.137:8006/v1", - api_key="not-needed" # 本地部署可能不需要 API key + api_key="not-needed" ) else: raise ValueError(f"Unsupported provider: {provider}") @@ -39,12 +45,14 @@ def query_llm(prompt, client=None, model=None, provider="openai"): client = create_llm_client(provider) try: - # 設定預設模型 + # Set default model if model is None: if provider == "openai": model = "gpt-3.5-turbo" elif provider == "anthropic": model = "claude-3-sonnet-20240229" + elif provider == "gemini": + model = "gemini-pro" elif provider == "local": model = "Qwen/Qwen2.5-32B-Instruct-AWQ" @@ -66,6 +74,10 @@ def query_llm(prompt, client=None, model=None, provider="openai"): ] ) return response.content[0].text + elif provider == "gemini": + model = client.GenerativeModel(model) + response = model.generate_content(prompt) + return response.text except Exception as e: print(f"Error querying LLM: {e}") return None @@ -73,18 +85,17 @@ def query_llm(prompt, client=None, model=None, provider="openai"): def main(): parser = argparse.ArgumentParser(description='Query an LLM with a prompt') parser.add_argument('--prompt', type=str, help='The prompt to send to the LLM', required=True) - parser.add_argument('--provider', type=str, choices=['openai', 'anthropic'], - default="openai", help='The API provider to use') - parser.add_argument('--model', type=str, - help='The model to use (default depends on provider)') + parser.add_argument('--provider', choices=['openai','anthropic','gemini','local'], default='openai', help='The API provider to use') + parser.add_argument('--model', type=str, help='The model to use (default depends on provider)') args = parser.parse_args() - # 設定預設模型 if not args.model: - if args.provider == "openai": + if args.provider == 'openai': args.model = "gpt-3.5-turbo" - else: + elif args.provider == 'anthropic': args.model = "claude-3-5-sonnet-20241022" + elif args.provider == 'gemini': + args.model = "gemini-2.0-flash-exp" client = create_llm_client(args.provider) response = query_llm(args.prompt, client, model=args.model, provider=args.provider)