Skip to content

Commit cb29670

Browse files
committed
address split_on_split_columns issue where a split was performed in the DataModule directly. Remove GenQualityDataModule. Replace az:// paths with public hf:// ones
1 parent 502d766 commit cb29670

File tree

8 files changed

+31
-247
lines changed

8 files changed

+31
-247
lines changed

mttl/datamodule/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
def maybe_filter_hf_dataset_by_task(
11-
dataset, task_field, task_names: str = None, n_proc=16
11+
dataset, task_field, task_names: str = None, n_proc=16, should_split_on_split_column=True
1212
):
1313
"""Filter a HuggingFace dataset by task names."""
1414

@@ -48,7 +48,9 @@ def maybe_filter_hf_dataset_by_task(
4848
dev_dataset is None
4949
and test_dataset is None
5050
and "split" in train_dataset.features
51+
and should_split_on_split_column
5152
):
53+
logger.info("Splitting train dataset on 'split' column.")
5254
train_dataset, dev_dataset, test_dataset = split_on_split_column(
5355
train_dataset, num_proc=n_proc
5456
)

projects/kms/train_km_simple.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,4 +332,7 @@ def train_km(training_args: KMArguments):
332332
logger.info("Model already trained, skipping")
333333
exit(0)
334334

335+
# The configs still contain pointers to internal az://mttldata paths, replace them
336+
args.dataset = args.dataset.replace('az://mttldata', BASE_PREFIX)
337+
335338
train_km(args)

projects/kms/utils/km_datamodule.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def setup_dataset(self):
350350
assert len(dataset) == 1, "all dataset should be in `train`"
351351

352352
# Let's first filter out unused tasks
353+
# NOTE: do not split on split column here, we will do custom train / dev split later
353354
(
354355
self._task_names,
355356
self._task_to_id,
@@ -360,6 +361,7 @@ def setup_dataset(self):
360361
dataset,
361362
self.config.task_name_field,
362363
self.config.finetune_task_name,
364+
should_split_on_split_column=False,
363365
n_proc=n_proc,
364366
)
365367

projects/kms/utils/longhealth_datamodule.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def setup_dataset(self):
8686
_,
8787
_,
8888
) = maybe_filter_hf_dataset_by_task(
89-
dataset, self.config.task_name_field, self.config.finetune_task_name
89+
dataset, self.config.task_name_field, self.config.finetune_task_name, should_split_on_split_column=False
9090
)
9191

9292
# Let's make sure that the full prompt is always in context
@@ -186,6 +186,8 @@ def expand_questions(examples, tokenizer, len_template):
186186
if self.tokenizer.chat_template is None:
187187
self.tokenizer.apply_chat_template = lambda x, **kwargs: x[0]["content"]
188188

189+
# TODO: refactor code to leverage `split_on_split_column` in
190+
# `maybe_filter_hf_dataset_by_task`
189191
if "split" in train_dataset.features:
190192
self.train_dataset, self.dev_dataset, self.test_dataset = (
191193
split_on_split_column(train_dataset)

projects/kms/utils/nqa_datamodule.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,12 @@ def setup_dataset(self):
3535
(
3636
self._task_names,
3737
self._task_to_id,
38-
train_dataset,
39-
_,
40-
_,
38+
self.train_dataset,
39+
self.dev_dataset,
40+
self.test_dataset,
4141
) = maybe_filter_hf_dataset_by_task(
42-
dataset, self.config.task_name_field, self.config.finetune_task_name
43-
)
44-
45-
self.train_dataset, self.dev_dataset, self.test_dataset = split_on_split_column(
46-
train_dataset
42+
dataset, self.config.task_name_field, self.config.finetune_task_name,
43+
should_split_on_split_column=False
4744
)
4845

4946
def expand_questions(examples, tokenizer):

projects/kms/utils/pit_datamodule.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def expand_targets_and_chat(example):
129129
self.config.task_name_field,
130130
self.config.finetune_task_name,
131131
n_proc=n_proc,
132+
should_split_on_split_column=False,
132133
)
133134

134135
train_dataset = train_dataset.map(

projects/kms/utils/quality_datamodule.py

Lines changed: 14 additions & 235 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def setup_dataset(self):
3535
(
3636
self._task_names,
3737
self._task_to_id,
38-
train_dataset,
39-
_,
40-
_,
38+
self.train_dataset,
39+
self.dev_dataset,
40+
self.test_dataset,
4141
) = maybe_filter_hf_dataset_by_task(
4242
dataset, self.config.task_name_field, self.config.finetune_task_name
4343
)
@@ -124,243 +124,22 @@ def expand_questions(examples, tokenizer):
124124
if self.tokenizer.chat_template is None:
125125
self.tokenizer.apply_chat_template = lambda x, **kwargs: x[0]["content"]
126126

127-
if "split" in train_dataset.features:
128-
self.train_dataset, self.dev_dataset, self.test_dataset = (
129-
split_on_split_column(train_dataset)
130-
)
131-
self.train_dataset = self.train_dataset.map(
132-
lambda examples: expand_questions(examples, self.tokenizer),
133-
batched=True,
134-
batch_size=1000,
135-
num_proc=1,
136-
remove_columns=train_dataset.column_names,
137-
)
127+
self.train_dataset = self.train_dataset.map(
128+
lambda examples: expand_questions(examples, self.tokenizer),
129+
batched=True,
130+
batch_size=1000,
131+
num_proc=1,
132+
remove_columns=self.train_dataset.column_names,
133+
)
134+
if self.dev_dataset:
138135
self.dev_dataset = self.dev_dataset.map(
139136
lambda examples: expand_questions(examples, self.tokenizer),
140137
batched=True,
141138
batch_size=1000,
142139
num_proc=1,
143-
remove_columns=train_dataset.column_names,
140+
remove_columns=self.dev_dataset.column_names,
144141
)
145-
self.test_dataset = self.dev_dataset
146142
else:
147-
train_dataset = train_dataset.map(
148-
lambda examples: expand_questions(examples, self.tokenizer),
149-
batched=True,
150-
batch_size=1000,
151-
num_proc=1,
152-
remove_columns=train_dataset.column_names,
153-
)
154-
self.train_dataset = self.dev_dataset = self.test_dataset = train_dataset
155-
156-
157-
prompt_template_w_docs = """
158-
--------------BEGIN CONTEXT--------------
159-
160-
{documents}
161-
162-
--------------END CONTEXT--------------
163-
164-
{question_text}
165-
{options}
166-
167-
Please answer using the following format:
168-
0. Begin your answer with the phrase "The correct answer is".
169-
1. State the letter of the correct option (e.g., A, B, C, D).
170-
2. Follow the letter with a colon and the exact text of the option you chose.
171-
3. Make sure your answer is a single, concise sentence.
172-
173-
For example, if the correct answer to a question is option C, and the text for C is 'Acute Bronchitis', your answer should be:
174-
'The correct answer is C: Acute bronchitis.'
175-
"""
176-
177-
prompt_template_no_docs = """
178-
{question_text}
179-
{options}
180-
181-
Please answer using the following format:
182-
1. Begin your answer with the phrase "The correct answer is".
183-
2. State the letter of the correct option (e.g., A, B, C, D).
184-
3. Follow the letter with a colon and the exact text of the option you chose.
185-
4. Make sure your answer is a single, concise sentence.
186-
187-
For example, if the correct answer to a question is option C, and the text for C is 'Acute Bronchitis', your answer should be:
188-
'The correct answer is C: Acute bronchitis.'
189-
"""
190-
191-
max_new_tokens = 50
192-
193-
194-
@dataclass
195-
class GenQualityDatasetConfig(DatasetConfig):
196-
task_name_field: str = "document_id"
197-
task_source_field: str = "document_id"
198-
prompt: str = (
199-
"Answer the following question. Give only the answer, and no extra commentary, formatting, or chattiness. Question: "
200-
)
201-
include_context: bool = False
202-
topk_context: int = 10
203-
include_all_answers: bool = True
204-
205-
206-
@DataModule.register("gen_quality", config_cls=GenQualityDatasetConfig)
207-
class GenQualityDataModule(DataModule):
208-
def setup_dataset(self):
209-
from mttl.models.library.dataset_library import DatasetLibrary
210-
211-
dataset = DatasetLibrary.pull_dataset(self.config.dataset)
212-
213-
# Instead of always working with the large datasets, we can subsample it
214-
if self.config.custom_split_file:
215-
dataset = apply_custom_split_file(dataset, self.config.custom_split_file)
216-
217-
(
218-
self._task_names,
219-
self._task_to_id,
220-
train_dataset,
221-
_,
222-
_,
223-
) = maybe_filter_hf_dataset_by_task(
224-
dataset, self.config.task_name_field, self.config.finetune_task_name
225-
)
226-
227-
# Let's make sure that the full prompt is always in context
228-
len_template = len(self.tokenizer.encode(prompt_template_w_docs))
229-
230-
def expand_questions(examples, tokenizer, len_template):
231-
batch = {
232-
"source": [],
233-
"target": [],
234-
"document_id": [],
235-
}
143+
self.dev_dataset = self.train_dataset
236144

237-
for i in range(len(examples["document_id"])):
238-
for j in range(len(examples["questions"][i])):
239-
document_id = examples["document_id"][i]
240-
question = examples["questions"][i][j]
241-
options = examples["options"][i][j]
242-
gold_label = examples["gold_label"][i][j]
243-
if gold_label == -1:
244-
gold_label = label_index = None
245-
else:
246-
label_index = gold_label - 1
247-
248-
""" NEW """
249-
letters = ["A", "B", "C", "D"]
250-
option_str = "\n".join(
251-
[f"{letters[i]}: {option}" for i, option in enumerate(options)]
252-
)
253-
len_question = len(tokenizer.encode(question))
254-
len_options = len(tokenizer.encode(option_str))
255-
len_suffix = len(tokenizer.encode("The correct answer is: "))
256-
257-
total_len = len_question + len_options + len_template + len_suffix
258-
259-
if self.config.include_context:
260-
context = examples["text"][i]
261-
262-
if isinstance(context, list):
263-
# following Alan's approach
264-
context = " ".join(
265-
[
266-
f"Passage {k+1}: {context[k]}\n\n"
267-
for k in range(
268-
min(self.config.topk_context, len(context))
269-
)[::-1]
270-
]
271-
)
272-
assert (
273-
type(context) == str
274-
), f"Context should be a string, but got {type(context)}"
275-
276-
# Let's do some rough trucation if needed
277-
context_ids = tokenizer.encode(context)
278-
len_context = len(context_ids)
279-
space_left = self.config.max_input_length - total_len
280-
281-
if space_left < len_context:
282-
context_ids = context_ids[: max(0, space_left - 20)]
283-
context = tokenizer.decode(
284-
context_ids, skip_special_tokens=True
285-
)
286-
287-
prompt = prompt_template_w_docs.format(
288-
documents=context,
289-
question_text=question,
290-
options=option_str,
291-
)
292-
else:
293-
prompt = prompt_template_no_docs.format(
294-
question_text=question,
295-
options=option_str,
296-
)
297-
298-
"""
299-
source = [
300-
{
301-
"role": "system",
302-
"content": sys_prompt,
303-
},
304-
{
305-
"role": "user",
306-
"content": prompt,
307-
},
308-
]
309-
"""
310-
source = [
311-
{
312-
"role": "user",
313-
"content": prompt,
314-
}
315-
]
316-
317-
batch["source"].append(
318-
tokenizer.apply_chat_template(
319-
source, add_generation_prompt=True, tokenize=False
320-
)
321-
+ "The correct answer is"
322-
)
323-
batch["target"].append(
324-
letters[label_index]
325-
) # [options[label_index]])
326-
batch["document_id"].append(examples["document_id"][i])
327-
328-
return batch
329-
330-
if self.tokenizer.chat_template is None:
331-
self.tokenizer.apply_chat_template = lambda x, **kwargs: x[0]["content"]
332-
333-
if "split" in train_dataset.features:
334-
self.train_dataset, self.dev_dataset, self.test_dataset = (
335-
split_on_split_column(train_dataset)
336-
)
337-
self.train_dataset = self.train_dataset.map(
338-
lambda examples: expand_questions(
339-
examples, self.tokenizer, len_template
340-
),
341-
batched=True,
342-
batch_size=1000,
343-
num_proc=1,
344-
remove_columns=train_dataset.column_names,
345-
)
346-
self.dev_dataset = self.dev_dataset.map(
347-
lambda examples: expand_questions(
348-
examples, self.tokenizer, len_template
349-
),
350-
batched=True,
351-
batch_size=1000,
352-
num_proc=1,
353-
remove_columns=train_dataset.column_names,
354-
)
355-
self.test_dataset = self.dev_dataset
356-
else:
357-
train_dataset = train_dataset.map(
358-
lambda examples: expand_questions(
359-
examples, self.tokenizer, len_template
360-
),
361-
batched=True,
362-
batch_size=1000,
363-
num_proc=1,
364-
remove_columns=train_dataset.column_names,
365-
)
366-
self.train_dataset = self.dev_dataset = self.test_dataset = train_dataset
145+
self.test_dataset = self.dev_dataset

projects/kms/utils/quality_evaluator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from mttl.logging import logger, warn_once
1616
from projects.kms.utils.nqa_datamodule import NQADatamodule, NQADatasetConfig
1717
from projects.kms.utils.quality_datamodule import (
18-
GenQualityDataModule,
19-
GenQualityDatasetConfig,
2018
QualityDatamodule,
2119
QualityDatasetConfig,
2220
)

0 commit comments

Comments
 (0)