-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.py
More file actions
149 lines (115 loc) · 5.79 KB
/
config.py
File metadata and controls
149 lines (115 loc) · 5.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
config.py — Client LLM unificato / Unified LLM Client
======================================================
Supporta Ollama (locale), Azure OpenAI, OpenAI diretto.
Supports Ollama (local), Azure OpenAI, direct OpenAI.
Uso / Usage:
from config import get_client, get_model, get_embedding_model, describe_config, settings
client = get_client()
model = get_model()
"""
from __future__ import annotations
import os
from pathlib import Path
from dataclasses import dataclass, field
from dotenv import load_dotenv
from openai import OpenAI, AzureOpenAI
# Carica .env dalla root del repo / Load .env from repo root
_env_path = Path(__file__).resolve().parent / ".env"
load_dotenv(_env_path)
# ------------------------------------------------------------------
# Settings
# ------------------------------------------------------------------
@dataclass
class Settings:
"""Configurazione caricata da variabili d'ambiente / Configuration loaded from env vars."""
provider: str = field(default_factory=lambda: os.getenv("LLM_PROVIDER", "ollama").lower())
# Ollama
ollama_base_url: str = field(default_factory=lambda: os.getenv("OLLAMA_BASE_URL", "http://localhost:11434/v1"))
ollama_model: str = field(default_factory=lambda: os.getenv("OLLAMA_MODEL", "llama3.2"))
ollama_embedding_model: str = field(default_factory=lambda: os.getenv("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text"))
# Azure OpenAI
azure_endpoint: str = field(default_factory=lambda: os.getenv("AZURE_OPENAI_ENDPOINT", ""))
azure_api_key: str = field(default_factory=lambda: os.getenv("AZURE_OPENAI_API_KEY", ""))
azure_api_version: str = field(default_factory=lambda: os.getenv("AZURE_OPENAI_API_VERSION", "2024-12-01-preview"))
azure_deployment: str = field(default_factory=lambda: os.getenv("AZURE_OPENAI_DEPLOYMENT", "gpt-4"))
azure_embedding_deployment: str = field(default_factory=lambda: os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT", "text-embedding-3-large"))
azure_tts_deployment: str = field(default_factory=lambda: os.getenv("AZURE_TTS_DEPLOYMENT", "tts"))
azure_stt_deployment: str = field(default_factory=lambda: os.getenv("AZURE_STT_DEPLOYMENT", "whisper"))
azure_realtime_deployment: str = field(default_factory=lambda: os.getenv("AZURE_REALTIME_DEPLOYMENT", "gpt-4o-realtime-preview"))
# OpenAI diretto / Direct OpenAI
openai_api_key: str = field(default_factory=lambda: os.getenv("OPENAI_API_KEY", ""))
openai_model: str = field(default_factory=lambda: os.getenv("OPENAI_MODEL", "gpt-4o-mini"))
openai_embedding_model: str = field(default_factory=lambda: os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small"))
settings = Settings()
# ------------------------------------------------------------------
# Client factory
# ------------------------------------------------------------------
def get_client() -> OpenAI | AzureOpenAI:
"""Restituisce un client OpenAI configurato per il provider scelto.
Returns an OpenAI client configured for the chosen provider."""
if settings.provider == "ollama":
return OpenAI(
base_url=settings.ollama_base_url,
api_key="ollama", # Ollama non richiede API key / Ollama doesn't need an API key
)
if settings.provider == "azure":
return AzureOpenAI(
azure_endpoint=settings.azure_endpoint,
api_key=settings.azure_api_key,
api_version=settings.azure_api_version,
)
if settings.provider == "openai":
return OpenAI(api_key=settings.openai_api_key)
raise ValueError(
f"Provider '{settings.provider}' non supportato. "
f"Usa: ollama, azure, openai. / "
f"Provider '{settings.provider}' not supported. "
f"Use: ollama, azure, openai."
)
def get_model() -> str:
"""Restituisce il nome del modello per il provider corrente.
Returns the model name for the current provider."""
if settings.provider == "ollama":
return settings.ollama_model
if settings.provider == "azure":
return settings.azure_deployment
if settings.provider == "openai":
return settings.openai_model
raise ValueError(f"Provider sconosciuto / Unknown provider: {settings.provider}")
def get_embedding_model() -> str:
"""Restituisce il modello di embedding per il provider corrente.
Returns the embedding model for the current provider."""
if settings.provider == "ollama":
return settings.ollama_embedding_model
if settings.provider == "azure":
return settings.azure_embedding_deployment
if settings.provider == "openai":
return settings.openai_embedding_model
raise ValueError(f"Provider sconosciuto / Unknown provider: {settings.provider}")
def describe_config() -> str:
"""Descrizione leggibile della configurazione corrente.
Human-readable description of the current configuration."""
model = get_model()
if settings.provider == "ollama":
return f"Ollama (local) — {model} @ {settings.ollama_base_url}"
if settings.provider == "azure":
return f"Azure OpenAI — {model} @ {settings.azure_endpoint}"
if settings.provider == "openai":
return f"OpenAI — {model}"
return f"Unknown provider: {settings.provider}"
# ------------------------------------------------------------------
# Quick test
# ------------------------------------------------------------------
if __name__ == "__main__":
print(f"Provider: {settings.provider}")
print(f"Config: {describe_config()}")
print(f"Model: {get_model()}")
print(f"Embed: {get_embedding_model()}")
client = get_client()
resp = client.chat.completions.create(
model=get_model(),
messages=[{"role": "user", "content": "Rispondi solo: ciao!"}],
max_tokens=10,
)
print(f"Test LLM: {resp.choices[0].message.content}")