Skip to content

Commit 5122775

Browse files
authored
Merge pull request game-by-virtuals#70 from game-by-virtuals/model_selection
Add functionality to allow llm model selection for GAME when creating agents
2 parents 0804c99 + 7f6de79 commit 5122775

File tree

8 files changed

+54
-14
lines changed

8 files changed

+54
-14
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ cd plugins/twitter
3232
pip install -e .
3333
```
3434

35+
## Model Selection
36+
The foundation models which power the GAME framework can also be confgiured and selected. This can be specified when creating agents or workers. The default model used is "Llama-3.1-405B-Instruct".
37+
38+
The models currently supported are:
39+
- "Llama-3.1-405B-Instruct" (default)
40+
- "Llama-3.3-70B-Instruct"
41+
- "DeepSeek-R1"
42+
- "DeepSeek-V3"
43+
- "Qwen-2.5-72B-Instruct"
44+
45+
**Note: If model is not specified in the API call (REST API level) or on the SDK level when creating the agents, the default (Llama-3.1-405B-Instruct) model will be used.
46+
3547
## Usage
3648
1. `game`:
3749
- Request for a GAME API key in the Game Console https://console.game.virtuals.io/

examples/game/test_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ def throw_furniture(object: str, **kwargs) -> Tuple[FunctionResultStatus, str, d
274274
agent_goal="Conquer the world by causing chaos.",
275275
agent_description="You are a mischievous master of chaos is very strong but with a very short attention span, and not so much brains",
276276
get_agent_state_fn=get_agent_state_fn,
277-
workers=[fruit_thrower, furniture_thrower]
277+
workers=[fruit_thrower, furniture_thrower],
278+
model_name="Llama-3.1-405B-Instruct"
278279
)
279280

280281
# # interact and instruct the worker to do something

examples/game/test_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def sit_on_object(object: str, **kwargs) -> Tuple[FunctionResultStatus, str, dic
107107
description="You are an evil NPC in a game.",
108108
instruction="Choose the evil-est actions.",
109109
get_state_fn=get_state_fn,
110-
action_space=action_space
110+
action_space=action_space,
111+
model_name="Llama-3.1-405B-Instruct"
111112
)
112113

113114
# interact and instruct the worker to do something

src/game_sdk/game/agent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(self,
105105
agent_description: str,
106106
get_agent_state_fn: Callable,
107107
workers: Optional[List[WorkerConfig]] = None,
108+
model_name: str = "Llama-3.1-405B-Instruct",
108109
):
109110

110111
if api_key.startswith("apt-"):
@@ -114,6 +115,8 @@ def __init__(self,
114115

115116
self._api_key: str = api_key
116117

118+
self._model_name: str = model_name
119+
117120
# checks
118121
if not self._api_key:
119122
raise ValueError("API key not set")
@@ -238,6 +241,7 @@ def _get_action(
238241
response = self.client.get_agent_action(
239242
agent_id=self.agent_id,
240243
data=data,
244+
model_name=self._model_name
241245
)
242246

243247
return ActionResponse.model_validate(response)

src/game_sdk/game/api.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import requests
2-
from typing import List, Dict
2+
from typing import List, Dict, Optional
3+
34

45
class GAMEClient:
56
def __init__(self, api_key: str):
@@ -13,7 +14,7 @@ def _get_access_token(self) -> str:
1314
response = requests.post(
1415
"https://api.virtuals.io/api/accesses/tokens",
1516
json={"data": {}},
16-
headers={"x-api-key": self.api_key}
17+
headers={"x-api-key": self.api_key},
1718
)
1819

1920
if response.status_code != 200:
@@ -22,12 +23,21 @@ def _get_access_token(self) -> str:
2223
response_json = response.json()
2324
return response_json["data"]["accessToken"]
2425

25-
def _post(self, endpoint: str, data: dict) -> dict:
26+
def _post(
27+
self, endpoint: str, data: dict, extra_headers: Optional[Dict[str, str]] = None
28+
) -> dict:
2629
"""
2730
Internal method to post data
2831
"""
2932
access_token = self._get_access_token()
3033

34+
# Default headers with Authorization
35+
headers = {"Authorization": f"Bearer {access_token}"}
36+
37+
# Merge additional headers if provided
38+
if extra_headers:
39+
headers.update(extra_headers)
40+
3141
response = requests.post(
3242
f"{self.base_url}/prompts",
3343
json={
@@ -40,7 +50,7 @@ def _post(self, endpoint: str, data: dict) -> dict:
4050
"data": data,
4151
},
4252
},
43-
headers={"Authorization": f"Bearer {access_token}"},
53+
headers=headers,
4454
)
4555

4656
if response.status_code != 200:
@@ -89,20 +99,28 @@ def set_worker_task(self, agent_id: str, task: str) -> Dict:
8999
data={"task": task},
90100
)
91101

92-
def get_worker_action(self, agent_id: str, submission_id: str, data: dict) -> Dict:
102+
def get_worker_action(
103+
self,
104+
agent_id: str,
105+
submission_id: str,
106+
data: dict,
107+
model_name: str,
108+
) -> Dict:
93109
"""
94110
Get worker actions (for standalone worker)
95111
"""
96112
return self._post(
97113
endpoint=f"/v2/agents/{agent_id}/tasks/{submission_id}/next",
98114
data=data,
115+
extra_headers={"model_name": model_name},
99116
)
100117

101-
def get_agent_action(self, agent_id: str, data: dict) -> Dict:
118+
def get_agent_action(self, agent_id: str, data: dict, model_name: str) -> Dict:
102119
"""
103120
Get agent actions/next step (for agent)
104121
"""
105122
return self._post(
106123
endpoint=f"/v2/agents/{agent_id}/actions",
107124
data=data,
125+
extra_headers={"model_name": model_name},
108126
)

src/game_sdk/game/api_v2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,13 @@ def set_worker_task(self, agent_id: str, task: str) -> Dict:
8484

8585
return response_json["data"]
8686

87-
def get_worker_action(self, agent_id: str, submission_id: str, data: dict) -> Dict:
87+
def get_worker_action(self, agent_id: str, submission_id: str, data: dict, model_name: str) -> Dict:
8888
"""
8989
API call to get worker actions (for standalone worker)
9090
"""
9191
response = requests.post(
9292
f"{self.base_url}/agents/{agent_id}/tasks/{submission_id}/next",
93-
headers=self.headers,
93+
headers=self.headers | {"model_name": model_name},
9494
json={
9595
"data": data
9696
}
@@ -103,13 +103,13 @@ def get_worker_action(self, agent_id: str, submission_id: str, data: dict) -> Di
103103

104104
return response_json["data"]
105105

106-
def get_agent_action(self, agent_id: str, data: dict) -> Dict:
106+
def get_agent_action(self, agent_id: str, data: dict, model_name: str) -> Dict:
107107
"""
108108
API call to get agent actions/next step (for agent)
109109
"""
110110
response = requests.post(
111111
f"{self.base_url}/agents/{agent_id}/actions",
112-
headers=self.headers,
112+
headers=self.headers | {"model_name": model_name},
113113
json={
114114
"data": data
115115
}

src/game_sdk/game/custom_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,4 +237,4 @@ class ActionResponse(BaseModel):
237237
"""
238238
action_type: ActionType
239239
agent_state: AgentStateResponse
240-
action_args: Optional[Dict[str, Any]] = None
240+
action_args: Optional[Dict[str, Any]] = None

src/game_sdk/game/worker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
action_space: List[Function],
5050
# specific additional instruction for the worker (PROMPT)
5151
instruction: Optional[str] = "",
52+
model_name: str = "Llama-3.1-405B-Instruct",
5253
):
5354

5455
if api_key.startswith("apt-"):
@@ -58,6 +59,8 @@ def __init__(
5859

5960
self._api_key: str = api_key
6061

62+
self._model_name: str = model_name
63+
6164
# checks
6265
if not self._api_key:
6366
raise ValueError("API key not set")
@@ -156,7 +159,8 @@ def _get_action(
156159
response = self.client.get_worker_action(
157160
self._agent_id,
158161
self._submission_id,
159-
data
162+
data,
163+
model_name=self._model_name
160164
)
161165

162166
return ActionResponse.model_validate(response)

0 commit comments

Comments
 (0)