@@ -44,9 +44,21 @@ def __init__(self, redis_host='localhost', redis_port=6379, redis_db=0,
4444 max_health = 200 , max_armor = 200 , map_dims = (4000 , 4000 , 1000 ),
4545 max_velocity = 800 , max_ammo = 200 , num_items = 10 , num_opponents = 3 ,
4646 max_episode_steps = 1000 , demo_dir = None , weapon_list = None ,
47- view_sensitivity = VIEW_SENSITIVITY ):
47+ view_sensitivity = VIEW_SENSITIVITY , obs_mode = 'oracle' ,
48+ agent_bot_name = None , agent_bot_skill = 5 ):
49+ """
50+ Args:
51+ obs_mode: Observation mode for opponent visibility.
52+ 'oracle' - Always include all opponent features (full information)
53+ 'human' - Mask opponent features when not in agent's FOV (partial observability)
54+ agent_bot_name: If set, the agent is a bot that will be re-added after reset.
55+ agent_bot_skill: Skill level for the agent bot (1-5).
56+ """
4857 super (QuakeLiveEnv , self ).__init__ ()
4958
59+ if obs_mode not in ('oracle' , 'human' ):
60+ raise ValueError (f"obs_mode must be 'oracle' or 'human', got '{ obs_mode } '" )
61+
5062 self .client = QuakeLiveClient (redis_host , redis_port , redis_db )
5163 self .game_state = GameState ()
5264 self .reward_system = RewardSystem ()
@@ -56,6 +68,10 @@ def __init__(self, redis_host='localhost', redis_port=6379, redis_db=0,
5668 self .max_episode_steps = max_episode_steps
5769 self .demo_dir = demo_dir
5870 self .view_sensitivity = view_sensitivity
71+ self .obs_mode = obs_mode
72+ self .agent_bot_name = agent_bot_name
73+ self .agent_bot_skill = agent_bot_skill
74+
5975
6076 # MultiDiscrete action space for universal RL compatibility
6177 # [forward/back, left/right, jump/crouch, attack, look_pitch, look_yaw]
@@ -114,7 +130,30 @@ def step(self, action):
114130 # Log performance metrics
115131 self .performance_tracker .log_step (self .game_state , action )
116132
117- return obs , reward , terminated , truncated , {}
133+ # Build info dict with episode metrics on termination/truncation
134+ # Use 'terminal_info' key to avoid VecMonitor overwriting 'episode'
135+ info = {}
136+ if terminated or truncated :
137+ tracker = self .performance_tracker
138+ info ['terminal_info' ] = {
139+ 'damage_dealt' : tracker .damage_dealt ,
140+ 'damage_taken' : tracker .damage_taken ,
141+ 'frags' : tracker .kills ,
142+ 'deaths' : tracker .deaths ,
143+ 'frag_diff' : tracker .kills - tracker .deaths ,
144+ 'shots_fired' : tracker .shots_fired ,
145+ 'hits' : tracker .successful_hits ,
146+ 'accuracy' : (tracker .successful_hits / tracker .shots_fired * 100 ) if tracker .shots_fired > 0 else 0 ,
147+ 'health_pickups' : tracker .items_collected .get ('Health' , 0 ),
148+ 'armor_pickups' : tracker .items_collected .get ('Armor' , 0 ),
149+ 'distance_traveled' : tracker .total_distance_traveled ,
150+ }
151+ # Quick validation print
152+ logger .info (f"Episode { self .episode_num } end: frags={ tracker .kills } deaths={ tracker .deaths } "
153+ f"dmg_dealt={ tracker .damage_dealt } dmg_taken={ tracker .damage_taken } "
154+ f"accuracy={ info ['terminal_info' ]['accuracy' ]:.1f} %" )
155+
156+ return obs , reward , terminated , truncated , info
118157
119158 def reset (self , seed = None , options = None , reset_timeout = 15.0 ):
120159 """
@@ -145,11 +184,52 @@ def reset(self, seed=None, options=None, reset_timeout=15.0):
145184 self .performance_tracker .reset ()
146185 self .game_state = GameState () # Reset game state
147186
148- # Send a command to restart the game
149- logger .info (f"Episode { self .episode_num } : Sending command to restart game." )
150- self .client .send_admin_command ('restart_game' )
151-
152- # Wait for the game to restart and for the agent to be alive
187+ import time as time_module
188+
189+ # First, check current roster before doing anything disruptive
190+ self .client .update_game_state ()
191+ current_state = self .client .get_game_state ()
192+ num_opponents = len (current_state .get_opponents ()) if current_state else 0
193+ agent = current_state .get_agent () if current_state else None
194+
195+ # Log roster for debugging
196+ if agent :
197+ logger .info (f"Episode { self .episode_num } : Roster check - Agent={ agent .name } , Opponents={ num_opponents } " )
198+ else :
199+ logger .info (f"Episode { self .episode_num } : No agent found, Opponents={ num_opponents } " )
200+
201+ # Only fix roster if it's actually wrong
202+ roster_correct = agent is not None and num_opponents == 1
203+ if self .agent_bot_name and not roster_correct :
204+ logger .info (f"Episode { self .episode_num } : Roster incorrect, fixing (once)..." )
205+
206+ # Kick ALL bots first
207+ self .client .send_admin_command ('kickbots' )
208+ time_module .sleep (2.0 )
209+
210+ # Add agent bot
211+ logger .info (f"Episode { self .episode_num } : Adding agent bot: { self .agent_bot_name } " )
212+ self .client .send_admin_command ('addbot' , {
213+ 'name' : self .agent_bot_name ,
214+ 'skill' : self .agent_bot_skill
215+ })
216+ time_module .sleep (2.0 )
217+
218+ # Add opponent bot (different from agent)
219+ opponent_name = 'crash' if self .agent_bot_name .lower () != 'crash' else 'doom'
220+ logger .info (f"Episode { self .episode_num } : Adding opponent bot: { opponent_name } " )
221+ self .client .send_admin_command ('addbot' , {
222+ 'name' : opponent_name ,
223+ 'skill' : self .agent_bot_skill
224+ })
225+ time_module .sleep (3.0 )
226+ else :
227+ # Roster is fine - just restart the match (fast, no kicks)
228+ logger .info (f"Episode { self .episode_num } : Roster OK, restarting match." )
229+ self .client .send_admin_command ('restart_game' )
230+ time_module .sleep (1.5 ) # Shorter wait since no bot changes
231+
232+ # Wait for the game to be ready and for the agent to be alive
153233 start_time = time .time ()
154234 while time .time () - start_time < reset_timeout :
155235 if self .client .update_game_state ():
@@ -219,6 +299,10 @@ def _get_observation(self):
219299 for i , opp in enumerate (opponents [:self .NUM_OPPONENTS ]):
220300 start_idx = i * 11
221301 end_idx = start_idx + 11
302+ # In 'human' mode, only include opponents that are in FOV
303+ if self .obs_mode == 'human' and not getattr (opp , 'in_fov' , True ):
304+ # Opponent not in FOV - leave as zeros (masked)
305+ continue
222306 opponent_feats [start_idx :end_idx ] = self ._get_player_features (opp )
223307
224308
@@ -286,10 +370,16 @@ def _get_item_features(self, items):
286370 for i , item in enumerate (items ):
287371 if i >= self .NUM_ITEMS :
288372 break
289- pos = self ._normalize_pos (item ['position' ])
290- is_available = 1 if item ['is_available' ] else 0
291- spawn_time = item ['spawn_time' ] / 30000.0 # Normalize by 30 seconds
292- features [i * 5 : i * 5 + 5 ] = [* pos , is_available , spawn_time ]
373+ # Handle both Item objects and dicts for backwards compatibility
374+ if hasattr (item , 'position' ):
375+ pos = self ._normalize_pos (item .position )
376+ is_available = 1 if item .is_available else 0
377+ time_to_spawn = getattr (item , 'time_to_spawn_ms' , 0 ) / 30000.0
378+ else :
379+ pos = self ._normalize_pos (item ['position' ])
380+ is_available = 1 if item ['is_available' ] else 0
381+ time_to_spawn = item .get ('time_to_spawn_ms' , item .get ('spawn_time' , 0 )) / 30000.0
382+ features [i * 5 : i * 5 + 5 ] = [* pos , is_available , time_to_spawn ]
293383 return features
294384
295385 @staticmethod
0 commit comments