Skip to content

Commit 1b92707

Browse files
vezhnickcopybara-github
authored andcommitted
Improve observation queue handling and add scene-aware event delivery.
MakeObservation improvements: - Extract ObservationQueue into standalone class with thread-safe operations - Add external_queue parameter to share queue across multiple GMs - Add allow_llm_fallback parameter to control LLM observation generation - Simplify internal implementation by always using ObservationQueue SendEventToRelevantPlayers improvements: - Add player_filter callback parameter to filter which players receive events - Decouples the component from scene-specific logic; callers provide any filtering function (e.g., scene_tracker.get_participants) Prefab updates (dialogic_and_dramaturgic, game_theoretic_and_dramaturgic, physically_situated_and_dramaturgic): - Add external_queue parameter for cross-GM observation persistence - Use player_filter callback for scene-aware event delivery PiperOrigin-RevId: 863371470 Change-Id: Ida87d5a917cf2316ebac66d989d9bc49b73459cf
1 parent 28cf0ef commit 1b92707

File tree

5 files changed

+149
-77
lines changed

5 files changed

+149
-77
lines changed

concordia/components/game_master/event_resolution.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def __init__(
353353
player_names: Sequence[str],
354354
make_observation_component_key: str,
355355
components: Sequence[str] = (),
356+
player_filter: Callable[[], Sequence[str]] | None = None,
356357
pre_act_label: str = DEFAULT_SEND_PRE_ACT_VALUES_TO_PLAYERS_PRE_ACT_LABEL,
357358
):
358359
"""Initializes a component that sends component pre-act values to players.
@@ -363,11 +364,15 @@ def __init__(
363364
364365
Args:
365366
model: The language model to use for the component.
366-
player_names: Names of players.
367+
player_names: Names of players (used if no player_filter is supplied).
367368
make_observation_component_key: the key for a MakeObservation component to
368369
add the pre-act values to the queue of events to observe.
369370
components: Keys of components to condition whether the event is relevant.
370371
If empty, all events are relevant.
372+
player_filter: Optional callback that returns the list of player names
373+
that should receive observations. When provided, this function is called
374+
during post_act to determine which players get the event. This can be
375+
used to filter by scene participants or any other criteria.
371376
pre_act_label: Prefix to add to the output of the component when called in
372377
`pre_act`.
373378
@@ -380,6 +385,7 @@ def __init__(
380385
self._player_names = player_names
381386
self._pre_act_label = pre_act_label
382387
self._components = components
388+
self._player_filter = player_filter
383389
self._queue = {}
384390
self._last_action_spec = None
385391
self._optional_make_observation_component_key = (
@@ -437,7 +443,12 @@ def post_act(
437443
prompt = interactive_document.InteractiveDocument(self._model)
438444
proceed = True
439445

440-
for active_entity_name in self._player_names:
446+
# Use player_filter if provided, otherwise all players
447+
relevant_players = self._player_names
448+
if self._player_filter:
449+
relevant_players = self._player_filter()
450+
451+
for active_entity_name in relevant_players:
441452
if self._components:
442453
component_states = '\n'.join(
443454
[self._component_pre_act_display(key) for key in self._components]

concordia/components/game_master/make_observation.py

Lines changed: 78 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from collections.abc import Sequence
1818
import copy
1919
import threading
20+
from typing import Any
2021

2122
from concordia.components.agent import action_spec_ignored
2223
from concordia.document import interactive_document
@@ -37,6 +38,50 @@
3738
)
3839

3940

41+
class ObservationQueue:
42+
"""A shared queue for observations that can be used across multiple GMs.
43+
44+
This allows observations queued in one GM to be delivered by another GM,
45+
preventing observation loss during GM transitions.
46+
"""
47+
48+
def __init__(self):
49+
self._queue = {}
50+
self._lock = threading.Lock()
51+
52+
def add(self, entity_name: str, event: str, player_names: Sequence[str]):
53+
"""Adds an event to the queue for the given entity."""
54+
with self._lock:
55+
if entity_name.lower().strip() == 'all':
56+
for player in player_names:
57+
if player not in self._queue:
58+
self._queue[player] = []
59+
self._queue[player].append(event)
60+
else:
61+
if entity_name not in self._queue:
62+
self._queue[entity_name] = []
63+
self._queue[entity_name].append(event)
64+
65+
def get_and_clear(self, entity_name: str) -> list[str]:
66+
"""Gets and clears the queue for the given entity."""
67+
with self._lock:
68+
if entity_name in self._queue and self._queue[entity_name]:
69+
events = self._queue[entity_name]
70+
self._queue[entity_name] = []
71+
return events
72+
return []
73+
74+
def get_all(self) -> dict[str, list[str]]:
75+
"""Returns a deep copy of the entire queue state."""
76+
with self._lock:
77+
return copy.deepcopy(self._queue)
78+
79+
def set_all(self, queue_state: Any):
80+
"""Sets the entire queue state from a deep copy."""
81+
with self._lock:
82+
self._queue = copy.deepcopy(queue_state)
83+
84+
4085
class MakeObservation(entity_component.ContextComponent,
4186
entity_component.ComponentWithLogging):
4287
"""A component that generates observations to send to players."""
@@ -49,14 +94,16 @@ def __init__(
4994
call_to_make_observation: str = DEFAULT_CALL_TO_MAKE_OBSERVATION,
5095
reformat_observations_in_specified_style: str = '',
5196
pre_act_label: str = DEFAULT_MAKE_OBSERVATION_PRE_ACT_LABEL,
97+
external_queue: ObservationQueue | None = None,
98+
allow_llm_fallback: bool = True,
5299
):
53100
"""Initializes the component.
54101
55102
Args:
56103
model: The language model to use for the component.
57104
player_names: Names of players.
58105
components: Keys of components to condition the observation on.
59-
call_to_make_observation: The call to action to make the observation.
106+
call_to_make_observation: The call to action to make the observation.
60107
Needed to extract the name of the active entity.
61108
reformat_observations_in_specified_style: If non-empty, the component will
62109
ask the model to reformat the observation to fit the style specified in
@@ -68,6 +115,12 @@ def __init__(
68115
description"."
69116
pre_act_label: Prefix to add to the output of the component when called in
70117
`pre_act`.
118+
external_queue: Optional shared ObservationQueue. If provided, this
119+
component will use the external queue instead of creating its own. This
120+
allows observations to persist across GM transitions.
121+
allow_llm_fallback: If True, when the queue is empty, the LLM will be used
122+
to generate an observation. If False, an empty observation will be
123+
returned when the queue is empty. Defaults to True.
71124
72125
Raises:
73126
ValueError: If the component order is not None and contains duplicate
@@ -82,9 +135,10 @@ def __init__(
82135
)
83136
self._call_to_make_observation = call_to_make_observation
84137
self._pre_act_label = pre_act_label
85-
self._lock = threading.Lock()
138+
self._allow_llm_fallback = allow_llm_fallback
86139

87-
self._queue = {}
140+
# Use external queue if provided, otherwise create one internally
141+
self._queue = external_queue if external_queue else ObservationQueue()
88142

89143
def get_named_component_pre_act_value(self, component_name: str) -> str:
90144
"""Returns the pre-act value of a named component of the parent entity."""
@@ -147,32 +201,25 @@ def pre_act(
147201
)
148202

149203
log_entry['Active Entity'] = active_entity_name
150-
with self._lock:
151-
log_entry['queue'] = copy.deepcopy(self._queue)
152-
153-
if (
154-
active_entity_name in self._queue
155-
and self._queue[active_entity_name]
156-
):
157-
log_entry['queue_active_entity'] = copy.deepcopy(
158-
self._queue[active_entity_name]
159-
)
160-
result = ''
161-
for event in self._queue[active_entity_name]:
162-
result += event + '\n\n\n'
163204

164-
self._queue[active_entity_name] = []
165-
else:
166-
result = prompt.open_question(
167-
question=(
168-
f'What does {active_entity_name} observe now? Never '
169-
'repeat information that was already provided to '
170-
f'{active_entity_name} unless absolutely necessary. Keep '
171-
'the story moving forward.'
172-
),
173-
max_tokens=1200,
174-
terminators=(),
175-
)
205+
events = self._queue.get_and_clear(active_entity_name)
206+
log_entry['queue'] = self._queue.get_all()
207+
if events:
208+
log_entry['queue_active_entity'] = events
209+
result = '\n\n\n'.join(events) + '\n\n\n'
210+
elif self._allow_llm_fallback:
211+
result = prompt.open_question(
212+
question=(
213+
f'What does {active_entity_name} observe now? Never '
214+
'repeat information that was already provided to '
215+
f'{active_entity_name} unless absolutely necessary. Keep '
216+
'the story moving forward.'
217+
),
218+
max_tokens=1200,
219+
terminators=(),
220+
)
221+
else:
222+
result = ''
176223

177224
if self._reformat_observations_in_specified_style:
178225
prompt.statement(
@@ -208,23 +255,12 @@ def pre_act(
208255

209256
def add_to_queue(self, entity_name: str, event: str):
210257
"""Adds an event to the queue of events to observe."""
211-
with self._lock:
212-
if entity_name.lower().strip() == 'all':
213-
for player in self._player_names:
214-
if player not in self._queue:
215-
self._queue[player] = []
216-
self._queue[player].append(event)
217-
else:
218-
if entity_name not in self._queue:
219-
self._queue[entity_name] = []
220-
self._queue[entity_name].append(event)
258+
self._queue.add(entity_name, event, self._player_names)
221259

222260
def get_state(self) -> entity_component.ComponentState:
223261
"""Returns the state of the component."""
224-
with self._lock:
225-
return {'queue': copy.deepcopy(self._queue)}
262+
return {'queue': self._queue.get_all()}
226263

227264
def set_state(self, state: entity_component.ComponentState) -> None:
228265
"""Sets the state of the component."""
229-
with self._lock:
230-
self._queue = copy.deepcopy(state['queue'])
266+
self._queue.set_all(state['queue'])

concordia/prefabs/game_master/dialogic_and_dramaturgic.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ class GameMaster(prefab_lib.Prefab):
111111
# in the component order. If not specified, the extra components
112112
# will be inserted at the end of the component order.
113113
'extra_components_index': {},
114+
# Optional shared ObservationQueue for cross-GM observation
115+
# persistence. When provided, observations queued by one GM
116+
# persist across GM switches.
117+
'external_queue': None,
114118
}
115119
)
116120
entities: Sequence[entity_agent_with_logging.EntityAgentWithLogging] = ()
@@ -142,6 +146,7 @@ def build(
142146
)
143147

144148
player_names = [entity.name for entity in self.entities]
149+
external_queue = self.params.get('external_queue', None)
145150

146151
scenes = self.params.get('scenes', _configure_default_scenes(player_names))
147152
assert isinstance(scenes, Sequence), 'scenes must be a sequence.'
@@ -187,30 +192,34 @@ def build(
187192
observation_component_key,
188193
display_events_key,
189194
],
195+
external_queue=external_queue,
190196
)
191197

198+
scene_tracker = gm_components.scene_tracker.SceneTracker(
199+
model=model,
200+
scenes=scenes,
201+
observation_component_key=(
202+
gm_components.make_observation.DEFAULT_MAKE_OBSERVATION_COMPONENT_KEY
203+
),
204+
)
205+
206+
# SendEventToRelevantPlayers handles notifying players about events.
207+
# Use scene tracker's get_participants as a filter to limit notifications.
192208
send_events_to_players_key = (
193209
gm_components.event_resolution.DEFAULT_SEND_PRE_ACT_VALUES_TO_PLAYERS_PRE_ACT_LABEL
194210
)
211+
scene_tracker_key = (
212+
gm_components.next_game_master.DEFAULT_NEXT_GAME_MASTER_COMPONENT_KEY
213+
)
195214
send_events_to_players = (
196215
gm_components.event_resolution.SendEventToRelevantPlayers(
197216
model=model,
198217
player_names=player_names,
199218
make_observation_component_key=make_observation_key,
219+
player_filter=scene_tracker.get_participants,
200220
)
201221
)
202222

203-
scene_tracker_key = (
204-
gm_components.next_game_master.DEFAULT_NEXT_GAME_MASTER_COMPONENT_KEY
205-
)
206-
scene_tracker = gm_components.scene_tracker.SceneTracker(
207-
model=model,
208-
scenes=scenes,
209-
observation_component_key=(
210-
gm_components.make_observation.DEFAULT_MAKE_OBSERVATION_COMPONENT_KEY
211-
),
212-
)
213-
214223
next_actor_key = gm_components.next_acting.DEFAULT_NEXT_ACTING_COMPONENT_KEY
215224
next_action_spec_key = (
216225
gm_components.next_acting.DEFAULT_NEXT_ACTION_SPEC_COMPONENT_KEY
@@ -237,7 +246,6 @@ def build(
237246
event_resolution = gm_components.event_resolution.EventResolution(
238247
model=model,
239248
event_resolution_steps=(identity_without_prefix,),
240-
notify_observers=False,
241249
)
242250

243251
terminator_key = gm_components.terminate.DEFAULT_TERMINATE_COMPONENT_KEY
@@ -252,14 +260,18 @@ def build(
252260
observation_component_key: observation,
253261
observation_to_memory_key: observation_to_memory,
254262
display_events_key: display_events,
255-
send_events_to_players_key: send_events_to_players,
256263
make_observation_key: make_observation,
257264
memory_component_key: memory,
258265
scene_tracker_key: scene_tracker,
259266
next_actor_key: next_actor,
260267
next_action_spec_key: next_action_spec,
261268
event_resolution_key: event_resolution,
262269
}
270+
# Only add SendEventToRelevantPlayers when notify_observers is False.
271+
if send_events_to_players is not None:
272+
components_of_game_master[send_events_to_players_key] = (
273+
send_events_to_players
274+
)
263275

264276
component_order = list(components_of_game_master.keys())
265277

concordia/prefabs/game_master/game_theoretic_and_dramaturgic.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ class GameMaster(prefab_lib.Prefab):
129129
'scenes': (),
130130
'action_to_scores': _default_action_to_scores,
131131
'scores_to_observation': _default_scores_to_observation,
132+
# Optional shared ObservationQueue for cross-GM observation
133+
# persistence.
134+
'external_queue': None,
132135
}
133136
)
134137
entities: (
@@ -152,6 +155,7 @@ def build(
152155
name = self.params.get('name', DEFAULT_NAME)
153156

154157
player_names = [entity.name for entity in self.entities]
158+
external_queue = self.params.get('external_queue', None)
155159

156160
scenes = self.params.get('scenes', _configure_default_scenes(player_names))
157161
assert isinstance(scenes, Sequence), 'scenes must be a sequence.'
@@ -196,15 +200,15 @@ def build(
196200

197201
make_observation_key = (
198202
gm_components.make_observation.DEFAULT_MAKE_OBSERVATION_COMPONENT_KEY)
199-
make_observation = (
200-
gm_components.make_observation.MakeObservation(
201-
model=model,
202-
player_names=player_names,
203-
components=[
204-
observation_component_key,
205-
display_events_key,
206-
],
207-
)
203+
make_observation = gm_components.make_observation.MakeObservation(
204+
model=model,
205+
player_names=player_names,
206+
components=[
207+
observation_component_key,
208+
display_events_key,
209+
],
210+
external_queue=external_queue,
211+
allow_llm_fallback=False,
208212
)
209213

210214
scene_tracker_key = (

0 commit comments

Comments
 (0)