Skip to content

Commit d3467c9

Browse files
committed
update
0 parents  commit d3467c9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+278952
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
**/__pycache__
2+
checkpoints

README.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# [Hierarchical Prompt Learning Using CLIP for Multi-label Classification with Single Positive Labels](https://dl.acm.org/doi/pdf/10.1145/3581783.3611988)
2+
3+
Official PyTorch Implementation of **HSPNet**, from the following paper:
4+
5+
[Hierarchical Prompt Learning Using CLIP for Multi-label Classification with Single Positive Labels](https://dl.acm.org/doi/pdf/10.1145/3581783.3611988). ACMMM 2023.
6+
7+
> Ao Wang, Hui Chen, Zijia Lin, Zixuan Ding, Pengzhang Liu, Yongjun Bao, Weipeng Yan, and Guiguang Ding
8+
9+
**Abstract**
10+
11+
Collecting full annotations to construct multi-label datasets is difficult and labor-consuming. As an effective solution to relieve the annotation burden, single positive multi-label learning (SPML) draws increasing attention from both academia and industry. It only annotates each image with one positive label, leaving other labels unobserved. Therefore, existing methods strive to explore the cue of unobserved labels to compensate for the insufficiency of label supervision. Though achieving promising performance, they generally consider labels independently, leaving out the inherent hierarchical semantic relationship among labels which reveals that labels can be clustered into groups. In this paper, we propose a hierarchical prompt learning method with a novel Hierarchical Semantic Prompt Network (HSPNet) to harness such hierarchical semantic relationships using a large-scale pretrained vision and language model, i.e., CLIP, for SPML. We first introduce a Hierarchical Conditional Prompt (HCP) strategy to grasp the hierarchical label-group dependency. Then we equip a Hierarchical Graph Convolutional Network (HGCN) to capture the high-order inter-label and inter-group dependencies. Comprehensive experiments and analyses on several benchmark datasets show that our method significantly outperforms the state-of-the-art methods, well demonstrating its superiority and effectiveness.
12+
13+
## Credit to previous work
14+
This repository is built upon the code base of [ASL](https://github.com/Alibaba-MIIL/ASL) and [SPLC](https://github.com/xinyu1205/robust-loss-mlml), thanks very much!
15+
16+
## Performance
17+
18+
| Dataset | mAP | Ckpt | Log |
19+
|:---: | :---: | :---: | :---: |
20+
| COCO | 75.7 | [hspnet+coco.ckpt](https://github.com/jameslahm/HSPNet/releases/download/v1.0/hspnet+coco.ckpt) | [hspnet+coco.txt](logs/hspnet+coco.txt) |
21+
| VOC | 90.4 | [hspnet+voc.ckpt](https://github.com/jameslahm/HSPNet/releases/download/v1.0/hspnet+voc.ckpt) | [hspnet+voc.txt](logs/hspnet+voc.txt) |
22+
| NUSWIDE | 61.8 | [hspnet+nuswide.ckpt](https://github.com/jameslahm/HSPNet/releases/download/v1.0/hspnet+nuswide.ckpt) | [hspnet+nuswide.txt](logs/hspnet+nuswide.txt) |
23+
| CUB | 24.3 | [hspnet+cub.ckpt]() | [hspnet+cub.txt](logs/hspnet+cub.txt) |
24+
25+
## Training
26+
27+
### COCO
28+
```python
29+
python train.py -c configs/hspnet+coco.yaml
30+
```
31+
32+
### VOC
33+
```python
34+
python train.py -c configs/hspnet+voc.yaml
35+
```
36+
37+
### NUSWIDE
38+
```python
39+
python train.py -c configs/hspnet+nuswide.yaml
40+
```
41+
42+
### CUB
43+
```python
44+
python train.py -c configs/hspnet+cub.yaml
45+
```
46+
47+
## Inference
48+
49+
> Note: Please place the pretrained checkpoint to checkpoints/hspnet+coco/round1/model-highest.ckpt
50+
51+
#### COCO
52+
```python
53+
python train.py -c configs/hspnet+coco.yaml -t -r 1
54+
```
55+
56+
#### VOC
57+
```python
58+
python train.py -c configs/hspnet+voc.yaml -t -r 1
59+
```
60+
61+
#### NUSWIDE
62+
```python
63+
python train.py -c configs/hspnet+nuswide.yaml -t -r 1
64+
```
65+
66+
#### CUB
67+
```python
68+
python train.py -c configs/hspnet+cub.yaml -t -r 1
69+
```
70+
71+
## Citation
72+
```
73+
@inproceedings{wang2023hierarchical,
74+
title={Hierarchical prompt learning using clip for multi-label classification with single positive labels},
75+
author={Wang, Ao and Chen, Hui and Lin, Zijia and Ding, Zixuan and Liu, Pengzhang and Bao, Yongjun and Yan, Weipeng and Ding, Guiguang},
76+
booktitle={Proceedings of the 31st ACM International Conference on Multimedia},
77+
pages={5594--5604},
78+
year={2023}
79+
}
80+
```

args.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import argparse
2+
3+
parser = argparse.ArgumentParser(description='PyTorch MS_COCO Training')
4+
parser.add_argument('-c',
5+
'--config-file',
6+
help='config file',
7+
default='configs/base.yaml',
8+
type=str)
9+
parser.add_argument('-t',
10+
'--test',
11+
help='run test',
12+
default=False,
13+
action="store_true")
14+
parser.add_argument('-r', '--round', help='round', default=1, type=int)
15+
parser.add_argument('--resume', default=False, action='store_true')
16+
args = parser.parse_args()

clip/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .clip import *

clip/bpe_simple_vocab_16e6.txt.gz

1.29 MB
Binary file not shown.

clip/clip.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
import hashlib
2+
import os
3+
import urllib
4+
import warnings
5+
from typing import List, Union
6+
7+
import torch
8+
from PIL import Image
9+
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
10+
ToTensor)
11+
from tqdm import tqdm
12+
13+
from .model import build_model
14+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15+
16+
try:
17+
from torchvision.transforms import InterpolationMode
18+
BICUBIC = InterpolationMode.BICUBIC
19+
except ImportError:
20+
BICUBIC = Image.BICUBIC
21+
22+
if torch.__version__.split(".") < ["1", "7", "1"]:
23+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
24+
25+
__all__ = ["available_models", "load", "tokenize"]
26+
_tokenizer = _Tokenizer()
27+
28+
_MODELS = {
29+
"RN50":
30+
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31+
"RN101":
32+
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33+
"RN50x4":
34+
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
35+
"RN50x16":
36+
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
37+
"ViT-B/32":
38+
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
39+
"ViT-B/16":
40+
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
41+
}
42+
43+
44+
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
45+
os.makedirs(root, exist_ok=True)
46+
filename = os.path.basename(url)
47+
48+
expected_sha256 = url.split("/")[-2]
49+
download_target = os.path.join(root, filename)
50+
51+
if os.path.exists(download_target) and not os.path.isfile(download_target):
52+
raise RuntimeError(
53+
f"{download_target} exists and is not a regular file")
54+
55+
if os.path.isfile(download_target):
56+
if hashlib.sha256(open(download_target,
57+
"rb").read()).hexdigest() == expected_sha256:
58+
return download_target
59+
else:
60+
warnings.warn(
61+
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
62+
)
63+
64+
with urllib.request.urlopen(url) as source, open(download_target,
65+
"wb") as output:
66+
with tqdm(total=int(source.info().get("Content-Length")),
67+
ncols=80,
68+
unit='iB',
69+
unit_scale=True) as loop:
70+
while True:
71+
buffer = source.read(8192)
72+
if not buffer:
73+
break
74+
75+
output.write(buffer)
76+
loop.update(len(buffer))
77+
78+
if hashlib.sha256(open(download_target,
79+
"rb").read()).hexdigest() != expected_sha256:
80+
raise RuntimeError(
81+
"Model has been downloaded but the SHA256 checksum does not not match"
82+
)
83+
84+
return download_target
85+
86+
87+
def _transform(n_px):
88+
return Compose([
89+
Resize(n_px, interpolation=BICUBIC),
90+
CenterCrop(n_px),
91+
lambda image: image.convert("RGB"),
92+
ToTensor(),
93+
Normalize((0.48145466, 0.4578275, 0.40821073),
94+
(0.26862954, 0.26130258, 0.27577711)),
95+
])
96+
97+
98+
def available_models() -> List[str]:
99+
"""Returns the names of available CLIP models"""
100+
return list(_MODELS.keys())
101+
102+
103+
def load(name: str,
104+
device: Union[str, torch.device] = "cuda"
105+
if torch.cuda.is_available() else "cpu",
106+
jit=False):
107+
"""Load a CLIP model
108+
109+
Parameters
110+
----------
111+
name : str
112+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
113+
114+
device : Union[str, torch.device]
115+
The device to put the loaded model
116+
117+
jit : bool
118+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
119+
120+
Returns
121+
-------
122+
model : torch.nn.Module
123+
The CLIP model
124+
125+
preprocess : Callable[[PIL.Image], torch.Tensor]
126+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
127+
"""
128+
if name in _MODELS:
129+
model_path = _download(_MODELS[name])
130+
elif os.path.isfile(name):
131+
model_path = name
132+
else:
133+
raise RuntimeError(
134+
f"Model {name} not found; available models = {available_models()}")
135+
136+
try:
137+
# loading JIT archive
138+
model = torch.jit.load(model_path,
139+
map_location=device if jit else "cpu").eval()
140+
state_dict = None
141+
except RuntimeError:
142+
# loading saved state dict
143+
if jit:
144+
warnings.warn(
145+
f"File {model_path} is not a JIT archive. Loading as a state dict instead"
146+
)
147+
jit = False
148+
state_dict = torch.load(model_path, map_location="cpu")
149+
150+
if not jit:
151+
model = build_model(state_dict or model.state_dict()).to(device)
152+
if str(device) == "cpu":
153+
model.float()
154+
return model, _transform(model.visual.input_resolution)
155+
156+
# patch the device names
157+
device_holder = torch.jit.trace(
158+
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
159+
device_node = [
160+
n for n in device_holder.graph.findAllNodes("prim::Constant")
161+
if "Device" in repr(n)
162+
][-1]
163+
164+
def patch_device(module):
165+
try:
166+
graphs = [module.graph] if hasattr(module, "graph") else []
167+
except RuntimeError:
168+
graphs = []
169+
170+
if hasattr(module, "forward1"):
171+
graphs.append(module.forward1.graph)
172+
173+
for graph in graphs:
174+
for node in graph.findAllNodes("prim::Constant"):
175+
if "value" in node.attributeNames() and str(
176+
node["value"]).startswith("cuda"):
177+
node.copyAttributes(device_node)
178+
179+
model.apply(patch_device)
180+
patch_device(model.encode_image)
181+
patch_device(model.encode_text)
182+
183+
# patch dtype to float32 on CPU
184+
if str(device) == "cpu":
185+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(),
186+
example_inputs=[])
187+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
188+
float_node = float_input.node()
189+
190+
def patch_float(module):
191+
try:
192+
graphs = [module.graph] if hasattr(module, "graph") else []
193+
except RuntimeError:
194+
graphs = []
195+
196+
if hasattr(module, "forward1"):
197+
graphs.append(module.forward1.graph)
198+
199+
for graph in graphs:
200+
for node in graph.findAllNodes("aten::to"):
201+
inputs = list(node.inputs())
202+
for i in [
203+
1, 2
204+
]: # dtype can be the second or third argument to aten::to()
205+
if inputs[i].node()["value"] == 5:
206+
inputs[i].node().copyAttributes(float_node)
207+
208+
model.apply(patch_float)
209+
patch_float(model.encode_image)
210+
patch_float(model.encode_text)
211+
212+
model.float()
213+
214+
return model, _transform(model.input_resolution.item())
215+
216+
217+
def tokenize(texts: Union[str, List[str]],
218+
context_length: int = 77,
219+
truncate: bool = False) -> torch.LongTensor:
220+
"""
221+
Returns the tokenized representation of given input string(s)
222+
223+
Parameters
224+
----------
225+
texts : Union[str, List[str]]
226+
An input string or a list of input strings to tokenize
227+
228+
context_length : int
229+
The context length to use; all CLIP models use 77 as the context length
230+
231+
truncate: bool
232+
Whether to truncate the text in case its encoding is longer than the context length
233+
234+
Returns
235+
-------
236+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
237+
"""
238+
if isinstance(texts, str):
239+
texts = [texts]
240+
241+
sot_token = _tokenizer.encoder["<|startoftext|>"]
242+
eot_token = _tokenizer.encoder["<|endoftext|>"]
243+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
244+
for text in texts]
245+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
246+
247+
for i, tokens in enumerate(all_tokens):
248+
if len(tokens) > context_length:
249+
if truncate:
250+
tokens = tokens[:context_length]
251+
tokens[-1] = eot_token
252+
else:
253+
raise RuntimeError(
254+
f"Input {texts[i]} is too long for context length {context_length}"
255+
)
256+
result[i, :len(tokens)] = torch.tensor(tokens)
257+
258+
return result

0 commit comments

Comments
 (0)