Skip to content

Commit 86708ca

Browse files
committed
init
1 parent 5cb41e4 commit 86708ca

40 files changed

+279364
-1
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
**/__pycache__
2+
checkpoints

README.md

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,92 @@
1-
# Official code for Exploring Structured Semantic Prior for Multi Label Recognition with Incomplete Labels (CVPR 2023)
1+
# Exploring Structured Semantic Prior for Multi Label Recognition with Incomplete Labels
2+
3+
Official PyTorch Implementation of the paper [Exploring Structured Semantic Prior
4+
for Multi Label Recognition with Incomplete Labels](https://openaccess.thecvf.com/content/CVPR2023/papers/Ding_Exploring_Structured_Semantic_Prior_for_Multi_Label_Recognition_With_Incomplete_CVPR_2023_paper.pdf)
5+
6+
> Zixuan Ding*, Ao Wang*, Hui Chen†, Qiang Zhang, Pengzhang Liu, Yongjun Bao, Weipeng Yan, Jungong Han,
7+
> <br/> Xidian University, Tsinghua University, JD.com
8+
9+
10+
**Abstract**
11+
12+
Multi-label recognition (MLR) with incomplete labels is very challenging. Recent works strive to explore the imageto-label correspondence in the vision-language model, i.e., CLIP, to compensate for insufficient annotations. In spite of promising performance, they generally overlook the
13+
valuable prior about the label-to-label correspondence. In this paper, we advocate remedying the deficiency of label supervision for the MLR with incomplete labels by deriving a structured semantic prior about the label-to-label correspondence via a semantic prior prompter. We then present a novel Semantic Correspondence Prompt Network (SCPNet), which can thoroughly explore the structured semantic prior. A Prior-Enhanced Self-Supervised Learning method is further introduced to enhance the use of the prior. Comprehensive experiments and analyses on several widely used
14+
benchmark datasets show that our method significantly outperforms existing methods on all datasets, well demonstrating the effectiveness and the superiority of our method.
15+
16+
<p align="center">
17+
<table class="tg">
18+
<tr>
19+
<td class="tg-c3ow"><img src="./figures/overview.png" align="center" width="600" ></td>
20+
</tr>
21+
</table>
22+
</p>
23+
24+
25+
## Credit to previous work
26+
This repository is built upon the code base of [ASL](https://github.com/Alibaba-MIIL/ASL) and [SPLC](https://github.com/xinyu1205/robust-loss-mlml), thanks very much!
27+
28+
## Performance
29+
30+
| Dataset | mAP | Ckpt | Log |
31+
|:---: | :---: | :---: | :---: |
32+
| COCO | 76.4 | scpnet+coco.ckpt | [scpnet+coco.txt](logs/scpnet+coco.txt) |
33+
| VOC | 91.2 | scpnet+voc.ckpt | [scpnet+voc.txt](logs/scpnet+voc.txt) |
34+
| NUSWIDE | 62.0 | scpnet+nuswide.ckpt | [scpnet+nuswide.txt](logs/scpnet+nuswide.txt) |
35+
| CUB | 25.7 | scpnet+cub.ckpt | [scpnet+cub.txt](logs/scpnet+cub.txt) |
36+
37+
## Training
38+
39+
### COCO
40+
```python
41+
python train.py -c configs/scpnet+coco.yaml
42+
```
43+
44+
### VOC
45+
```python
46+
python train.py -c configs/scpnet+voc.yaml
47+
```
48+
49+
### NUSWIDE
50+
```python
51+
python train.py -c configs/scpnet+nuswide.yaml
52+
```
53+
54+
### CUB
55+
```python
56+
python train.py -c configs/scpnet+cub.yaml
57+
```
58+
59+
## Inference
60+
61+
> Note: Please place the pretrained checkpoint to checkpoints/scpnet+coco/round1/model-highest.ckpt
62+
63+
#### COCO
64+
```python
65+
python train.py -c configs/scpnet+coco.yaml -t -r 1
66+
```
67+
68+
#### VOC
69+
```python
70+
python train.py -c configs/scpnet+voc.yaml -t -r 1
71+
```
72+
73+
#### NUSWIDE
74+
```python
75+
python train.py -c configs/scpnet+nuswide.yaml -t -r 1
76+
```
77+
78+
#### CUB
79+
```python
80+
python train.py -c configs/scpnet+cub.yaml -t -r 1
81+
```
82+
83+
## Citation
84+
```
85+
@inproceedings{ding2023exploring,
86+
title={Exploring Structured Semantic Prior for Multi Label Recognition with Incomplete Labels},
87+
author={Ding, Zixuan and Wang, Ao and Chen, Hui and Zhang, Qiang and Liu, Pengzhang and Bao, Yongjun and Yan, Weipeng and Han, Jungong},
88+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
89+
pages={3398--3407},
90+
year={2023}
91+
}
92+
```

args.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import argparse
2+
3+
parser = argparse.ArgumentParser(description='PyTorch MS_COCO Training')
4+
parser.add_argument('-c',
5+
'--config-file',
6+
help='config file',
7+
default='configs/base.yaml',
8+
type=str)
9+
parser.add_argument('-t',
10+
'--test',
11+
help='run test',
12+
default=False,
13+
action="store_true")
14+
parser.add_argument('-r', '--round', help='round', default=1, type=int)
15+
parser.add_argument('--resume', default=False, action='store_true')
16+
args = parser.parse_args()

clip/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .clip import *

clip/bpe_simple_vocab_16e6.txt.gz

1.29 MB
Binary file not shown.

clip/clip.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
import hashlib
2+
import os
3+
import urllib
4+
import warnings
5+
from typing import List, Union
6+
7+
import torch
8+
from PIL import Image
9+
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
10+
ToTensor)
11+
from tqdm import tqdm
12+
13+
from .model import build_model
14+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15+
16+
try:
17+
from torchvision.transforms import InterpolationMode
18+
BICUBIC = InterpolationMode.BICUBIC
19+
except ImportError:
20+
BICUBIC = Image.BICUBIC
21+
22+
if torch.__version__.split(".") < ["1", "7", "1"]:
23+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
24+
25+
__all__ = ["available_models", "load", "tokenize"]
26+
_tokenizer = _Tokenizer()
27+
28+
_MODELS = {
29+
"RN50":
30+
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31+
"RN101":
32+
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33+
"RN50x4":
34+
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
35+
"RN50x16":
36+
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
37+
"ViT-B/32":
38+
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
39+
"ViT-B/16":
40+
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
41+
}
42+
43+
44+
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
45+
os.makedirs(root, exist_ok=True)
46+
filename = os.path.basename(url)
47+
48+
expected_sha256 = url.split("/")[-2]
49+
download_target = os.path.join(root, filename)
50+
51+
if os.path.exists(download_target) and not os.path.isfile(download_target):
52+
raise RuntimeError(
53+
f"{download_target} exists and is not a regular file")
54+
55+
if os.path.isfile(download_target):
56+
if hashlib.sha256(open(download_target,
57+
"rb").read()).hexdigest() == expected_sha256:
58+
return download_target
59+
else:
60+
warnings.warn(
61+
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
62+
)
63+
64+
with urllib.request.urlopen(url) as source, open(download_target,
65+
"wb") as output:
66+
with tqdm(total=int(source.info().get("Content-Length")),
67+
ncols=80,
68+
unit='iB',
69+
unit_scale=True) as loop:
70+
while True:
71+
buffer = source.read(8192)
72+
if not buffer:
73+
break
74+
75+
output.write(buffer)
76+
loop.update(len(buffer))
77+
78+
if hashlib.sha256(open(download_target,
79+
"rb").read()).hexdigest() != expected_sha256:
80+
raise RuntimeError(
81+
"Model has been downloaded but the SHA256 checksum does not not match"
82+
)
83+
84+
return download_target
85+
86+
87+
def _transform(n_px):
88+
return Compose([
89+
Resize(n_px, interpolation=BICUBIC),
90+
CenterCrop(n_px),
91+
lambda image: image.convert("RGB"),
92+
ToTensor(),
93+
Normalize((0.48145466, 0.4578275, 0.40821073),
94+
(0.26862954, 0.26130258, 0.27577711)),
95+
])
96+
97+
98+
def available_models() -> List[str]:
99+
"""Returns the names of available CLIP models"""
100+
return list(_MODELS.keys())
101+
102+
103+
def load(name: str,
104+
device: Union[str, torch.device] = "cuda"
105+
if torch.cuda.is_available() else "cpu",
106+
jit=False):
107+
"""Load a CLIP model
108+
109+
Parameters
110+
----------
111+
name : str
112+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
113+
114+
device : Union[str, torch.device]
115+
The device to put the loaded model
116+
117+
jit : bool
118+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
119+
120+
Returns
121+
-------
122+
model : torch.nn.Module
123+
The CLIP model
124+
125+
preprocess : Callable[[PIL.Image], torch.Tensor]
126+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
127+
"""
128+
if name in _MODELS:
129+
model_path = _download(_MODELS[name])
130+
elif os.path.isfile(name):
131+
model_path = name
132+
else:
133+
raise RuntimeError(
134+
f"Model {name} not found; available models = {available_models()}")
135+
136+
try:
137+
# loading JIT archive
138+
model = torch.jit.load(model_path,
139+
map_location=device if jit else "cpu").eval()
140+
state_dict = None
141+
except RuntimeError:
142+
# loading saved state dict
143+
if jit:
144+
warnings.warn(
145+
f"File {model_path} is not a JIT archive. Loading as a state dict instead"
146+
)
147+
jit = False
148+
state_dict = torch.load(model_path, map_location="cpu")
149+
150+
if not jit:
151+
model = build_model(state_dict or model.state_dict()).to(device)
152+
if str(device) == "cpu":
153+
model.float()
154+
return model, _transform(model.visual.input_resolution)
155+
156+
# patch the device names
157+
device_holder = torch.jit.trace(
158+
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
159+
device_node = [
160+
n for n in device_holder.graph.findAllNodes("prim::Constant")
161+
if "Device" in repr(n)
162+
][-1]
163+
164+
def patch_device(module):
165+
try:
166+
graphs = [module.graph] if hasattr(module, "graph") else []
167+
except RuntimeError:
168+
graphs = []
169+
170+
if hasattr(module, "forward1"):
171+
graphs.append(module.forward1.graph)
172+
173+
for graph in graphs:
174+
for node in graph.findAllNodes("prim::Constant"):
175+
if "value" in node.attributeNames() and str(
176+
node["value"]).startswith("cuda"):
177+
node.copyAttributes(device_node)
178+
179+
model.apply(patch_device)
180+
patch_device(model.encode_image)
181+
patch_device(model.encode_text)
182+
183+
# patch dtype to float32 on CPU
184+
if str(device) == "cpu":
185+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(),
186+
example_inputs=[])
187+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
188+
float_node = float_input.node()
189+
190+
def patch_float(module):
191+
try:
192+
graphs = [module.graph] if hasattr(module, "graph") else []
193+
except RuntimeError:
194+
graphs = []
195+
196+
if hasattr(module, "forward1"):
197+
graphs.append(module.forward1.graph)
198+
199+
for graph in graphs:
200+
for node in graph.findAllNodes("aten::to"):
201+
inputs = list(node.inputs())
202+
for i in [
203+
1, 2
204+
]: # dtype can be the second or third argument to aten::to()
205+
if inputs[i].node()["value"] == 5:
206+
inputs[i].node().copyAttributes(float_node)
207+
208+
model.apply(patch_float)
209+
patch_float(model.encode_image)
210+
patch_float(model.encode_text)
211+
212+
model.float()
213+
214+
return model, _transform(model.input_resolution.item())
215+
216+
217+
def tokenize(texts: Union[str, List[str]],
218+
context_length: int = 77,
219+
truncate: bool = False) -> torch.LongTensor:
220+
"""
221+
Returns the tokenized representation of given input string(s)
222+
223+
Parameters
224+
----------
225+
texts : Union[str, List[str]]
226+
An input string or a list of input strings to tokenize
227+
228+
context_length : int
229+
The context length to use; all CLIP models use 77 as the context length
230+
231+
truncate: bool
232+
Whether to truncate the text in case its encoding is longer than the context length
233+
234+
Returns
235+
-------
236+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
237+
"""
238+
if isinstance(texts, str):
239+
texts = [texts]
240+
241+
sot_token = _tokenizer.encoder["<|startoftext|>"]
242+
eot_token = _tokenizer.encoder["<|endoftext|>"]
243+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
244+
for text in texts]
245+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
246+
247+
for i, tokens in enumerate(all_tokens):
248+
if len(tokens) > context_length:
249+
if truncate:
250+
tokens = tokens[:context_length]
251+
tokens[-1] = eot_token
252+
else:
253+
raise RuntimeError(
254+
f"Input {texts[i]} is too long for context length {context_length}"
255+
)
256+
result[i, :len(tokens)] = torch.tensor(tokens)
257+
258+
return result

0 commit comments

Comments
 (0)