Skip to content

Commit 4d4cff0

Browse files
Harrison/cohere experimental (#638)
Co-authored-by: inyourhead <[email protected]>
1 parent 5c97f70 commit 4d4cff0

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

langchain/llms/ai21.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ class AI21(LLM, BaseModel):
6464

6565
ai21_api_key: Optional[str] = None
6666

67+
base_url: Optional[str] = None
68+
"""Base url to use, if None decides based on model name."""
69+
6770
class Config:
6871
"""Configuration for this pydantic object."""
6972

@@ -118,8 +121,15 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
118121
"""
119122
if stop is None:
120123
stop = []
124+
if self.base_url is not None:
125+
base_url = self.base_url
126+
else:
127+
if self.model in ("j1-grande-instruct",):
128+
base_url = "https://api.ai21.com/studio/v1/experimental"
129+
else:
130+
base_url = "https://api.ai21.com/studio/v1"
121131
response = requests.post(
122-
url=f"https://api.ai21.com/studio/v1/{self.model}/complete",
132+
url=f"{base_url}/{self.model}/complete",
123133
headers={"Authorization": f"Bearer {self.ai21_api_key}"},
124134
json={"prompt": prompt, "stopSequences": stop, **self._default_params},
125135
)

tests/integration_tests/llms/test_ai21.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ def test_ai21_call() -> None:
1313
assert isinstance(output, str)
1414

1515

16+
def test_ai21_call_experimental() -> None:
17+
"""Test valid call to ai21 with an experimental model."""
18+
llm = AI21(maxTokens=10, model="j1-grande-instruct")
19+
output = llm("Say foo:")
20+
assert isinstance(output, str)
21+
22+
1623
def test_saving_loading_llm(tmp_path: Path) -> None:
1724
"""Test saving/loading an AI21 LLM."""
1825
llm = AI21(maxTokens=10)

0 commit comments

Comments
 (0)