Skip to content

Commit 5aed31f

Browse files
d4l3kfacebook-github-bot
authored andcommitted
lightning_classy_vision: add model interpretability example (#54)
Summary: This is a model interpretability example. It loads the model from a training step and then runs integrated gradients on it via captum (https://captum.ai/tutorials/CIFAR_TorchVision_Interpret) and then writes out 5 samples to the output path. Pull Request resolved: #54 Test Plan: ``` cd examples/apps && torchx run --scheduler local lightning_classy_vision/component.py:interpret --image (pwd) --output_path /tmp/output --data_path /tmp/data.tar.gz --load_path /tmp/model-out/last.ckpt scripts/kfpint.py ``` http://5ab6bab9-istiosystem-istio-2af2-1926929629.us-west-2.elb.amazonaws.com/_/pipeline/?ns=torchx-dev#/runs/details/0a4a26f2-e1b2-4b99-be4e-3c0d74875770 Reviewed By: kiukchung Differential Revision: D29085948 Pulled By: d4l3k fbshipit-source-id: 3c3922e4f8eb6cdb97251eb20f3f403b4cb2bbbb
1 parent 20040ba commit 5aed31f

File tree

13 files changed

+542
-225
lines changed

13 files changed

+542
-225
lines changed

.pyre_configuration

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"source_directories": [
3+
"examples/apps/lightning_classy_vision",
34
"."
45
],
56
"strict": true,

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ classy-vision>=0.5.0
1111
flake8==3.9.0
1212
ts>=0.5.1
1313
torchserve>=0.4.0
14+
captum>=0.3.1

docs/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ help:
1919
clean:
2020
@echo "Deleting build directory"
2121
rm -rf "$(BUILDDIR)"
22+
rm -rf "$(SOURCEDIR)/examples_apps" "$(SOURCEDIR)/examples_pipelines"
2223

2324
.PHONY: help Makefile clean livehtml
2425

examples/apps/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
FROM pytorch/pytorch
22

3-
RUN pip install classy_vision pytorch-lightning fsspec[s3] torch-model-archiver
3+
RUN pip install classy_vision pytorch-lightning fsspec[s3] torch-model-archiver captum
44

55
WORKDIR /app
66

examples/apps/lightning_classy_vision/component.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,44 +11,81 @@
1111
This is a component definition that runs the example lightning_classy_vision app.
1212
"""
1313

14+
from typing import Optional
15+
1416
import torchx.specs.api as torchx
17+
from torchx.components.base import named_resource
18+
from torchx.components.base.binary_component import binary_component
1519

1620

17-
def classy_vision(
21+
def trainer(
1822
image: str,
1923
output_path: str,
24+
data_path: str,
2025
load_path: str = "",
2126
log_dir: str = "/logs",
27+
resource: Optional[str] = None,
2228
) -> torchx.AppDef:
2329
"""Runs the example lightning_classy_vision app.
2430
25-
Runs the example lightning_classy_vision app.
26-
2731
Args:
2832
image: image to run (e.g. foobar:latest)
29-
resource: resource spec
3033
output_path: output path for model checkpoints (e.g. file:///foo/bar)
3134
load_path: path to load pretrained model from
35+
data_path: path to the data to load
3236
log_dir: path to save tensorboard logs to
37+
resource: the resources to use
3338
"""
34-
entrypoint = "main"
35-
36-
trainer_role = (
37-
torchx.Role(
38-
name="trainer",
39-
image=image,
40-
resource=torchx.Resource(cpu=1, gpu=1, memMB=1024),
41-
)
42-
.runs(
43-
"main",
39+
return binary_component(
40+
name="examples-lightning_classy_vision-trainer",
41+
entrypoint="lightning_classy_vision/train.py",
42+
args=[
4443
"--output_path",
4544
output_path,
4645
"--load_path",
4746
load_path,
4847
"--log_dir",
4948
log_dir,
50-
)
51-
.replicas(1)
49+
"--data_path",
50+
data_path,
51+
],
52+
image=image,
53+
resource=named_resource(resource)
54+
if resource
55+
else torchx.Resource(cpu=1, gpu=0, memMB=1024),
5256
)
5357

54-
return torchx.AppDef("examples-lightning_classy_vision").of(trainer_role)
58+
59+
def interpret(
60+
image: str,
61+
load_path: str,
62+
data_path: str,
63+
output_path: str,
64+
resource: Optional[str] = None,
65+
) -> torchx.AppDef:
66+
"""Runs the model intepretability app on the model outputted by the training
67+
component.
68+
69+
Args:
70+
image: image to run (e.g. foobar:latest)
71+
load_path: path to load pretrained model from
72+
data_path: path to the data to load
73+
output_path: output path for model checkpoints (e.g. file:///foo/bar)
74+
resource: the resources to use
75+
"""
76+
return binary_component(
77+
name="examples-lightning_classy_vision-intepret",
78+
entrypoint="lightning_classy_vision/interpret.py",
79+
args=[
80+
"--load_path",
81+
load_path,
82+
"--data_path",
83+
data_path,
84+
"--output_path",
85+
output_path,
86+
],
87+
image=image,
88+
resource=named_resource(resource)
89+
if resource
90+
else torchx.Resource(cpu=1, gpu=0, memMB=1024),
91+
)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Trainer Datasets Example
9+
========================
10+
11+
This is the datasets used for the training example. It's using stock Pytorch
12+
Lightning + Classy Vision libraries.
13+
"""
14+
15+
import os.path
16+
import tarfile
17+
from typing import Optional, Callable
18+
19+
import fsspec
20+
import pytorch_lightning as pl
21+
from classy_vision.dataset.classy_dataset import ClassyDataset
22+
from torch.utils.data import DataLoader
23+
from torchvision import datasets, transforms
24+
25+
# %%
26+
# This uses classy vision to define a dataset that we will then later use in our
27+
# Pytorch Lightning data module.
28+
29+
30+
class TinyImageNetDataset(ClassyDataset):
31+
"""
32+
TinyImageNetDataset is a ClassyDataset for the tiny imagenet dataset.
33+
"""
34+
35+
def __init__(self, data_path: str, transform: Callable[[object], object]) -> None:
36+
batchsize_per_replica = 16
37+
shuffle = False
38+
num_samples = 1000
39+
dataset = datasets.ImageFolder(data_path)
40+
super().__init__(
41+
# pyre-fixme[6]
42+
dataset,
43+
batchsize_per_replica,
44+
shuffle,
45+
transform,
46+
num_samples,
47+
)
48+
49+
50+
# %%
51+
# For easy of use, we define a lightning data module so we can reuse it across
52+
# our trainer and other components that need to load data.
53+
54+
# pyre-fixme[13]: Attribute `test_ds` is never initialized.
55+
# pyre-fixme[13]: Attribute `train_ds` is never initialized.
56+
# pyre-fixme[13]: Attribute `val_ds` is never initialized.
57+
class TinyImageNetDataModule(pl.LightningDataModule):
58+
"""
59+
TinyImageNetDataModule is a pytorch LightningDataModule for the tiny
60+
imagenet dataset.
61+
"""
62+
63+
train_ds: TinyImageNetDataset
64+
val_ds: TinyImageNetDataset
65+
test_ds: TinyImageNetDataset
66+
67+
def __init__(self, data_dir: str, batch_size: int = 16) -> None:
68+
super().__init__()
69+
self.data_dir = data_dir
70+
self.batch_size = batch_size
71+
72+
def setup(self, stage: Optional[str] = None) -> None:
73+
# Setup data loader and transforms
74+
img_transform = transforms.Compose(
75+
[
76+
transforms.Grayscale(),
77+
transforms.ToTensor(),
78+
]
79+
)
80+
self.train_ds = TinyImageNetDataset(
81+
data_path=os.path.join(self.data_dir, "train"),
82+
transform=lambda x: (img_transform(x[0]), x[1]),
83+
)
84+
self.val_ds = TinyImageNetDataset(
85+
data_path=os.path.join(self.data_dir, "val"),
86+
transform=lambda x: (img_transform(x[0]), x[1]),
87+
)
88+
self.test_ds = TinyImageNetDataset(
89+
data_path=os.path.join(self.data_dir, "test"),
90+
transform=lambda x: (img_transform(x[0]), x[1]),
91+
)
92+
93+
def train_dataloader(self) -> DataLoader:
94+
# pyre-fixme[6]
95+
return DataLoader(self.train_ds, batch_size=self.batch_size)
96+
97+
def val_dataloader(self) -> DataLoader:
98+
# pyre-fixme[6]:
99+
return DataLoader(self.val_ds, batch_size=self.batch_size)
100+
101+
def test_dataloader(self) -> DataLoader:
102+
# pyre-fixme[6]
103+
return DataLoader(self.test_ds, batch_size=self.batch_size)
104+
105+
def teardown(self, stage: Optional[str] = None) -> None:
106+
pass
107+
108+
109+
# %%
110+
# To pass data between the different components we use fsspec which allows us to
111+
# read/write to cloud or local file storage.
112+
113+
114+
def download_data(remote_path: str, tmpdir: str) -> str:
115+
"""
116+
download_data downloads the training data from the specified remote path via
117+
fsspec and places it in the tmpdir unextracted.
118+
"""
119+
tar_path = os.path.join(tmpdir, "data.tar.gz")
120+
print(f"downloading dataset from {remote_path} to {tar_path}...")
121+
fs, _, rpaths = fsspec.get_fs_token_paths(remote_path)
122+
assert len(rpaths) == 1, "must have single path"
123+
fs.get(rpaths[0], tar_path)
124+
125+
data_path = os.path.join(tmpdir, "data")
126+
print(f"extracting {tar_path} to {data_path}...")
127+
with tarfile.open(tar_path, mode="r") as f:
128+
f.extractall(data_path)
129+
130+
return data_path
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
Model Interpretability App Example
10+
=============================================
11+
12+
This is an example TorchX app that uses captum to analyze inputs to for model
13+
intepretability purposes. It consumes the trained model from the trainer app
14+
example and the preprocessed examples from the datapreproc app example. The
15+
output is a series of images with integrated gradient attributions overlayed on
16+
them.
17+
18+
See https://captum.ai/tutorials/CIFAR_TorchVision_Interpret for more info on
19+
using captum.
20+
"""
21+
22+
import argparse
23+
import itertools
24+
import os.path
25+
import sys
26+
import tempfile
27+
from typing import List
28+
29+
import fsspec
30+
import torch
31+
from data import TinyImageNetDataModule, download_data
32+
from model import TinyImageNetModel
33+
34+
35+
# FIXME: captum must be imported after torch otherwise it causes python to crash
36+
if True:
37+
import numpy as np
38+
from captum.attr import IntegratedGradients
39+
from captum.attr import visualization as viz
40+
41+
42+
def parse_args(argv: List[str]) -> argparse.Namespace:
43+
parser = argparse.ArgumentParser(description="example TorchX captum app")
44+
parser.add_argument(
45+
"--load_path",
46+
type=str,
47+
help="checkpoint path to load model weights from",
48+
required=True,
49+
)
50+
parser.add_argument(
51+
"--data_path",
52+
type=str,
53+
help="path to load the training data from",
54+
required=True,
55+
)
56+
parser.add_argument(
57+
"--output_path",
58+
type=str,
59+
help="path to place analysis results",
60+
required=True,
61+
)
62+
63+
return parser.parse_args(argv)
64+
65+
66+
def convert_to_rgb(arr: torch.Tensor) -> np.ndarray:
67+
"""
68+
This converts the image from a torch tensor with size (1, 1, 64, 64) to
69+
numpy array with size (64, 64, 3).
70+
"""
71+
squeezed = arr.squeeze()
72+
stacked = torch.stack([squeezed, squeezed, squeezed], dim=2)
73+
print("img stats: ", stacked.count_nonzero(), stacked.mean())
74+
return stacked.numpy()
75+
76+
77+
def main(argv: List[str]) -> None:
78+
with tempfile.TemporaryDirectory() as tmpdir:
79+
args = parse_args(argv)
80+
81+
# Init our model
82+
model = TinyImageNetModel()
83+
84+
print(f"loading checkpoint: {args.load_path}...")
85+
model.load_from_checkpoint(checkpoint_path=args.load_path)
86+
87+
# Download and setup the data module
88+
data_path = download_data(args.data_path, tmpdir)
89+
data = TinyImageNetDataModule(
90+
data_dir=data_path,
91+
batch_size=1,
92+
)
93+
94+
ig = IntegratedGradients(model)
95+
96+
data.setup("test")
97+
dataloader = data.test_dataloader()
98+
99+
# process first 5 images
100+
for i, (input, label) in enumerate(itertools.islice(dataloader, 5)):
101+
print(f"analyzing example {i}")
102+
model.zero_grad()
103+
attr_ig, delta = ig.attribute(
104+
input,
105+
target=label,
106+
baselines=input * 0,
107+
return_convergence_delta=True,
108+
)
109+
110+
if attr_ig.count_nonzero() == 0:
111+
# Our toy model sometimes has no IG results.
112+
print("skipping due to zero gradients")
113+
continue
114+
115+
fig, axis = viz.visualize_image_attr(
116+
convert_to_rgb(attr_ig),
117+
convert_to_rgb(input),
118+
method="blended_heat_map",
119+
sign="all",
120+
show_colorbar=True,
121+
title="Overlayed Integrated Gradients",
122+
)
123+
out_path = os.path.join(args.output_path, f"ig_{i}.png")
124+
print(f"saving heatmap to {out_path}")
125+
with fsspec.open(out_path, "wb") as f:
126+
fig.savefig(f)
127+
128+
129+
if __name__ == "__main__":
130+
main(sys.argv[1:])

0 commit comments

Comments
 (0)