Skip to content

Commit 8b01298

Browse files
committed
Remove trust_remote_code in favor of setting env variable
1 parent 6a0f956 commit 8b01298

File tree

3 files changed

+14
-18
lines changed

3 files changed

+14
-18
lines changed

src/llama_recipes/configs/datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ class samsum_dataset:
99
dataset: str = "samsum_dataset"
1010
train_split: str = "train"
1111
test_split: str = "validation"
12-
trust_remote_code: bool = False
1312

1413

1514
@dataclass

src/llama_recipes/datasets/samsum_dataset.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,22 @@
66
import copy
77
import datasets
88

9+
from unittest.mock import patch
10+
11+
@patch('builtins.input', return_value="N")
12+
def load_samsum(split, _):
13+
try:
14+
ds = datasets.load_dataset("Samsung/samsum", split=split)
15+
except ValueError as e:
16+
if "trust_remote_code" in str(e):
17+
raise ValueError("Loading Samsung/samsum requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set HF_DATASETS_TRUST_REMOTE_CODE env variable to True.") from e
18+
else:
19+
raise e
20+
return ds
21+
922

1023
def get_preprocessed_samsum(dataset_config, tokenizer, split):
11-
if not hasattr(dataset_config, "trust_remote_code") or not dataset_config.trust_remote_code:
12-
raise ValueError("The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum. To activate `trust_remote_code` option use this config: --samsum_dataset.trust_remote_code=True")
13-
dataset = datasets.load_dataset("samsum", split=split, trust_remote_code=dataset_config.trust_remote_code)
24+
dataset = load_samsum(split)
1425

1526
prompt = (
1627
f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"

src/tests/datasets/test_samsum_datasets.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,6 @@
55
from functools import partial
66
from unittest.mock import patch
77

8-
EXPECTED_RESULTS = {
9-
"meta-llama/Llama-2-7b-hf":{
10-
"label": 8432,
11-
"pos": 242,
12-
},
13-
"meta-llama/Meta-Llama-3.1-8B":{
14-
"label": 2250,
15-
"pos": 211,
16-
},
17-
}
18-
198
@pytest.mark.skip_missing_tokenizer
209
@patch('llama_recipes.finetuning.train')
2110
@patch('llama_recipes.finetuning.AutoTokenizer')
@@ -59,9 +48,6 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
5948
assert "input_ids" in batch.keys()
6049
assert "attention_mask" in batch.keys()
6150

62-
assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
63-
assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
64-
6551
assert batch["input_ids"][0][0] == token.bos_token_id
6652
assert batch["labels"][0][-1] == token.eos_token_id
6753
assert batch["input_ids"][0][-1] == token.eos_token_id

0 commit comments

Comments
 (0)