Skip to content

Commit 48ba680

Browse files
authored
Enable users to trust remote code in samsum dataset (meta-llama#628)
1 parent 9b3dabc commit 48ba680

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

src/llama_recipes/configs/datasets.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class samsum_dataset:
99
dataset: str = "samsum_dataset"
1010
train_split: str = "train"
1111
test_split: str = "validation"
12+
trust_remote_code: bool = False
1213

1314

1415
@dataclass
@@ -37,4 +38,4 @@ class custom_dataset:
3738
class llamaguard_toxicchat_dataset:
3839
dataset: str = "llamaguard_toxicchat_dataset"
3940
train_split: str = "train"
40-
test_split: str = "test"
41+
test_split: str = "test"

src/llama_recipes/datasets/samsum_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99

1010
def get_preprocessed_samsum(dataset_config, tokenizer, split):
11-
dataset = datasets.load_dataset("samsum", split=split)
11+
if not hasattr(dataset_config, "trust_remote_code") or not dataset_config.trust_remote_code:
12+
raise ValueError("The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum. To activate `trust_remote_code` option use this config: --samsum_dataset.trust_remote_code=True")
13+
dataset = datasets.load_dataset("samsum", split=split, trust_remote_code=dataset_config.trust_remote_code)
1214

1315
prompt = (
1416
f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"

0 commit comments

Comments
 (0)