Skip to content

Commit 82e39fa

Browse files
committed
Implemented Google and OpenAI APIs, OpenAI has passed, Gemini has not.
1 parent 339a8d2 commit 82e39fa

File tree

9 files changed

+1318
-49
lines changed

9 files changed

+1318
-49
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ htmlcov/
3333

3434

3535
/uv.lock
36+
/.env

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "reptrace"
3-
version = "0.1.0b1"
3+
version = "0.1.1b1"
44
description = "Extract LLM DNA vectors — low-dimensional representations that capture functional behavior and model evolution."
55
authors = [{ name = "RepTrace Project" }]
66
license = { file = "LICENSE" }
@@ -30,6 +30,8 @@ dependencies = [
3030
"tqdm>=4.65.0",
3131
"pyyaml>=6.0",
3232
"wonderwords>=2.2.0",
33+
"openai>=1.0.0",
34+
"tiktoken>=0.7.0",
3335
]
3436

3537
[project.optional-dependencies]

src/reptrace/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"calc_dna_batch",
1010
]
1111

12-
__version__ = "0.1.0"
12+
__version__ = "0.1.1b1"
1313

1414

1515
def __getattr__(name: str):

src/reptrace/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def parse_arguments(argv: Optional[Iterable[str]] = None) -> argparse.Namespace:
8686
"--model-type",
8787
type=str,
8888
default="auto",
89-
choices=["auto", "huggingface", "openai", "anthropic"],
89+
choices=["auto", "huggingface", "openai", "gemini"],
9090
)
9191

9292
# Dataset and probes

src/reptrace/core/extraction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
6161
"--model-type",
6262
type=str,
6363
default="auto",
64-
choices=["auto", "huggingface", "openai", "anthropic"],
64+
choices=["auto", "huggingface", "openai", "gemini"],
6565
help="Type of model to load"
6666
)
6767

src/reptrace/models/ModelLoader.py

Lines changed: 65 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99

1010
import torch
11-
from .ModelWrapper import (LLMWrapper, HuggingFaceWrapper, OpenAIWrapper, AnthropicWrapper,
11+
from .ModelWrapper import (LLMWrapper, HuggingFaceWrapper, OpenAIWrapper, GeminiWrapper,
1212
DecoderOnlyWrapper, EncoderOnlyWrapper, EncoderDecoderWrapper)
1313

1414

@@ -31,7 +31,7 @@ def load_model(
3131
3232
Args:
3333
model_path_or_name: Path to local model or HuggingFace model name
34-
model_type: Type of model ("auto", "huggingface", "openai", "anthropic")
34+
model_type: Type of model ("auto", "huggingface", "openai", "gemini", "anthropic")
3535
device: Device for computation
3636
**kwargs: Additional arguments for model loading
3737
@@ -48,29 +48,32 @@ def load_model(
4848
return self._load_huggingface_model(model_path_or_name, device, **kwargs)
4949
elif model_type == "openai":
5050
return self._load_openai_model(model_path_or_name, **kwargs)
51-
elif model_type == "anthropic":
52-
return self._load_anthropic_model(model_path_or_name, **kwargs)
51+
elif model_type == "gemini":
52+
return self._load_gemini_model(model_path_or_name, **kwargs)
5353
else:
54-
raise ValueError(f"Unsupported model type: {model_type}")
54+
raise ValueError(f"Unsupported model type: {model_type}. Supported: huggingface, openai, gemini")
5555

5656
def _detect_model_type(self, model_path_or_name: str) -> str:
5757
"""Auto-detect model type based on path/name patterns."""
58-
# Check for OpenAI model names
59-
openai_models = [
60-
"gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-4o",
61-
"text-davinci-003", "text-curie-001", "text-babbage-001"
58+
# Check for OpenAI model names (including newer models)
59+
openai_prefixes = [
60+
"gpt-3.5", "gpt-4", "gpt-4o", "gpt-4-turbo",
61+
"o1-", "o3-", # Reasoning models
62+
"text-davinci", "text-curie", "text-babbage", "text-ada",
6263
]
6364

64-
if any(model_path_or_name.startswith(name) for name in openai_models):
65+
model_lower = model_path_or_name.lower()
66+
if any(model_lower.startswith(prefix) for prefix in openai_prefixes):
6567
return "openai"
66-
67-
# Check for Anthropic model names
68-
anthropic_models = [
69-
"claude-3", "claude-2", "claude-instant"
68+
69+
# Check for Google Gemini model names
70+
gemini_prefixes = [
71+
"gemini-",
72+
"models/gemini-",
73+
"gemini-pro", # Older naming
7074
]
71-
72-
if any(model_path_or_name.startswith(name) for name in anthropic_models):
73-
return "anthropic"
75+
if any(model_lower.startswith(prefix) for prefix in gemini_prefixes):
76+
return "gemini"
7477

7578
# Check if it's a local path
7679
if os.path.exists(model_path_or_name):
@@ -206,35 +209,56 @@ def _load_openai_model(
206209
"""Load OpenAI model."""
207210
# Get API key from environment if not provided
208211
if api_key is None:
209-
api_key = os.getenv("OPENAI_API_KEY")
212+
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("APIKEY_OPENAI")
210213

211214
if api_key is None:
212215
raise ValueError("OpenAI API key required. Set OPENAI_API_KEY environment variable.")
216+
217+
allowed_kwargs = {
218+
"batch_poll_interval_seconds",
219+
"batch_timeout_seconds",
220+
"batch_max_requests",
221+
"prefer_batch_api",
222+
}
223+
wrapper_kwargs = {key: value for key, value in kwargs.items() if key in allowed_kwargs}
213224

214225
return OpenAIWrapper(
215226
model_name=model_name,
216227
api_key=api_key,
217-
**kwargs
228+
**wrapper_kwargs
218229
)
219-
220-
def _load_anthropic_model(
230+
231+
def _load_gemini_model(
221232
self,
222233
model_name: str,
223234
api_key: Optional[str] = None,
224235
**kwargs
225-
) -> AnthropicWrapper:
226-
"""Load Anthropic model."""
227-
# Get API key from environment if not provided
236+
) -> GeminiWrapper:
237+
"""Load Gemini model."""
228238
if api_key is None:
229-
api_key = os.getenv("ANTHROPIC_API_KEY")
230-
239+
api_key = (
240+
os.getenv("GEMINI_API_KEY")
241+
or os.getenv("GOOGLE_API_KEY")
242+
or os.getenv("APIKEY_GOOGLE")
243+
)
244+
231245
if api_key is None:
232-
raise ValueError("Anthropic API key required. Set ANTHROPIC_API_KEY environment variable.")
233-
234-
return AnthropicWrapper(
246+
raise ValueError("Gemini API key required. Set GEMINI_API_KEY or GOOGLE_API_KEY.")
247+
248+
allowed_kwargs = {
249+
"api_base",
250+
"batch_poll_interval_seconds",
251+
"batch_timeout_seconds",
252+
"batch_max_requests",
253+
"batch_max_payload_bytes",
254+
"prefer_batch_api",
255+
}
256+
wrapper_kwargs = {key: value for key, value in kwargs.items() if key in allowed_kwargs}
257+
258+
return GeminiWrapper(
235259
model_name=model_name,
236260
api_key=api_key,
237-
**kwargs
261+
**wrapper_kwargs
238262
)
239263

240264
def _is_unsupported_model(self, model_name: str) -> bool:
@@ -272,8 +296,8 @@ def list_available_models(self, model_type: str = "huggingface") -> Dict[str, An
272296
return self._list_huggingface_models()
273297
elif model_type == "openai":
274298
return self._list_openai_models()
275-
elif model_type == "anthropic":
276-
return self._list_anthropic_models()
299+
elif model_type == "gemini":
300+
return self._list_gemini_models()
277301
else:
278302
return {}
279303

@@ -333,17 +357,15 @@ def _list_openai_models(self) -> Dict[str, Any]:
333357
"text-ada-001"
334358
]
335359
}
336-
337-
def _list_anthropic_models(self) -> Dict[str, Any]:
338-
"""List available Anthropic models."""
360+
361+
def _list_gemini_models(self) -> Dict[str, Any]:
362+
"""List available Gemini models."""
339363
return {
340364
"chat_models": [
341-
"claude-3-opus-20240229",
342-
"claude-3-sonnet-20240229",
343-
"claude-3-haiku-20240307",
344-
"claude-2.1",
345-
"claude-2.0",
346-
"claude-instant-1.2"
365+
"gemini-2.0-flash",
366+
"gemini-2.0-flash-lite",
367+
"gemini-1.5-pro",
368+
"gemini-1.5-flash",
347369
]
348370
}
349371

@@ -367,7 +389,7 @@ def get_model_info(self, model_path_or_name: str) -> Dict[str, Any]:
367389

368390
if model_type == "huggingface":
369391
info.update(self._get_huggingface_info(model_path_or_name))
370-
elif model_type in ["openai", "anthropic"]:
392+
elif model_type in ["openai", "gemini", "anthropic"]:
371393
info.update({"requires_api_key": True})
372394

373395
return info
@@ -404,7 +426,7 @@ def load_model(
404426
405427
Args:
406428
model_path_or_name: Model identifier
407-
model_type: Model type ("auto", "huggingface", "openai", "anthropic")
429+
model_type: Model type ("auto", "huggingface", "openai", "gemini", "anthropic")
408430
device: Computation device
409431
config_dict: Model configuration dictionary
410432
**kwargs: Additional model loading arguments

0 commit comments

Comments
 (0)