Skip to content

Commit 6777241

Browse files
committed
Training working prototype
1 parent 61a0f12 commit 6777241

File tree

2 files changed

+897
-195
lines changed

2 files changed

+897
-195
lines changed

marooned_env/llm_interface.py

Lines changed: 91 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from models import Observation, Action, Position
1616
from config import ActionType, ResourceType, ShipComponent, MapLevel
1717

18-
# Ollama Teacher API Configuration
19-
OLLAMA_API_URL = "http://localhost:11434/api/chat"
20-
TEACHER_MODEL_NAME = "mixtral:8x22b"
18+
# vLLM Teacher API Configuration (OpenAI-compatible)
19+
VLLM_API_URL = "http://localhost:8000/v1/chat/completions"
20+
TEACHER_MODEL_NAME = "mistralai/Mixtral-8x7B-Instruct-v0.1"
2121

2222

2323
# ============================================================================
@@ -956,6 +956,69 @@ def observation_to_prompt(obs: Observation, include_role: bool = False, sailor_r
956956
return base_text + "\n" + action_instructions
957957

958958

959+
def observation_to_condensed_prompt(obs: Observation) -> str:
960+
"""
961+
Create a CONDENSED observation for teacher validation (reduces token usage).
962+
Removes verbose action lists since teacher already knows valid actions.
963+
964+
Args:
965+
obs: The observation object
966+
967+
Returns:
968+
Minimal observation text for teacher context
969+
"""
970+
# Essential game state only
971+
condensed = f"""━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
972+
GAME STATE (Day {obs.day}/100, Turn {obs.turn}/100)
973+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
974+
975+
YOUR STATUS:
976+
Position: {obs.position}
977+
Energy: {obs.energy}/100
978+
Backpack: {len(obs.backpack)}/20 items
979+
"""
980+
981+
# Visible resources (condensed) - use spatial_view
982+
if hasattr(obs.spatial_view, 'visible_resources') and obs.spatial_view.visible_resources:
983+
condensed += f"\nVISIBLE RESOURCES ({len(obs.spatial_view.visible_resources)}):\n"
984+
for res in list(obs.spatial_view.visible_resources)[:5]: # Show first 5 only
985+
condensed += f" {res.resource_id} ({res.resource_type.value}) at {res.position}\n"
986+
if len(obs.spatial_view.visible_resources) > 5:
987+
condensed += f" ... and {len(obs.spatial_view.visible_resources) - 5} more\n"
988+
989+
# Other sailors (condensed) - visible_sailors is a set of sailor IDs (strings)
990+
if hasattr(obs.spatial_view, 'visible_sailors') and obs.spatial_view.visible_sailors:
991+
condensed += f"\nOTHER SAILORS ({len(obs.spatial_view.visible_sailors)}):\n"
992+
for sailor_id in list(obs.spatial_view.visible_sailors)[:3]: # Show first 3 only
993+
energy_info = obs.all_sailors_energy.get(sailor_id, "?")
994+
# Get position from all_sailor_positions if available (traitor vision)
995+
if obs.all_sailor_positions and sailor_id in obs.all_sailor_positions:
996+
position = obs.all_sailor_positions[sailor_id]
997+
condensed += f" {sailor_id}: {position}, Energy {energy_info}/100\n"
998+
else:
999+
condensed += f" {sailor_id}: (nearby), Energy {energy_info}/100\n"
1000+
if len(obs.spatial_view.visible_sailors) > 3:
1001+
condensed += f" ... and {len(obs.spatial_view.visible_sailors) - 3} more\n"
1002+
1003+
# Ship progress (essential)
1004+
condensed += f"\nSHIP PROGRESS: {obs.ship_progress.total_percentage}%\n"
1005+
1006+
# Common inventory (condensed)
1007+
if obs.common_inventory:
1008+
condensed += f"COMMON INVENTORY: "
1009+
inv_items = []
1010+
for item in obs.common_inventory[:4]:
1011+
inv_items.append(f"{item.resource_type.value}={item.quantity}")
1012+
condensed += ", ".join(inv_items)
1013+
if len(obs.common_inventory) > 4:
1014+
condensed += f" +{len(obs.common_inventory) - 4} more"
1015+
condensed += "\n"
1016+
1017+
condensed += "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
1018+
1019+
return condensed
1020+
1021+
9591022
# ============================================================================
9601023
# TEACHER-GUIDED ACTION PARSING (Process Reward Modeling)
9611024
# ============================================================================
@@ -966,7 +1029,7 @@ def teacher_validate_student_output(
9661029
sailor_id: str
9671030
) -> Dict[str, Any]:
9681031
"""
969-
Send student LLM output to teacher (Ollama Mixtral) for validation and correction.
1032+
Send student LLM output to teacher (vLLM Mixtral) for validation and correction.
9701033
9711034
This is the CORE of process reward modeling:
9721035
- Student generates potentially malformed output
@@ -987,37 +1050,46 @@ def teacher_validate_student_output(
9871050
- valid: bool (was original output valid?)
9881051
- teacher_response: str (full teacher output for logging)
9891052
"""
990-
# Build teacher prompt with full observation context + student output
991-
full_observation_text = observation.to_text()
1053+
# Build teacher prompt with CONDENSED observation (reduce tokens)
1054+
condensed_observation = observation_to_condensed_prompt(observation)
9921055

9931056
user_prompt = f"""STUDENT OUTPUT:
9941057
{student_response}
9951058
9961059
GAME STATE:
997-
{full_observation_text}"""
1060+
{condensed_observation}"""
9981061

999-
# Query Ollama teacher API
1062+
# Query vLLM teacher API (OpenAI-compatible endpoint)
1063+
# Note: Mixtral uses Instruct format, combine system + user into single user message
1064+
combined_prompt = f"{TEACHER_SYSTEM_PROMPT}\n\n{user_prompt}"
1065+
10001066
payload = {
10011067
"model": TEACHER_MODEL_NAME,
10021068
"messages": [
1003-
{"role": "system", "content": TEACHER_SYSTEM_PROMPT},
1004-
{"role": "user", "content": user_prompt}
1069+
{"role": "user", "content": combined_prompt}
10051070
],
1006-
"stream": False,
1007-
"options": {
1008-
"temperature": 0.1,
1009-
"top_p": 1.0,
1010-
"num_predict": 200
1011-
}
1071+
"temperature": 0.1,
1072+
"top_p": 1.0,
1073+
"max_tokens": 200,
1074+
"stream": False
10121075
}
10131076

10141077
try:
1015-
response = requests.post(OLLAMA_API_URL, json=payload, timeout=30)
1078+
response = requests.post(VLLM_API_URL, json=payload, timeout=30)
10161079
response.raise_for_status()
10171080
data = response.json()
1018-
teacher_response = data["message"]["content"].strip()
1081+
teacher_response = data["choices"][0]["message"]["content"].strip()
1082+
except requests.exceptions.HTTPError as e:
1083+
# HTTP error with details
1084+
error_detail = ""
1085+
try:
1086+
error_detail = f" - {response.json()}"
1087+
except:
1088+
error_detail = f" - {response.text[:200]}"
1089+
print(f"⚠️ Teacher API error: {e}{error_detail}")
1090+
teacher_response = f"VALID: NO\nACTION: WAIT\nPENALTY: -2.0\nCRITIQUE: Teacher API error - defaulting to WAIT"
10191091
except requests.exceptions.RequestException as e:
1020-
# Fallback if Ollama server unreachable
1092+
# Other connection errors
10211093
print(f"⚠️ Teacher API error: {e}")
10221094
teacher_response = f"VALID: NO\nACTION: WAIT\nPENALTY: -2.0\nCRITIQUE: Teacher API unavailable - defaulting to WAIT"
10231095

0 commit comments

Comments
 (0)