Skip to content

Commit 6beeedd

Browse files
vezhnickcopybara-github
authored andcommitted
Add allow_duplicates option to AssociativeMemoryBank
Fixes a bug where identical actions across rounds were incorrectly deduplicated, causing EventResolution to pick up stale data. - Add allow_duplicates constructor parameter to AssociativeMemoryBank - Enable allow_duplicates for game_master_memory_bank in generic.py PiperOrigin-RevId: 862771227 Change-Id: I5416428f0549f3ffa1d176fbfa41fcaa431684b1
1 parent bfb7107 commit 6beeedd

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

concordia/associative_memory/basic_associative_memory.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,21 @@ class AssociativeMemoryBank:
3333
def __init__(
3434
self,
3535
sentence_embedder: Callable[[str], np.ndarray] | None = None,
36+
allow_duplicates: bool = False,
3637
):
3738
"""Constructor.
3839
3940
Args:
4041
sentence_embedder: text embedding model, if None then skip setting the
4142
embedder on initialization of the object. It still must be set before
4243
calling `add` or `retrieve` methods.
44+
allow_duplicates: if True, allow adding duplicate entries to the memory.
45+
This is useful for Game Master memories where the same action may recur
46+
across different rounds.
4347
"""
4448
self._memory_bank_lock = threading.Lock()
4549
self._embedder = sentence_embedder
50+
self._allow_duplicates = allow_duplicates
4651

4752
self._memory_bank = pd.DataFrame(columns=['text', 'embedding'])
4853
self._stored_hashes = set()
@@ -100,7 +105,7 @@ def add(
100105
hashed_contents = hash(tuple(contents.values()))
101106

102107
with self._memory_bank_lock:
103-
if hashed_contents in self._stored_hashes:
108+
if not self._allow_duplicates and hashed_contents in self._stored_hashes:
104109
return
105110

106111
derived = {'embedding': self._embedder(text)}

concordia/prefabs/simulation/generic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,11 @@ def __init__(
7878
self._checkpoint_counter = 0
7979

8080
# All game masters share the same memory bank.
81+
# allow_duplicates=True because the same action (e.g., "Bob: defect") may
82+
# recur across different rounds and should not be deduplicated.
8183
self.game_master_memory_bank = associative_memory.AssociativeMemoryBank(
8284
sentence_embedder=embedder,
85+
allow_duplicates=True,
8386
)
8487
all_data = self._config.instances
8588
gm_configs = [

0 commit comments

Comments
 (0)