Skip to content

Commit 5370b69

Browse files
committed
setting num_participant_samples to zero uses factorial design
1 parent 9c65d7a commit 5370b69

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

llm_cooperation/experiments/dilemma.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,7 @@ def compute_freq_pd(choices: List[Choices[DilemmaChoice]]) -> float:
299299

300300
@lru_cache()
301301
def get_participants(num_participant_samples: int) -> List[Participant]:
302-
participant_conditions = GROUP_PROMPT_CONDITIONS
303-
random_attributes: Grid = {
302+
pd_attributes: Grid = {
304303
CONDITION_CHAIN_OF_THOUGHT: [True, False],
305304
CONDITION_LABEL: all_values(Label),
306305
CONDITION_CASE: all_values(Case),
@@ -310,11 +309,15 @@ def get_participants(num_participant_samples: int) -> List[Participant]:
310309
}
311310
result = list(
312311
participants(
313-
participant_conditions,
314-
random_attributes,
312+
GROUP_PROMPT_CONDITIONS,
313+
pd_attributes,
315314
num_participant_samples,
316315
seed=SEED_VALUE,
317316
)
317+
if num_participant_samples > 0
318+
else participants(
319+
GROUP_PROMPT_CONDITIONS | pd_attributes,
320+
)
318321
)
319322
for i, participant in enumerate(result):
320323
participant["id"] = i

tests/test_dilemma.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
from openai_pygenerator import Completion, logger
3030
from pytest_lazyfixture import lazy_fixture
3131

32-
from llm_cooperation import DEFAULT_MODEL_SETUP, Group, Participant, Payoffs
33-
from llm_cooperation.experiments import AI_PARTICIPANTS
32+
from llm_cooperation import DEFAULT_MODEL_SETUP, Group, Participant, Payoffs, exhaustive
33+
from llm_cooperation.experiments import AI_PARTICIPANTS, GROUP_PROMPT_CONDITIONS
3434
from llm_cooperation.experiments.dilemma import (
3535
CONDITION_LABELS_REVERSED,
3636
CONDITION_PRONOUN,
@@ -47,6 +47,7 @@
4747
defect_label,
4848
extract_choice_pd,
4949
get_choice_template,
50+
get_participants,
5051
get_prompt_pd,
5152
get_pronoun_phrasing,
5253
move_as_str,
@@ -236,6 +237,18 @@ def test_cooperate_label_reversed(condition: Participant, expected: str):
236237
assert cooperate_label(condition_reversed) == expected
237238

238239

240+
def test_get_participants():
241+
n = 5
242+
random_participants = get_participants(num_participant_samples=n)
243+
assert len(random_participants) == n * len(
244+
list(exhaustive(GROUP_PROMPT_CONDITIONS))
245+
)
246+
assert get_participants(n) == random_participants
247+
factorial_participants = get_participants(num_participant_samples=0)
248+
assert get_participants(0) == factorial_participants
249+
assert len(factorial_participants) == 3888
250+
251+
239252
def test_run_repeated_game(mocker, base_condition):
240253
completions = [
241254
{"role": "assistant", "content": "project green"},

0 commit comments

Comments
 (0)