Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 942952d

Browse files
authored
TOD project folder (#4437)
* [TOD] Projects folder for tod_simulator + scripts + documentation [lots of commits from this being a stacked diff removed cause... no one needs to see all that.]
1 parent 24cc648 commit 942952d

19 files changed

+1845
-5
lines changed

parlai/core/tod/tod_agents.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,19 @@ def __init__(self, opt, shared=None):
736736
self._num_examples_cache = sum([len(x.rounds) for x in self.episodes])
737737
self._num_episodes_cache = len(self.episodes)
738738

739+
@classmethod
740+
def add_cmdline_args(
741+
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
742+
) -> ParlaiParser:
743+
parser = super().add_cmdline_args(parser, partial_opt)
744+
parser.add_argument(
745+
"--api-schemas",
746+
type="bool",
747+
default=False,
748+
help="Preempt first turn with intents + required/optional parameters as key/value for given domain. NOOP for this teacher, but including to make sweeps easier",
749+
)
750+
return parser
751+
739752
def setup_data(self, fold):
740753
for episode in self.generate_episodes():
741754
if len(episode.rounds) < 1:

parlai/core/tod/tod_core.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def delex(cls, text, slots):
168168
def inner_list_join(cls, values):
169169
if isinstance(values, str):
170170
return values
171-
return ", ".join(sorted([v.strip() for v in values]))
171+
return ", ".join(sorted([str(v).strip() for v in values]))
172172

173173
@classmethod
174174
def inner_list_split(cls, s):
@@ -185,12 +185,18 @@ def inner_list_split(cls, s):
185185
def maybe_inner_list_join(cls, values):
186186
if type(values) is dict:
187187
return str(values)
188-
if isinstance(values, str) or isinstance(values, int):
188+
if (
189+
isinstance(values, str)
190+
or isinstance(values, int)
191+
or isinstance(values, float)
192+
):
189193
return values
190194
elif isinstance(values, Iterable):
191195
return SerializationHelpers.inner_list_join(values)
192196
else:
193-
raise RuntimeError("invalid type of argument for maybe_inner_list_join")
197+
raise RuntimeError(
198+
f"invalid type of argument for maybe_inner_list_join: {values}; type {type(values)}"
199+
)
194200

195201
@classmethod
196202
def api_dict_to_str(cls, apidict):

parlai/scripts/tod_world_script.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,14 @@ def _is_batch_world(self, world):
5858
def _log_batch(self, world):
5959
batch_acts = world.get_batch_acts()
6060
for i, acts in enumerate(batch_acts):
61-
# filter out for empty
62-
acts = [act for act in acts if act["id"] != "" and act["text"] != ""]
61+
acts = [
62+
act for act in acts if act is not None and "id" in act and "text" in act
63+
]
64+
acts = [
65+
act
66+
for act in acts
67+
if act["id"] != "" and (act["text"] != "" or "Human" in act["id"])
68+
]
6369
if len(acts) > 0:
6470
self._add_msgs(acts, idx=i)
6571
if world.episode_done():

projects/tod_simulator/README.md

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Task Oriented Dialogue (TOD): Agents, Worlds, Scripts, etc
2+
3+
### _Teaching Models new APIs: Domain-Agnostic Simulators for Task Oriented Dialogue_
4+
5+
Moya Chen, Paul A. Crook, Stephen Roller
6+
7+
## Abstract
8+
9+
We demonstrate that large language models are able to simulate Task Oriented Dialogues in novel domains, provided only with an API implementation and a list of goals. We show these simulations can formulate online, automatic metrics that correlate well with human evaluations. Furthermore, by checking for whether the User's goals are met, we can use simulation to repeatedly generate training data and improve the quality of simulations themselves. With no human intervention or domain-specific training data, our simulations bootstrap end-to-end models which achieve a 37% error reduction in previously unseen domains. By including as few as 32 domain-specific conversations, bootstrapped models can match the performance of a fully-supervised model with 10× more data. To our knowledge, this is the first time simulations have been shown to be effective at bootstrapping models without explicitly requiring any domain-specific training data, rule-engineering, or humans-in-the-loop.
10+
11+
## Paper
12+
13+
[Link to arXiv](https://arxiv.org/abs/2110.06905)
14+
15+
# Explanation of content in project
16+
17+
This directory contains code for executing conversations for task-oriented dialogue (ex. setting an alarm, asking for the time) in a structured format. We introduce this structured format then go into the operational details for our setup: dataset generation + model training, simulation script usage, then give an overview of scripts in this folder. We then go into details of the specific datasets that we use as well as how to download and interact with our pre-trained models.
18+
19+
As a terminology note, while the paper uses "Assistant" throughout, the same speaker is generally referred to as the "System" throughout code and documentation.
20+
21+
## Conversation structure
22+
23+
In task oriented dialogue, we have a user (with some goal) that requests some form of action out of an assistant system. This assistant system normally has some external knowledge-base with to which it can interact with via APIs.
24+
25+
To model this, we begin each episode with a grounding stage where:
26+
1. an api schema agent gives a description string of the API to an api call and api response agents
27+
2. a goal agent gives a target goal string to a user utterance agent to start the conversation
28+
29+
During the 'parlay' or normal conversational phase, we have four agents that speak in looped turns:
30+
1. User utt agent
31+
2. System API call agent
32+
3. API response agent
33+
4. System utt agent
34+
35+
In analogy to more traditional TOD-setups, one can think of the api call agent as dialogue state tracking and the system utt agent as natural language generation. Since many TOD-systems these days combine both dialogue state tracking and natural language generation into one model, we assume that the api call and system agents are the same.
36+
37+
To prevent leakage of information between agents during the parlay phase, each agent only observes only its own output and that of the agent which speaks immediately before.
38+
39+
## Dataset setup + Model Training
40+
41+
See `parlai/core/tod/tod_agents.py` for information on how to build agents and teachers for a specific dataset.
42+
43+
Of the agents described in the conversation, only the User and System need to be trained with generative models. These can be trained as normal ParlAI models (ie.`parlai train_model -t <insert task> -mf <model file path> -m <model type>`) using System- and UserSimulator- Teachers created via the documentation in the `tod_agents.py` file mentioned above.
44+
45+
## Simulation Script Usage
46+
Use `python parlai/scripts/tod_world_script.py` or `parlai tod_world_script` (or the corresponding `distributed_` prefixed versions) to generate model-model chats. Arguments to the script are listed in file. Note that it is oftentimes preferable to use the `python ..` rather than `parlai ..` form of this command, especially if one has model or agent specific flags, due to argument order parsing.
47+
48+
As a quick example, we provide
49+
50+
`parlai tod_world_script -o projects/tod_simulator/tod_world_configs/google_sgd_simulation_dump_data.json`
51+
52+
as an example of printing the validation data from Google SGD Out of Domain through the simulation script.
53+
54+
Additionally, use this to specify a conversation where all of the agents take human input from the command line:
55+
56+
```
57+
parlai tod_world_script --system-model parlai.agents.local_human.local_human:LocalHumanAgent --user-model parlai.agents.local_human.local_human:LocalHumanAgent --api-resp-model parlai.agents.local_human.local_human:LocalHumanAgent --api-schema-grounding-model parlai.agents.local_human.local_human:LocalHumanAgent --goal-grounding-model parlai.agents.local_human.local_human:LocalHumanAgent
58+
```
59+
60+
(which is the same as `parlai tod_world_script -o projects/tod_simulator/tod_world_configs/all_human.json`, included for convenience)
61+
62+
Defaults are provided for the grounding agents but must be specified for the rest. Pretrained model locations can also be specified for the user and system with `--user-model-file` and `--system-model-file` arguments respectively. Since the system agent + api call agent are assumed to be the same, we only specify the 5 distinct agents, rather than 6.
63+
64+
Further documentation of the simulation world and simulation world metrics are described in `parlai/core/tod/tod_world.py` and `parlai/core/tod/world_metrics.py`, respectively.
65+
66+
## Scripts in `script` directory of this folder
67+
68+
**cleanup\_conversation.py**
69+
As a convenience, we also add a script for parsing the output conversation of the TOD Script into a format slightly more ameniable to ACUTE-Eval. While the raw output of the TOD Script could be used as well, the provided cleanup script does things like remove API utterances + Goals.
70+
71+
**do\_get\_passing\_only\_on\_dir.py**
72+
Uses `get_passing_only.py` internaly to run on a directory
73+
74+
**get\_al\_samples\_for\_gsgd.py**
75+
Gets active learning samples out of Google SGD's OutDomainSystemTeacher train set based on worst-performing API calls as extracted from `get_passing_only.py`.
76+
77+
**get\_api\_data.py**
78+
For models trained with `tod_distributed_uber_script.py` that have `--api-jga-record` set to `True`, this will automatically pull per-api Google SGD Out-of-Domain JGA and simulation success statistics.
79+
80+
**get\_interdistinct\_on\_conversations.py**
81+
Deprecated script to calculate interdistinct metrics for simulation conversations. (Included for completeness.)
82+
83+
**get\_passing\_only.py**
84+
Given a conversation generated from `tod_world_script`, outputs statistics about performance of different APIs.
85+
86+
**get\_quick\_eval\_stats.py**
87+
For models trained with `tod_distributed_uber_script.py`, this quickly grabs evaluation and model-model simulation data into a comma-separated format.
88+
89+
**tod\_distributed\_uber\_multiwoz\_script.py**
90+
Version of `tod_distributed_uber_script.py` but with MultiWoz v2.2 as the primary task rather than Google SGD Out-of-Domain. (Included for completeness.)
91+
92+
**tod\_distributed\_uber\_script.py**
93+
Multi-step train, evaluation, and data generation script used in Simulations paper. Uses Google SGD Out-of-Domain as primary dataset; note "STANDALONE\_API\_FILE\_PATH" that needs to be set in file. Makes use of `do_get_passing_only_on_dir.py` and `get_al_samples_for_gsgd.py`; use `get_passing_only.py` and `get_api_data.py` after the fact for analysis.
94+
95+
Note that this script is intended to be run in a SLURM environment matching that of the Simulations paper authors. It is unknown how the script performs in other settings but is included as a reference.
96+
97+
## Tasks used in the paper
98+
99+
See the appendix of [the paper](https://arxiv.org/abs/2110.06905) (or the description of the task in ParlAI Task List) for explanations of these datasets. Below, we include the dataset name, the command to run the `SystemTeacher` relevant for each of the datasets, and any other notable details. Other agents and teachers for the dataset are specified in the relevant task `agent.py` files.
100+
101+
### Pretraining Tasks
102+
103+
* Google SGD In-Domain
104+
* `parlai dd -t google_sgd_simulation_splits:InDomainSystemTeacher`
105+
* MetalWoz
106+
* `parlai dd -t metalwoz:SystemTeacher`
107+
* MSR_E2E
108+
* `parlai dd -t msr_e2e:SystemTeacher`
109+
* Note that due to the lack of annotations in this dataset, this System Teacher *only* includes utterance turns
110+
* Multidogo
111+
* `parlai dd -t multidogo:SystemTeacher`
112+
* MultiWoz
113+
* We use a fb-internal pre-processing of MultiWoz, based on MultiWoz v2.1 and do not open source it at this time.
114+
* Taskmaster
115+
* `parlai dd -t taskmaster:SystemTeacher`
116+
* Taskmaster2
117+
* `parlai dd -t taskmaster2:SystemTeacher`
118+
* Taskmaster3 (TicketTalk)
119+
* `parlai dd -t taskmaster3:SystemTeacher`
120+
121+
### Experimentation Tasks
122+
123+
* Google SGD Out-of-Domain
124+
* `parlai dd -t google_sgd_simulation_splits:OutDomainSystemTeacher`
125+
* MultiWoz (not currently included in paper)
126+
* `parlai dd -t multiwoz_v22:SystemTeacher`
127+
* This is a preprocessing of the dataset based on MultiWoz v2.2. Though utterances are the same as used for pre-training, API Call and API Response structures aer different.
128+
129+
See "scripts in project directory" for scripts associated with training, evaluation, and data generation.
130+
131+
## Pretrained models
132+
133+
We release Schema-Aware and Schema-Agnostic version of our intermediate task-pretraining. One can see the outputs of these models by running
134+
135+
```
136+
parlai dd -t google_sgd_simulation_splits:OutDomainSystemTeacher -mf zoo:tod/tod_base_yes_api/model --skip-generation false --api-schemas true
137+
```
138+
139+
for the Schema-Aware version of the model and
140+
141+
```
142+
parlai dd -t google_sgd_simulation_splits:OutDomainSystemTeacher -mf zoo:tod/tod_base_no_api/model --skip-generation false --api-schemas false
143+
```
144+
145+
for the Schema-Agnostic version.
146+
147+
Note the path names of the model files; they are `zoo:tod/tod_base_{yes,no}_api/mode` where "yes" corresponds to Schema-Aware and "no" corresponding to Schema-Agnostic. Care must be taken to specify `--api-schemas` correctly since task-setting flags are parsed from teacher-specific flags and not from model files.
148+
149+
These models are both based on a BART-large (400 million paramater) base model. Hyperparameters for training can be found in the paper; tasks are listed in "Pretraining Tasks" above.

projects/tod_simulator/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
"""
7+
Script for making light modifications to conversations from tod chats such that they are
8+
ready for the ACUTE format.
9+
10+
Notably, this does things that are slightly too much of a pain in the butt to do with
11+
regexes like "add suffixes to ids when multiple ids might have the same string" (and
12+
change metadata appropriately).
13+
14+
For example, the following
15+
16+
```
17+
python cleanup_conversation.py --source_file <insert filepath>_conversations.jsonl --report-path <insert filepath>.json --agent-suffixes user_utt_model _BASE_USER system_utt_model _BASE_SYSTEM --included-speakers goal_grounding_model user_utt_model system_utt_model
18+
```
19+
20+
strips the API call related turns and adds "_BASE_USER" and "_BASE_SYSTEM" (which otherwise would be the model type name, ex BART) to the latter two, respecitvely.
21+
"""
22+
23+
from parlai.core.params import ParlaiParser
24+
from parlai.utils.conversations import Conversations, Metadata
25+
from parlai.utils.io import PathManager
26+
from parlai.core.script import ParlaiScript, register_script
27+
28+
from parlai.core.tod.tod_core import TodAgentType, TOD_AGENT_TYPE_TO_PREFIX
29+
30+
import json
31+
32+
33+
@register_script("conversation_cleanup")
34+
class ConversationCleanup(ParlaiScript):
35+
@classmethod
36+
def setup_args(cls):
37+
parser = ParlaiParser(
38+
False,
39+
False,
40+
"Script for simplying conversations output from TOD. Input expected to be in conversations format, as is output",
41+
)
42+
# Following params are same as the `eval_model` script
43+
parser.add_argument(
44+
"--source-file",
45+
type=str,
46+
required=True,
47+
help="Source file in conversations format, generated from `tod_world_script.py`",
48+
)
49+
parser.add_argument(
50+
"--out-file", type=str, default=None, help="Output location."
51+
)
52+
parser.add_argument(
53+
"--included-speakers",
54+
nargs="*",
55+
type=str,
56+
choices=[e.value for e in TodAgentType],
57+
default=[TodAgentType.USER_UTT_AGENT, TodAgentType.SYSTEM_UTT_AGENT],
58+
help="Which of the speakers to not remove. Should match those in `tod_world`",
59+
)
60+
parser.add_argument(
61+
"--agent-suffixes",
62+
nargs="*",
63+
type=str,
64+
default=[
65+
TodAgentType.USER_UTT_AGENT,
66+
"_USER",
67+
TodAgentType.SYSTEM_UTT_AGENT,
68+
"_SYSTEM",
69+
],
70+
help="List of <speaker type, suffix> pairs. Speaker type should match those in `TodAgentType`; outputs (if included) will have the suffix added to the ID. This is useful when using multiple of the same out model (ex. Bart model for both the user and the system)",
71+
)
72+
parser.add_argument(
73+
"--num-conversations",
74+
default=400,
75+
help="Number of conversations to include. -1 for all",
76+
)
77+
parser.add_argument(
78+
"--report-path",
79+
required=True,
80+
help="path of the report saved from the tod_metrics_script",
81+
)
82+
return parser
83+
84+
def _get_turn_type(self, turn):
85+
for agent_type, prefix in TOD_AGENT_TYPE_TO_PREFIX.items():
86+
if prefix in turn["text"]:
87+
return agent_type
88+
89+
def run(self):
90+
opt = self.opt
91+
if int(len(self.opt["agent_suffixes"])) % 2 != 0:
92+
raise RuntimeError("Agent suffix input should be even")
93+
suffixes = {}
94+
for i in range(int(len(self.opt["agent_suffixes"]) / 2)):
95+
agent = self.opt["agent_suffixes"][2 * i]
96+
suffix = self.opt["agent_suffixes"][2 * i + 1]
97+
suffixes[agent] = suffix
98+
99+
with PathManager.open(opt["report_path"]) as r:
100+
report = json.load(r)["report"]
101+
tod_metrics = report["tod_metrics"]
102+
103+
if opt["num_conversations"] > -1:
104+
tod_metrics = tod_metrics[: opt["num_conversations"]]
105+
106+
source = self.opt["source_file"].replace(".jsonl", "")
107+
if self.opt["out_file"]:
108+
out = self.opt["out_file"]
109+
else:
110+
if (
111+
"conversations" in source
112+
): # just to make sure we don't overwrite anything...
113+
out = source.replace("conversations", "cleaned_conversations")
114+
else:
115+
out = "cleaned_" + source
116+
117+
speakers = []
118+
with PathManager.open(out + ".jsonl", "w") as f:
119+
conversations = Conversations(source + ".jsonl")
120+
for i, conversation in enumerate(conversations):
121+
if opt["num_conversations"] >= 0 and i >= opt["num_conversations"]:
122+
break
123+
cleaned_dialog = []
124+
for parlay_round in conversation.episode["dialog"]:
125+
cleaned_parlay_round = []
126+
for turn in parlay_round:
127+
turn_type = self._get_turn_type(turn)
128+
if turn_type in self.opt["included_speakers"]:
129+
if turn_type in suffixes:
130+
turn["id"] += suffixes[turn_type]
131+
if turn["id"] not in speakers:
132+
speakers.append(turn["id"])
133+
cleaned_parlay_round.append(turn)
134+
if len(cleaned_parlay_round) > 0:
135+
cleaned_dialog.append(cleaned_parlay_round)
136+
convo = {}
137+
convo["dialog"] = cleaned_dialog
138+
convo["metadata_path"] = Metadata._get_path(out)
139+
convo["context"] = [
140+
{
141+
"synthetic_task_success": tod_metrics[i][
142+
"synthetic_task_success"
143+
],
144+
"goal_text": tod_metrics[i]["goal"]["text"],
145+
}
146+
]
147+
json_convo = json.dumps(convo)
148+
f.write(json_convo + "\n")
149+
150+
old_meta = Metadata(source + ".jsonl")
151+
Metadata.save_metadata(
152+
out, old_meta.opt, old_meta.self_chat, speakers, **old_meta.extra_data
153+
)
154+
155+
156+
if __name__ == "__main__":
157+
ConversationCleanup.main()

0 commit comments

Comments
 (0)