Skip to content

Commit c1df4f8

Browse files
authored
Merge pull request facebookresearch#22 from facebookresearch/tuan/update_slurm_docs
Update documentation
2 parents 93a9734 + 8471f67 commit c1df4f8

File tree

6 files changed

+44
-26
lines changed

6 files changed

+44
-26
lines changed

examples/evaluation/README.md

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,15 @@ After you have trained an LCM, the checkpoint will be saved in a folder under th
77
Since an LCM expects input data in sentence level, we need to preprocess the evaluation datasets accordingly. This includes parsing the raw content and
88
splitting texts into sentences, then embedding them into vectors using a Sonar encoder.
99

10-
The example below shows how we prepare the data for CNN Dailymail. We load the dataset from Huggingface using [`datasets` API](https://huggingface.co/docs/datasets/en/index). The sentence splitting is done using [wtpsplit](https://github.com/segment-any-text/wtpsplit). First, we install necessary libraries:
11-
12-
```shell
13-
python -m pip install datasets wtpsplit
14-
```
10+
The example below shows how we prepare the data for CNN Dailymail. We load the dataset from Huggingface using [`datasets` API](https://huggingface.co/docs/datasets/en/index). The sentence splitting is done using [wtpsplit](https://github.com/segment-any-text/wtpsplit). Make sure to specify `--extra data` in installing the project to include these libraries.
1511

1612
All processing logic is implemented in the file `prepare_evaluation_data.py`, as described below.
1713

1814
### Step 1.1: Process the split:
1915
Next, we download and parse the content (source text and summaries), saving different splits into JSON format
2016

2117
```shell
22-
python prepare_evaluation_data.py prepare_data \
18+
uv run --extra data prepare_evaluation_data.py prepare_data \
2319
--dataset_name=cnn_dailymail \
2420
--output_dir=jsonl_dataset \
2521
--source_text_column=article \
@@ -41,16 +37,30 @@ The output will be stored in different files `[split].jsonl` under the directory
4137
To perform sentence splitting and sonar embedding for each split, run the following command:
4238

4339
```shell
44-
python prepare_evaluation_data.py embed \
40+
uv run --extra data prepare_evaluation_data.py embed \
4541
--input_path=jsonl_dataset/cnn_dailymail/test.jsonl \
46-
--input_column=article \
47-
--output_column=highlights \
42+
--source_text_column=prompt \
43+
--target_text_column=answer \
4844
--output_dir=parquet_dataset/cnn_dailymail \
4945
--lang=eng_Latn \
50-
--mode=slurm \
46+
--mode=local \
5147
--log_dir=/tmp/logs/embed_cnndm
5248
```
5349

50+
Depending on your machine, this might take some time. Alternatively, you can try to run in your SLURM cluster with the argmnent `--mode=slurm --shards=NO_OF_PARALLEL_JOBS`. This requires changing your SLURM config accordingly. We use [submitit](https://github.com/facebookincubator/submitit) to configure the job launcher. Here is the relevant excerpt in the script:
51+
52+
```python
53+
launcher = Launcher(
54+
cache=None,
55+
config_dump_dir=Path(log_dir) / "conf",
56+
log_folder=Path(log_dir) / "logs",
57+
cluster=mode,
58+
update_parameters={"partition": "your_slurm_partition"},
59+
)
60+
_ = await launcher.schedule(inst_stopes_module)
61+
```
62+
63+
5464

5565
## Step 2: Choose the predictor for evaluation
5666

@@ -121,7 +131,7 @@ uv run torchrun --standalone --nnodes=1 --nproc-per-node=1 -m lcm.evaluation \
121131
--dump_dir output_results
122132
```
123133

124-
Note the missing parameters `source_text_column` and `target_text_column` and the new parameters `source_prefix_text`, `target_prefix_text`, since in this case, we do not modify the column schema, therefore the original text columns ("article", "highlights") are kept and not specified in the CLI.
134+
> **_NOTE:_** the missing parameters `source_text_column` and `target_text_column` and the new parameters `source_prefix_text`, `target_prefix_text` are becase we do not modify the column schema. Therefore, the original text columns ("article", "highlights") are kept and not specified in the CLI.
125135
126136
It is also possible to provide the prompt from a YAML file. This is handy when you have to engineer the prompts carefully and have a very long detailed text. We provide one example prompt in the file [instruction.yaml](./instruction.yaml). The example command is:
127137

@@ -151,6 +161,10 @@ uv run torchrun --standalone --nnodes=1 --nproc-per-node=1 -m lcm.evaluation \
151161
--tasks lcm_generation \
152162
--task_args '{"max_gen_len": 200}' \
153163
--dataset.parquet_path parquet_dataset/cnn_dailymail \
164+
--dataset.source_column prompt_sentences_sonar_emb \
165+
--dataset.source_text_column prompt_sentences \
166+
--dataset.target_column answer_sentences_sonar_emb \
167+
--dataset.target_text_column answer_sentences \
154168
--data_loading.batch_size 16 \
155169
--dump_dir output_results
156170
```
@@ -168,13 +182,12 @@ Similar to LLM evaluation, it is possible to specify the prompt prefix and suffi
168182
| `data_loading.batch_size` | Loading and evaluate data in batch. By default `batch_size=10` |
169183
| `dataset_dir` | The directory consists of different JSONL files processed in Step 1. Only used in LLM evaluation
170184
| `dataset.parquet_path` | The parquet path consists of different Parquet files files processed in Step 1. Only used in LCM evaluation
171-
| `dataset.source_column` | The column in the data that refers to the input embedding. Not applicable when evaluating LLMs
172-
| `dataset.source_text_column` | The column in the data that refers to the input text. Not applicable when evaluating LCMs
173-
| `dataset.source_text_column` | The column in the data that refers to the input text. Not applicable when evaluating LCMs
174-
| `dataset.target_column` | The column in the data that refers to the ground-truth embedding. Not applicable when evaluating LLMs
175-
| `dataset.target_text_column` | The column in the data that refers to the ground-truth text. Not applicable when evaluating LCMs
185+
| `dataset.source_column` | The column in the data that refers to the input embedding. Not applicable when evaluating LLMs.
186+
| `dataset.source_text_column` | The column in the data that refers to the input text.
187+
| `dataset.target_column` | The column in the data that refers to the ground-truth embedding. Not applicable when evaluating LLMs.
188+
| `dataset.target_text_column` | The column in the data that refers to the ground-truth text.
176189
| `dataset.source_text_prefix` | The text that will prepended to each input text to make the prompt for the model.
177-
| `dataset.source_text_prefix` | The text that will appended after each input text to make the prompt for the model.
190+
| `dataset.source_text_suffix` | The text that will appended after each input text to make the prompt for the model.
178191
| `task_args` | The JSON-formatted string that represents the task arguments. See [task param list](#task_param_list) below.
179192
| `dump_dir` | The directory consisting output of the eval run. If successful, there should be a file `metrics.eval.jsonl` that consists of metric results, the directory `results` that capture the verbose command line used with the detailed output scores, and the directory `raw_results` that shows
180193
the model output for each individual sample, together with the per-sample metric results.
@@ -223,7 +236,7 @@ shards=NUMBER_OF_SLURM_NODES
223236
timeout_min=JOB_TIMEOUT_IN_MINUTES
224237

225238

226-
python -m lcm.evaluation \
239+
uv run -m lcm.evaluation \
227240
--predictor two_tower_diffusion_lcm \
228241
--model_card path/to/the/model_card.yaml \
229242
--generator_batch_size 16 \

examples/evaluation/prepare_evaluation_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ async def embed(
195195
target_text_column: Optional[str] = OUTPUT_KEY,
196196
lang: str = "eng_Latn",
197197
mode: Literal["local", "slurm"] = "local",
198+
shards: int = 1,
198199
log_dir: Optional[str] = None,
199200
):
200201
inst_sonar_config = SonarColumnRenameAndEmbedConfig(
@@ -212,6 +213,7 @@ async def embed(
212213
Path(input_path),
213214
batch_size=10, # iterating by small number of documents
214215
batch_format=BatchFormat.ARROW,
216+
num_shards=shards,
215217
)
216218

217219
output_config = ParquetOutputConfig(output_dir)
@@ -230,7 +232,7 @@ async def embed(
230232
config_dump_dir=Path(log_dir) / "conf",
231233
log_folder=Path(log_dir) / "logs",
232234
cluster=mode,
233-
update_parameters={"slurm_qos": "lcm_pretrain"},
235+
update_parameters={"partition": "learn"},
234236
)
235237
_ = await launcher.schedule(inst_stopes_module)
236238

lcm/train/two_tower_diffusion_lcm/criterion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44
#
55

6-
from dataclasses import dataclass
6+
from dataclasses import dataclass, field
77
from typing import List, Tuple
88

99
import torch
@@ -33,7 +33,9 @@ class TowerDiffusionLCMCriterionConfig(LCMCriterionConfig):
3333
Note that this requires the model to be set with
3434
`trained_with_cf_guidance = True`!
3535
"""
36-
step_sampling: StepsSamplerConfig = StepsSamplerConfig()
36+
step_sampling: StepsSamplerConfig = field(
37+
default_factory=lambda: StepsSamplerConfig()
38+
)
3739

3840
log_losses_per_timestep_bucket: bool = False
3941

scripts/prepare_wikipedia.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ def run(output_dir: Path):
8787
cache=None,
8888
cluster="local",
8989
# for SLURM you can set some parameters of the launcher here
90+
# cluster="slurm",
9091
# update_parameters={
91-
# "slurm_partition": "YOURPARTITION",
92+
# "partition": "learn",
9293
# },
9394
)
9495

tests/units/training/test_get_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __post_init__(self):
2323
@dataclass
2424
class Config:
2525
foobar: str = "test"
26-
cfg: Foo = Foo()
26+
cfg: Foo = field(default_factory=lambda: Foo())
2727
c: float = field(init=False)
2828

2929
def __post_init__(self):

uv.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)