Skip to content

Commit 29c3dd8

Browse files
Manual download logic (#708)
* Add README * Allow to run manually downloaded datasets * Update README * make style * remove datasets dependency in tasks.py * Nit * Nicer error message in web interface * Fix error * \n * Fix character escape * ` is not a correct escapte character * I don't know * fix data_dir for cases where `subset_name` is `None` * tiny grammarly fixes Co-authored-by: Victor Sanh <[email protected]>
1 parent dba1d41 commit 29c3dd8

File tree

5 files changed

+51
-6
lines changed

5 files changed

+51
-6
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ collection = TemplateCollection()
7070
# and the value is an instance of DatasetTemplates
7171
print(collection.datasets_templates)
7272
```
73+
74+
## Running datasets that need manual download
75+
76+
Some datasets are not handled automatically by `datasets` and require users to download the dataset manually.
77+
78+
In order to handle those datasets as well, we require users to download the dataset and put it in `~/.cache/promptsource`. This is the root directory containing all manually downloaded datasets.
79+
80+
You can override this default path using `PROMPTSOURCE_MANUAL_DATASET_DIR` environment variable. This should point to the root directory.
81+
7382
## Contributing
7483
Contribution guidelines and step-by-step *HOW TO* are described [here](CONTRIBUTING.md).
7584

promptsource/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
DEFAULT_PROMPTSOURCE_CACHE_HOME = "~/.cache/promptsource"

promptsource/app.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def get_infos(d_name):
204204
fig.update_xaxes(visible=False, showticklabels=False)
205205
st.plotly_chart(fig, use_container_width=True)
206206
st.write(
207-
f"- Top 3 training subsets account for `{100*plot_df[:3]['Train size'].sum()/nb_training_instances:.2f}%` of the training instances."
207+
f"- Top 3 training subsets account for `{100 * plot_df[:3]['Train size'].sum() / nb_training_instances:.2f}%` of the training instances."
208208
)
209209
biggest_training_subset = plot_df.iloc[0]
210210
st.write(
@@ -257,7 +257,20 @@ def get_infos(d_name):
257257
if len(configs) > 0:
258258
conf_option = st.sidebar.selectbox("Subset", configs, index=0, format_func=lambda a: a.name)
259259

260-
dataset = get_dataset(dataset_key, str(conf_option.name) if conf_option else None)
260+
subset_name = str(conf_option.name) if conf_option else None
261+
try:
262+
dataset = get_dataset(dataset_key, subset_name)
263+
except OSError as e:
264+
st.error(
265+
f"Some datasets are not handled automatically by `datasets` and require users to download the "
266+
f"dataset manually. This applies to {dataset_key}{f'/{subset_name}' if subset_name is not None else ''}. "
267+
f"\n\nPlease download the raw dataset to `~/.cache/promptsource/{dataset_key}{f'/{subset_name}' if subset_name is not None else ''}`. "
268+
f"\n\nYou can choose another cache directory by overriding `PROMPTSOURCE_MANUAL_DATASET_DIR` environment "
269+
f"variable and downloading raw dataset to `$PROMPTSOURCE_MANUAL_DATASET_DIR/{dataset_key}{f'/{subset_name}' if subset_name is not None else ''}`"
270+
f"\n\nOriginal error:\n{str(e)}"
271+
)
272+
st.stop()
273+
261274
splits = list(dataset.keys())
262275
index = 0
263276
if "train" in splits:
@@ -596,7 +609,6 @@ def get_infos(d_name):
596609
st.write("Target")
597610
show_text(prompt[1], width=40)
598611

599-
600612
#
601613
# Must sync state at end
602614
#

promptsource/seqio_tasks/tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import functools
33
from typing import Dict, List, Optional, Tuple
44

5-
import datasets
65
import pkg_resources
76
import seqio
87
import t5
@@ -12,6 +11,7 @@
1211

1312
import promptsource.templates
1413
from promptsource.seqio_tasks import utils
14+
from promptsource.utils import load_dataset
1515

1616

1717
GET_METRICS = {
@@ -59,7 +59,7 @@ def postprocess_fn(output_or_target, example=None, is_target=False):
5959
def get_tf_dataset(split, shuffle_files, seed, dataset_name, subset_name, template, split_mapping):
6060
# HF datasets does not support file-level shuffling
6161
del shuffle_files, seed
62-
dataset = datasets.load_dataset(dataset_name, subset_name)
62+
dataset = load_dataset(dataset_name, subset_name)
6363
dataset = dataset[split_mapping[split]]
6464
dataset = utils.apply_template(dataset, template)
6565
return utils.hf_dataset_to_tf_dataset(dataset)

promptsource/utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# coding=utf-8
2+
import os
23

34
import datasets
45
import requests
56

7+
from promptsource import DEFAULT_PROMPTSOURCE_CACHE_HOME
68
from promptsource.templates import INCLUDED_USERS
79

810

@@ -49,7 +51,28 @@ def get_dataset(path, conf=None):
4951
builder_instance.download_and_prepare()
5052
return builder_instance.as_dataset()
5153
else:
52-
return datasets.load_dataset(path, conf)
54+
return load_dataset(path, conf)
55+
56+
57+
def load_dataset(dataset_name, subset_name):
58+
try:
59+
return datasets.load_dataset(dataset_name, subset_name)
60+
except datasets.builder.ManualDownloadError:
61+
cache_root_dir = (
62+
os.environ["PROMPTSOURCE_MANUAL_DATASET_DIR"]
63+
if "PROMPTSOURCE_MANUAL_DATASET_DIR" in os.environ
64+
else DEFAULT_PROMPTSOURCE_CACHE_HOME
65+
)
66+
data_dir = (
67+
f"{cache_root_dir}/{dataset_name}"
68+
if subset_name is None
69+
else f"{cache_root_dir}/{dataset_name}/{subset_name}"
70+
)
71+
return datasets.load_dataset(
72+
dataset_name,
73+
subset_name,
74+
data_dir=data_dir,
75+
)
5376

5477

5578
def get_dataset_confs(path):

0 commit comments

Comments
 (0)