Skip to content

Commit 64d9bd1

Browse files
committed
Add Python gym interface with combat-focused rewards
- Add QuakeLiveEnv with proper episode reset and bot management - Add combat-focused reward system (frag +500, damage +2/pt) - Add engagement reward to prevent passive play - Add performance metrics tracking with deepcopy fix - Add in_fov and time_to_spawn_ms to state objects - Add obs_mode parameter for partial observability - Add thrash guard to prevent reset loops
1 parent 0ba3330 commit 64d9bd1

File tree

6 files changed

+308
-96
lines changed

6 files changed

+308
-96
lines changed

QuakeLiveInterface/client.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ def __init__(self, redis_host='localhost', redis_port=6379, redis_db=0):
2121

2222
def update_game_state(self):
2323
"""
24-
Checks for a new game state message from Redis and updates the local game state.
24+
Gets the latest game state from Redis and updates the local game state.
25+
Uses GET on ql:agent:last_state for reliable polling instead of pubsub.
2526
"""
26-
message = self.connection.get_message(self.game_state_pubsub)
27-
if message:
28-
self.game_state.update_from_redis(message)
27+
# Poll the stored state instead of using pubsub (more reliable)
28+
state_data = self.connection.get('ql:agent:last_state')
29+
if state_data:
30+
self.game_state.update_from_redis(state_data)
2931
return True
3032
return False
3133

@@ -103,6 +105,10 @@ def stop_demo_recording(self):
103105
"""Stops recording a demo on the server."""
104106
self.send_admin_command('stop_demo_record')
105107

108+
def kick_all_bots(self):
109+
"""Kicks all bots from the server."""
110+
self.send_admin_command('kickbots')
111+
106112
# Other getters
107113
def get_game_state(self):
108114
return self.game_state

QuakeLiveInterface/connection.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,26 @@ def get_message(self, pubsub, timeout: float = 1.0):
162162
logger.error(f"Error getting message: {e}")
163163
raise
164164

165+
def get(self, key: str):
166+
"""
167+
Gets a value from Redis by key.
168+
169+
Args:
170+
key: The key to retrieve.
171+
Returns:
172+
The value, or None if key doesn't exist.
173+
"""
174+
try:
175+
self._ensure_connected()
176+
return self.redis.get(key)
177+
except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError) as e:
178+
logger.warning(f"Get failed, attempting reconnect: {e}")
179+
self.reconnect()
180+
return self.redis.get(key)
181+
except redis.exceptions.RedisError as e:
182+
logger.error(f"Error getting key {key}: {e}")
183+
raise
184+
165185
def close(self):
166186
"""
167187
Closes the Redis connection.

QuakeLiveInterface/env.py

Lines changed: 101 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,21 @@ def __init__(self, redis_host='localhost', redis_port=6379, redis_db=0,
4444
max_health=200, max_armor=200, map_dims=(4000, 4000, 1000),
4545
max_velocity=800, max_ammo=200, num_items=10, num_opponents=3,
4646
max_episode_steps=1000, demo_dir=None, weapon_list=None,
47-
view_sensitivity=VIEW_SENSITIVITY):
47+
view_sensitivity=VIEW_SENSITIVITY, obs_mode='oracle',
48+
agent_bot_name=None, agent_bot_skill=5):
49+
"""
50+
Args:
51+
obs_mode: Observation mode for opponent visibility.
52+
'oracle' - Always include all opponent features (full information)
53+
'human' - Mask opponent features when not in agent's FOV (partial observability)
54+
agent_bot_name: If set, the agent is a bot that will be re-added after reset.
55+
agent_bot_skill: Skill level for the agent bot (1-5).
56+
"""
4857
super(QuakeLiveEnv, self).__init__()
4958

59+
if obs_mode not in ('oracle', 'human'):
60+
raise ValueError(f"obs_mode must be 'oracle' or 'human', got '{obs_mode}'")
61+
5062
self.client = QuakeLiveClient(redis_host, redis_port, redis_db)
5163
self.game_state = GameState()
5264
self.reward_system = RewardSystem()
@@ -56,6 +68,10 @@ def __init__(self, redis_host='localhost', redis_port=6379, redis_db=0,
5668
self.max_episode_steps = max_episode_steps
5769
self.demo_dir = demo_dir
5870
self.view_sensitivity = view_sensitivity
71+
self.obs_mode = obs_mode
72+
self.agent_bot_name = agent_bot_name
73+
self.agent_bot_skill = agent_bot_skill
74+
5975

6076
# MultiDiscrete action space for universal RL compatibility
6177
# [forward/back, left/right, jump/crouch, attack, look_pitch, look_yaw]
@@ -114,7 +130,30 @@ def step(self, action):
114130
# Log performance metrics
115131
self.performance_tracker.log_step(self.game_state, action)
116132

117-
return obs, reward, terminated, truncated, {}
133+
# Build info dict with episode metrics on termination/truncation
134+
# Use 'terminal_info' key to avoid VecMonitor overwriting 'episode'
135+
info = {}
136+
if terminated or truncated:
137+
tracker = self.performance_tracker
138+
info['terminal_info'] = {
139+
'damage_dealt': tracker.damage_dealt,
140+
'damage_taken': tracker.damage_taken,
141+
'frags': tracker.kills,
142+
'deaths': tracker.deaths,
143+
'frag_diff': tracker.kills - tracker.deaths,
144+
'shots_fired': tracker.shots_fired,
145+
'hits': tracker.successful_hits,
146+
'accuracy': (tracker.successful_hits / tracker.shots_fired * 100) if tracker.shots_fired > 0 else 0,
147+
'health_pickups': tracker.items_collected.get('Health', 0),
148+
'armor_pickups': tracker.items_collected.get('Armor', 0),
149+
'distance_traveled': tracker.total_distance_traveled,
150+
}
151+
# Quick validation print
152+
logger.info(f"Episode {self.episode_num} end: frags={tracker.kills} deaths={tracker.deaths} "
153+
f"dmg_dealt={tracker.damage_dealt} dmg_taken={tracker.damage_taken} "
154+
f"accuracy={info['terminal_info']['accuracy']:.1f}%")
155+
156+
return obs, reward, terminated, truncated, info
118157

119158
def reset(self, seed=None, options=None, reset_timeout=15.0):
120159
"""
@@ -145,11 +184,52 @@ def reset(self, seed=None, options=None, reset_timeout=15.0):
145184
self.performance_tracker.reset()
146185
self.game_state = GameState() # Reset game state
147186

148-
# Send a command to restart the game
149-
logger.info(f"Episode {self.episode_num}: Sending command to restart game.")
150-
self.client.send_admin_command('restart_game')
151-
152-
# Wait for the game to restart and for the agent to be alive
187+
import time as time_module
188+
189+
# First, check current roster before doing anything disruptive
190+
self.client.update_game_state()
191+
current_state = self.client.get_game_state()
192+
num_opponents = len(current_state.get_opponents()) if current_state else 0
193+
agent = current_state.get_agent() if current_state else None
194+
195+
# Log roster for debugging
196+
if agent:
197+
logger.info(f"Episode {self.episode_num}: Roster check - Agent={agent.name}, Opponents={num_opponents}")
198+
else:
199+
logger.info(f"Episode {self.episode_num}: No agent found, Opponents={num_opponents}")
200+
201+
# Only fix roster if it's actually wrong
202+
roster_correct = agent is not None and num_opponents == 1
203+
if self.agent_bot_name and not roster_correct:
204+
logger.info(f"Episode {self.episode_num}: Roster incorrect, fixing (once)...")
205+
206+
# Kick ALL bots first
207+
self.client.send_admin_command('kickbots')
208+
time_module.sleep(2.0)
209+
210+
# Add agent bot
211+
logger.info(f"Episode {self.episode_num}: Adding agent bot: {self.agent_bot_name}")
212+
self.client.send_admin_command('addbot', {
213+
'name': self.agent_bot_name,
214+
'skill': self.agent_bot_skill
215+
})
216+
time_module.sleep(2.0)
217+
218+
# Add opponent bot (different from agent)
219+
opponent_name = 'crash' if self.agent_bot_name.lower() != 'crash' else 'doom'
220+
logger.info(f"Episode {self.episode_num}: Adding opponent bot: {opponent_name}")
221+
self.client.send_admin_command('addbot', {
222+
'name': opponent_name,
223+
'skill': self.agent_bot_skill
224+
})
225+
time_module.sleep(3.0)
226+
else:
227+
# Roster is fine - just restart the match (fast, no kicks)
228+
logger.info(f"Episode {self.episode_num}: Roster OK, restarting match.")
229+
self.client.send_admin_command('restart_game')
230+
time_module.sleep(1.5) # Shorter wait since no bot changes
231+
232+
# Wait for the game to be ready and for the agent to be alive
153233
start_time = time.time()
154234
while time.time() - start_time < reset_timeout:
155235
if self.client.update_game_state():
@@ -219,6 +299,10 @@ def _get_observation(self):
219299
for i, opp in enumerate(opponents[:self.NUM_OPPONENTS]):
220300
start_idx = i * 11
221301
end_idx = start_idx + 11
302+
# In 'human' mode, only include opponents that are in FOV
303+
if self.obs_mode == 'human' and not getattr(opp, 'in_fov', True):
304+
# Opponent not in FOV - leave as zeros (masked)
305+
continue
222306
opponent_feats[start_idx:end_idx] = self._get_player_features(opp)
223307

224308

@@ -286,10 +370,16 @@ def _get_item_features(self, items):
286370
for i, item in enumerate(items):
287371
if i >= self.NUM_ITEMS:
288372
break
289-
pos = self._normalize_pos(item['position'])
290-
is_available = 1 if item['is_available'] else 0
291-
spawn_time = item['spawn_time'] / 30000.0 # Normalize by 30 seconds
292-
features[i*5 : i*5 + 5] = [*pos, is_available, spawn_time]
373+
# Handle both Item objects and dicts for backwards compatibility
374+
if hasattr(item, 'position'):
375+
pos = self._normalize_pos(item.position)
376+
is_available = 1 if item.is_available else 0
377+
time_to_spawn = getattr(item, 'time_to_spawn_ms', 0) / 30000.0
378+
else:
379+
pos = self._normalize_pos(item['position'])
380+
is_available = 1 if item['is_available'] else 0
381+
time_to_spawn = item.get('time_to_spawn_ms', item.get('spawn_time', 0)) / 30000.0
382+
features[i*5 : i*5 + 5] = [*pos, is_available, time_to_spawn]
293383
return features
294384

295385
@staticmethod

QuakeLiveInterface/metrics.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import numpy as np
3+
import copy
34

45
logger = logging.getLogger(__name__)
56

@@ -48,12 +49,17 @@ def log_step(self, current_state, action):
4849
self.shots_fired += 1
4950

5051
if self.previous_state is None:
51-
self.previous_state = current_state
52+
self.previous_state = copy.deepcopy(current_state)
5253
return
5354

5455
prev_agent = self.previous_state.get_agent()
5556
curr_agent = current_state.get_agent()
5657

58+
# Safety check - need both agents to compare
59+
if prev_agent is None or curr_agent is None:
60+
self.previous_state = copy.deepcopy(current_state)
61+
return
62+
5763
# Damage taken
5864
health_diff = prev_agent.health - curr_agent.health
5965
armor_diff = prev_agent.armor - curr_agent.armor
@@ -91,7 +97,7 @@ def log_step(self, current_state, action):
9197
curr_pos = np.array(list(curr_agent.position.values()))
9298
self.total_distance_traveled += np.linalg.norm(curr_pos - prev_pos)
9399

94-
self.previous_state = current_state
100+
self.previous_state = copy.deepcopy(current_state)
95101

96102
def log_episode(self, episode_num):
97103
"""Logs the summary of the episode's performance."""

0 commit comments

Comments
 (0)