Skip to content

Commit b38304b

Browse files
authored
Merge pull request #39 from p-lambda/dev
V1.1 updates
2 parents 28ef873 + 9c84fec commit b38304b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+2517
-522
lines changed

README.md

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pip install wilds
2929
If you have already installed it, please check that you have the latest version:
3030
```bash
3131
python -c "import wilds; print(wilds.__version__)"
32-
# This should print "1.0.0". If it doesn't, update by running:
32+
# This should print "1.1.0". If it doesn't, update by running:
3333
pip install -U wilds
3434
```
3535

@@ -42,15 +42,15 @@ pip install -e .
4242

4343
### Requirements
4444
- numpy>=1.19.1
45+
- ogb>=1.2.6
46+
- outdated>=0.2.0
4547
- pandas>=1.1.0
4648
- pillow>=7.2.0
47-
- torch>=1.7.0
48-
- tqdm>=4.53.0
4949
- pytz>=2020.4
50-
- outdated>=0.2.0
51-
- ogb>=1.2.3
50+
- torch>=1.7.0
5251
- torch-scatter>=2.0.5
5352
- torch-geometric>=1.6.1
53+
- tqdm>=4.53.0
5454

5555
Running `pip install wilds` or `pip install -e .` will automatically check for and install all of these requirements
5656
except for the `torch-scatter` and `torch-geometric` packages, which require a [quick manual install](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html#installation-via-binaries).
@@ -70,39 +70,69 @@ To run these scripts, you will need to install these additional dependencies:
7070

7171
All baseline experiments in the paper were run on Python 3.8.5 and CUDA 10.1.
7272

73-
## Usage
74-
### Default models
75-
In the `examples/` folder, we provide a set of scripts that we used to train models on the WILDS package. These scripts are configured with the default models and hyperparameters that we used for all of the baselines described in our paper. All baseline results in the paper can be easily replicated with commands like:
73+
74+
## Using the example scripts
75+
76+
In the `examples/` folder, we provide a set of scripts that can be used to download WILDS datasets and train models on them.
77+
These scripts are configured with the default models and hyperparameters that we used for all of the baselines described in our paper. All baseline results in the paper can be easily replicated with commands like:
7678

7779
```bash
78-
cd examples
79-
python run_expt.py --dataset iwildcam --algorithm ERM --root_dir data
80-
python run_expt.py --dataset civilcomments --algorithm groupDRO --root_dir data
80+
python examples/run_expt.py --dataset iwildcam --algorithm ERM --root_dir data
81+
python examples/run_expt.py --dataset civilcomments --algorithm groupDRO --root_dir data
8182
```
8283

8384
The scripts are set up to facilitate general-purpose algorithm development: new algorithms can be added to `examples/algorithms` and then run on all of the WILDS datasets using the default models.
8485

8586
The first time you run these scripts, you might need to download the datasets. You can do so with the `--download` argument, for example:
8687
```
87-
python run_expt.py --dataset civilcomments --algorithm groupDRO --root_dir data --download
88+
python examples/run_expt.py --dataset civilcomments --algorithm groupDRO --root_dir data --download
8889
```
8990

91+
Alternatively, you can use the standalone `wilds/download_datasets.py` script to download the datasets, for example:
92+
93+
```bash
94+
python wilds/download_datasets.py --root_dir data
95+
```
96+
97+
This will download all datasets to the specified `data` folder. You can also use the `--datasets` argument to download particular datasets.
98+
99+
These are the sizes of each of our datasets, as well as their approximate time taken to train and evaluate the default model for a single ERM run using a NVIDIA V100 GPU.
100+
101+
| Dataset command | Modality | Download size (GB) | Size on disk (GB) | Train+eval time (Hours) |
102+
|-----------------|----------|--------------------|-------------------|-------------------------|
103+
| iwildcam | Image | 11 | 25 | 7 |
104+
| camelyon17 | Image | 10 | 15 | 2 |
105+
| ogb-molpcba | Graph | 0.04 | 2 | 15 |
106+
| civilcomments | Text | 0.1 | 0.3 | 4.5 |
107+
| fmow | Image | 50 | 55 | 6 |
108+
| poverty | Image | 12 | 14 | 5 |
109+
| amazon | Text | 6.6 | 7 | 5 |
110+
| py150 | Text | 0.1 | 0.8 | 9.5 |
111+
112+
While the `camelyon17` dataset is small and fast to train on, we advise against using it as the only dataset to prototype methods on, as the test performance of models trained on this dataset tend to exhibit a large degree of variability over random seeds.
113+
114+
The image datasets (`iwildcam`, `camelyon17`, `fmow`, and `poverty`) tend to have high disk I/O usage. If training time is much slower for you than the approximate times listed above, consider checking if I/O is a bottleneck (e.g., by moving to a local disk if you are using a network drive, or by increasing the number of data loader workers). To speed up training, you could also disable evaluation at each epoch or for all splits by toggling `--evaluate_all_splits` and related arguments.
115+
116+
We have an [executable version](https://wilds.stanford.edu/codalab) of our paper on CodaLab that contains the exact commands, code, and data used for the experiments reported in our paper. Trained model weights for all datasets can also be found there.
117+
118+
119+
## Using the WILDS package
90120
### Data loading
91121

92122
The WILDS package provides a simple, standardized interface for all datasets in the benchmark.
93123
This short Python snippet covers all of the steps of getting started with a WILDS dataset, including dataset download and initialization, accessing various splits, and preparing a user-customizable data loader.
94124

95125
```py
96-
>>> from wilds.datasets.iwildcam_dataset import IWildCamDataset
126+
>>> from wilds import get_dataset
97127
>>> from wilds.common.data_loaders import get_train_loader
98128
>>> import torchvision.transforms as transforms
99129

100130
# Load the full dataset, and download it if necessary
101-
>>> dataset = IWildCamDataset(download=True)
131+
>>> dataset = get_dataset(dataset='iwildcam', download=True)
102132

103133
# Get the training set
104134
>>> train_data = dataset.get_subset('train',
105-
... transform=transforms.Compose([transforms.Resize((224,224)),
135+
... transform=transforms.Compose([transforms.Resize((448,448)),
106136
... transforms.ToTensor()]))
107137

108138
# Prepare the standard data loader
@@ -171,11 +201,12 @@ Invoking the `eval` method of each dataset yields all metrics reported in the pa
171201
>>> dataset.eval(all_y_pred, all_y_true, all_metadata)
172202
{'recall_macro_all': 0.66, ...}
173203
```
204+
Most `eval` methods take in predicted labels for `all_y_pred` by default, but the default inputs vary across datasets and are documented in the `eval` docstrings of the corresponding dataset class.
174205

175206
## Citing WILDS
176207
If you use WILDS datasets in your work, please cite [our paper](https://arxiv.org/abs/2012.07421) ([Bibtex](https://wilds.stanford.edu/assets/files/bibtex.md)):
177208

178-
- **WILDS: A Benchmark of in-the-Wild Distribution Shifts** (2020). Pang Wei Koh*, Shiori Sagawa*, Henrik Marklund, Sang Michael Xie, Marvin Zhang, Akshay Balsubramani, Weihua Hu, Michihiro Yasunaga, Richard Lanas Phillips, Sara Beery, Jure Leskovec, Anshul Kundaje, Emma Pierson, Sergey Levine, Chelsea Finn, and Percy Liang.
209+
- **WILDS: A Benchmark of in-the-Wild Distribution Shifts** (2021). Pang Wei Koh*, Shiori Sagawa*, Henrik Marklund, Sang Michael Xie, Marvin Zhang, Akshay Balsubramani, Weihua Hu, Michihiro Yasunaga, Richard Lanas Phillips, Irena Gao, Tony Lee, Etienne David, Ian Stavness, Wei Guo, Berton A. Earnshaw, Imran S. Haque, Sara Beery, Jure Leskovec, Anshul Kundaje, Emma Pierson, Sergey Levine, Chelsea Finn, and Percy Liang.
179210

180211
Please also cite the original papers that introduce the datasets, as listed on the [datasets page](https://wilds.stanford.edu/datasets/).
181212

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import argparse
2+
import csv
3+
import os
4+
5+
import pandas as pd
6+
import numpy as np
7+
8+
# Fix the seed for reproducibility
9+
np.random.seed(0)
10+
11+
"""
12+
Subsample the Amazon dataset.
13+
14+
Usage:
15+
python dataset_preprocessing/amazon_yelp/subsample_amazon.py <path> <frac>
16+
"""
17+
18+
NOT_IN_DATASET = -1
19+
# Split: {'train': 0, 'val': 1, 'id_val': 2, 'test': 3, 'id_test': 4}
20+
TRAIN, OOD_VAL, ID_VAL, OOD_TEST, ID_TEST = range(5)
21+
22+
23+
def main(dataset_path, frac=0.25):
24+
def output_dataset_sizes(split_df):
25+
print("-" * 50)
26+
print(f'Train size: {len(split_df[split_df["split"] == TRAIN])}')
27+
print(f'Val size: {len(split_df[split_df["split"] == OOD_VAL])}')
28+
print(f'ID Val size: {len(split_df[split_df["split"] == ID_VAL])}')
29+
print(f'Test size: {len(split_df[split_df["split"] == OOD_TEST])}')
30+
print(f'ID Test size: {len(split_df[split_df["split"] == ID_TEST])}')
31+
print(
32+
f'Number of examples not included: {len(split_df[split_df["split"] == NOT_IN_DATASET])}'
33+
)
34+
print("-" * 50)
35+
print("\n")
36+
37+
data_df = pd.read_csv(
38+
os.path.join(dataset_path, "reviews.csv"),
39+
dtype={
40+
"reviewerID": str,
41+
"asin": str,
42+
"reviewTime": str,
43+
"unixReviewTime": int,
44+
"reviewText": str,
45+
"summary": str,
46+
"verified": bool,
47+
"category": str,
48+
"reviewYear": int,
49+
},
50+
keep_default_na=False,
51+
na_values=[],
52+
quoting=csv.QUOTE_NONNUMERIC,
53+
)
54+
55+
user_csv_path = os.path.join(dataset_path, "splits", "user.csv")
56+
split_df = pd.read_csv(user_csv_path)
57+
output_dataset_sizes(split_df)
58+
59+
train_data_df = data_df[split_df["split"] == 0]
60+
train_reviewer_ids = train_data_df.reviewerID.unique()
61+
print(f"Number of unique reviewers in train set: {len(train_reviewer_ids)}")
62+
63+
# Randomly sample (1 - frac) x number of reviewers
64+
# Blackout all the reviews belonging to the randomly sampled reviewers
65+
subsampled_reviewers_count = int((1 - frac) * len(train_reviewer_ids))
66+
subsampled_reviewers = np.random.choice(
67+
train_reviewer_ids, subsampled_reviewers_count, replace=False
68+
)
69+
print(subsampled_reviewers)
70+
71+
blackout_indices = train_data_df[
72+
train_data_df["reviewerID"].isin(subsampled_reviewers)
73+
].index
74+
75+
# Mark all the corresponding reviews of blackout_indices as -1
76+
split_df.loc[blackout_indices, "split"] = NOT_IN_DATASET
77+
output_dataset_sizes(split_df)
78+
79+
# Mark duplicates
80+
duplicated_within_user = data_df[["reviewerID", "reviewText"]].duplicated()
81+
df_deduplicated_within_user = data_df[~duplicated_within_user]
82+
duplicated_text = df_deduplicated_within_user[
83+
df_deduplicated_within_user["reviewText"]
84+
.apply(lambda x: x.lower())
85+
.duplicated(keep=False)
86+
]["reviewText"]
87+
duplicated_text = set(duplicated_text.values)
88+
data_df["duplicate"] = (
89+
data_df["reviewText"].isin(duplicated_text)
90+
) | duplicated_within_user
91+
92+
# Mark html candidates
93+
data_df["contains_html"] = data_df["reviewText"].apply(
94+
lambda x: "<" in x and ">" in x
95+
)
96+
97+
# Mark clean ones
98+
data_df["clean"] = ~data_df["duplicate"] & ~data_df["contains_html"]
99+
100+
# Clear ID val and ID test since we're regenerating
101+
split_df.loc[split_df["split"] == ID_VAL, "split"] = NOT_IN_DATASET
102+
split_df.loc[split_df["split"] == ID_TEST, "split"] = NOT_IN_DATASET
103+
104+
# Regenerate ID val and ID test
105+
train_reviewer_ids = data_df[split_df["split"] == TRAIN]["reviewerID"].unique()
106+
np.random.shuffle(train_reviewer_ids)
107+
cutoff = int(len(train_reviewer_ids) / 2)
108+
id_val_reviewer_ids = train_reviewer_ids[:cutoff]
109+
id_test_reviewer_ids = train_reviewer_ids[cutoff:]
110+
split_df.loc[
111+
(split_df["split"] == NOT_IN_DATASET)
112+
& data_df["clean"]
113+
& data_df["reviewerID"].isin(id_val_reviewer_ids),
114+
"split",
115+
] = ID_VAL
116+
split_df.loc[
117+
(split_df["split"] == NOT_IN_DATASET)
118+
& data_df["clean"]
119+
& data_df["reviewerID"].isin(id_test_reviewer_ids),
120+
"split",
121+
] = ID_TEST
122+
123+
# Sanity check
124+
assert (
125+
data_df[(split_df["split"] == ID_VAL)]["reviewerID"].value_counts().min() == 75
126+
)
127+
assert (
128+
data_df[(split_df["split"] == ID_VAL)]["reviewerID"].value_counts().max() == 75
129+
)
130+
assert (
131+
data_df[(split_df["split"] == ID_TEST)]["reviewerID"].value_counts().min() == 75
132+
)
133+
assert (
134+
data_df[(split_df["split"] == ID_TEST)]["reviewerID"].value_counts().max() == 75
135+
)
136+
137+
# Write out the new splits to user.csv
138+
output_dataset_sizes(split_df)
139+
split_df.to_csv(user_csv_path, index=False)
140+
print("Done.")
141+
142+
143+
if __name__ == "__main__":
144+
parser = argparse.ArgumentParser(description="Subsample the Amazon dataset.")
145+
parser.add_argument(
146+
"path",
147+
type=str,
148+
help="Path to the Amazon dataset",
149+
)
150+
parser.add_argument(
151+
"frac",
152+
type=float,
153+
help="Subsample fraction",
154+
)
155+
156+
args = parser.parse_args()
157+
main(args.path, args.frac)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os, sys
2+
import argparse
3+
import numpy as np
4+
from PIL import Image
5+
from pathlib import Path
6+
from tqdm import tqdm
7+
8+
def main():
9+
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument('--root_dir', required=True,
12+
help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')
13+
config = parser.parse_args()
14+
data_dir = Path(config.root_dir) / 'fmow_v1.0'
15+
image_dir = Path(config.root_dir) / 'fmow_v1.0_images_jpg'
16+
os.makedirs(image_dir, exist_ok=True)
17+
18+
img_counter = 0
19+
for chunk in tqdm(range(101)):
20+
npy_chunk = np.load(data_dir / f'rgb_all_imgs_{chunk}.npy', mmap_mode='r')
21+
for i in range(len(npy_chunk)):
22+
npy_image = npy_chunk[i]
23+
img = Image.fromarray(npy_image, mode='RGB')
24+
img.save(image_dir / f'rgb_img_{img_counter}.jpg')
25+
img_counter += 1
26+
27+
if __name__=='__main__':
28+
main()

0 commit comments

Comments
 (0)