Skip to content

Commit 90fe00c

Browse files
Merge pull request #57 from KevinMusgrave/dev
v0.0.70
2 parents 3c1acc1 + d88747c commit 90fe00c

33 files changed

+1347
-301
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
"torch",
3636
"torchvision",
3737
"torchmetrics",
38-
"pytorch-metric-learning >= 1.1.0",
38+
"pytorch-metric-learning >= 1.3.1.dev0",
3939
],
4040
extras_require={
4141
"ignite": extras_require_ignite,

src/pytorch_adapt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.61"
1+
__version__ = "0.0.70"

src/pytorch_adapt/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .mcc_loss import MCCLoss
1717
from .mcd_loss import GeneralMCDLoss, MCDLoss
1818
from .mean_dist_loss import MeanDistLoss
19-
from .mmd_loss import MMDLoss
19+
from .mmd_loss import MMDBatchedLoss, MMDLoss
2020
from .model_with_bridge import ModelWithBridge
2121
from .multiple_models import MultipleModels
2222
from .neighborhood_aggregation import NeighborhoodAggregation

src/pytorch_adapt/layers/confidence_weights.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@ def __init__(self, normalizer: Callable[[torch.Tensor], torch.Tensor] = None):
2525
super().__init__()
2626
self.normalizer = c_f.default(normalizer, NoNormalizer())
2727

28-
def forward(self, logits):
28+
def forward(self, preds):
2929
""""""
30-
return self.normalizer(torch.max(logits, dim=1)[0])
30+
return self.normalizer(torch.max(preds, dim=1)[0])

src/pytorch_adapt/layers/ist_loss.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,42 @@
77
from .entropy_loss import EntropyLoss
88

99

10+
def get_probs(mat, mask, y, dist_is_inverted):
11+
if not dist_is_inverted:
12+
mat *= -1
13+
mat = F.softmax(mat, dim=1)
14+
n, m = mat.shape
15+
y = y.repeat(n, 1)[mask].view(n, m)
16+
17+
target_probs = torch.sum(mat * y, dim=1, keepdims=True)
18+
src_probs = torch.sum(mat * (1 - y), dim=1, keepdims=True)
19+
return torch.cat([src_probs, target_probs], dim=1)
20+
21+
22+
def get_loss(probs, ent_fn, div_fn, with_ent, with_div):
23+
loss = 0
24+
if with_ent:
25+
loss += -ent_fn(probs)
26+
if with_div:
27+
loss += -div_fn(probs)
28+
return loss
29+
30+
1031
class ISTLoss(torch.nn.Module):
1132
"""
1233
Implementation of the I_st loss from
1334
[Information-Theoretical Learning of Discriminative Clusters for Unsupervised Domain Adaptation](https://icml.cc/2012/papers/566.pdf)
1435
"""
1536

16-
def __init__(self, distance=None, with_div=True):
37+
def __init__(self, distance=None, with_ent=True, with_div=True):
1738
super().__init__()
1839
self.distance = c_f.default(distance, CosineSimilarity, {})
40+
if not (with_ent or with_div):
41+
raise ValueError("At least one of with_ent or with_div must be True")
42+
self.with_ent = with_ent
1943
self.with_div = with_div
2044
self.ent_loss_fn = EntropyLoss(after_softmax=True)
21-
if self.with_div:
22-
self.div_loss_fn = DiversityLoss(after_softmax=True)
45+
self.div_loss_fn = DiversityLoss(after_softmax=True)
2346

2447
def forward(self, x, y):
2548
"""
@@ -35,23 +58,13 @@ def forward(self, x, y):
3558

3659
mat = self.distance(x)
3760
# remove self comparisons
38-
mask = torch.eye(n, dtype=torch.bool)
39-
mat = mat[~mask].view(n, n - 1)
40-
if not self.distance.is_inverted:
41-
mat *= -1
42-
mat = F.softmax(mat, dim=1)
43-
44-
y = y.repeat(n, 1)[~mask].view(n, n - 1)
45-
46-
target_probs = torch.sum(mat * y, dim=1, keepdims=True)
47-
src_probs = torch.sum(mat * (1 - y), dim=1, keepdims=True)
48-
probs = torch.cat([src_probs, target_probs], dim=1)
49-
50-
ent_loss = self.ent_loss_fn(probs)
61+
mask = ~torch.eye(n, dtype=torch.bool)
62+
mat = mat[mask].view(n, n - 1)
63+
probs = get_probs(mat, mask, y, self.distance.is_inverted)
5164

52-
if self.with_div:
53-
return -self.div_loss_fn(probs) - ent_loss
54-
return -ent_loss
65+
return get_loss(
66+
probs, self.ent_loss_fn, self.div_loss_fn, self.with_ent, self.with_div
67+
)
5568

5669
def extra_repr(self):
5770
""""""

src/pytorch_adapt/layers/mmd_loss.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
from typing import List, Union
22

33
import torch
4-
from pytorch_metric_learning.distances import LpDistance
4+
from pytorch_metric_learning.distances import BatchedDistance, LpDistance
55
from pytorch_metric_learning.utils import common_functions as pml_cf
66

77
from ..utils import common_functions as c_f
88
from . import utils as l_u
99

1010

11+
def check_batch_sizes(s, t, mmd_type):
12+
if mmd_type == "quadratic":
13+
return
14+
is_list = c_f.is_list_or_tuple(s)
15+
if (is_list and any(s[i].shape != t[i].shape for i in range(len(s)))) or (
16+
not is_list and s.shape != t.shape
17+
):
18+
raise ValueError(
19+
"For mmd_type 'linear', source and target must have the same batch size."
20+
)
21+
22+
1123
class MMDLoss(torch.nn.Module):
1224
"""
1325
Implementation of
@@ -18,7 +30,11 @@ class MMDLoss(torch.nn.Module):
1830
"""
1931

2032
def __init__(
21-
self, kernel_scales: Union[float, torch.Tensor] = 1, mmd_type: str = "linear"
33+
self,
34+
kernel_scales: Union[float, torch.Tensor] = 1,
35+
mmd_type: str = "linear",
36+
dist_func=None,
37+
bandwidth=None,
2238
):
2339
"""
2440
Arguments:
@@ -28,7 +44,10 @@ def __init__(
2844
"""
2945
super().__init__()
3046
self.kernel_scales = kernel_scales
31-
self.dist_func = LpDistance(normalize_embeddings=False, p=2, power=2)
47+
self.dist_func = c_f.default(
48+
dist_func, LpDistance(normalize_embeddings=False, p=2, power=2)
49+
)
50+
self.bandwidth = bandwidth
3251
self.mmd_type = mmd_type
3352
if mmd_type == "linear":
3453
self.mmd_func = l_u.get_mmd_linear
@@ -50,7 +69,8 @@ def forward(
5069
Returns:
5170
MMD if the inputs are tensors, and Joint MMD (JMMD) if the inputs are lists of tensors.
5271
"""
53-
xx, yy, zz, scale = l_u.get_mmd_dist_mats(x, y, self.dist_func)
72+
check_batch_sizes(x, y, self.mmd_type)
73+
xx, yy, zz, scale = l_u.get_mmd_dist_mats(x, y, self.dist_func, self.bandwidth)
5474
if torch.is_tensor(self.kernel_scales):
5575
s = scale[0] if c_f.is_list_or_tuple(scale) else scale
5676
self.kernel_scales = pml_cf.to_device(self.kernel_scales, s, dtype=s.dtype)
@@ -66,3 +86,25 @@ def forward(
6686
def extra_repr(self):
6787
""""""
6888
return c_f.extra_repr(self, ["mmd_type", "kernel_scales"])
89+
90+
91+
class MMDBatchedLoss(MMDLoss):
92+
def __init__(self, batch_size=1024, **kwargs):
93+
super().__init__(**kwargs)
94+
if self.mmd_type != "quadratic":
95+
raise ValueError("mmd_type must be 'quadratic'")
96+
self.mmd_func = l_u.get_mmd_quadratic_batched
97+
self.dist_func = BatchedDistance(self.dist_func, batch_size=batch_size)
98+
99+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
100+
"""
101+
Arguments:
102+
x: features from one domain.
103+
y: features from the other domain.
104+
Returns:
105+
MMD
106+
"""
107+
if c_f.is_list_or_tuple(x) or c_f.is_list_or_tuple(y):
108+
raise TypeError("List of features not yet supported")
109+
check_batch_sizes(x, y, self.mmd_type)
110+
return self.mmd_func(x, y, self.dist_func, self.kernel_scales, self.bandwidth)

src/pytorch_adapt/layers/neighborhood_aggregation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,16 @@ def get_pseudo_labels(self, normalized_features, idx):
7777
for di in range(dis.size(0)):
7878
dis[di, idx[di]] = torch.min(dis)
7979
_, indices = torch.topk(dis, k=self.k, dim=1)
80-
logits = torch.mean(self.pred_memory[indices], dim=1)
81-
pseudo_labels = torch.argmax(logits, dim=1)
82-
return pseudo_labels, logits
80+
preds = torch.mean(self.pred_memory[indices], dim=1)
81+
pseudo_labels = torch.argmax(preds, dim=1)
82+
return pseudo_labels, preds
8383

8484
def update_memory(self, normalized_features, logits, idx):
85-
logits = F.softmax(logits, dim=1)
85+
preds = F.softmax(logits, dim=1)
8686
p = 1.0 / self.T
87-
logits = (logits**p) / torch.sum(logits**p, dim=0)
87+
preds = (preds**p) / torch.sum(preds**p, dim=0)
8888
self.feat_memory[idx] = normalized_features
89-
self.pred_memory[idx] = logits
89+
self.pred_memory[idx] = preds
9090

9191
def extra_repr(self):
9292
""""""

src/pytorch_adapt/layers/utils.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import torch
3+
from pytorch_metric_learning.utils import common_functions as pml_cf
34

45
from ..utils import common_functions as c_f
56

@@ -18,31 +19,36 @@ def get_kernel_scales(low=-8, high=8, num_kernels=33, base=2.0):
1819
return torch.from_numpy(np.logspace(low, high, num=num_kernels, base=base))
1920

2021

21-
def _mmd_dist_mats(x, y, dist_func):
22+
def _mmd_dist_mats(x, y, dist_func, bandwidth=None):
2223
xx = dist_func(x, x)
2324
yy = dist_func(y, y)
2425
zz = dist_func(x, y)
2526

2627
with torch.no_grad():
2728
# https://arxiv.org/pdf/1409.6041.pdf
2829
# https://arxiv.org/pdf/1707.07269.pdf
29-
scale = -1.0 / torch.median(xx)
30+
denom = (
31+
torch.median(xx)
32+
if bandwidth is None
33+
else torch.tensor([bandwidth], dtype=xx.dtype, device=xx.device)
34+
)
35+
scale = -1.0 / denom
3036

3137
return xx, yy, zz, scale
3238

3339

34-
def get_mmd_dist_mats(x, y, dist_func):
40+
def get_mmd_dist_mats(x, y, dist_func, bandwidth):
3541
if c_f.is_list_or_tuple(x):
3642
xx, yy, zz, scale = [], [], [], []
3743
for i in range(len(x)):
38-
_xx, _yy, _zz, _scale = _mmd_dist_mats(x[i], y[i], dist_func)
44+
_xx, _yy, _zz, _scale = _mmd_dist_mats(x[i], y[i], dist_func, bandwidth)
3945
xx.append(_xx)
4046
yy.append(_yy)
4147
zz.append(_zz)
4248
scale.append(_scale)
4349
return xx, yy, zz, scale
4450
else:
45-
return _mmd_dist_mats(x, y, dist_func)
51+
return _mmd_dist_mats(x, y, dist_func, bandwidth)
4652

4753

4854
def get_default_kernel_weights(scale):
@@ -124,3 +130,44 @@ def get_mmd_linear(xx, yy, zz, scale, weights=None):
124130

125131
loss = loss1 + loss2 - loss3 - loss4
126132
return torch.sum(loss) / float(B // 2)
133+
134+
135+
def _mmd_quadratic_batched(rsum, scale, weights, query_is_ref):
136+
def fn(mat, s, *_):
137+
if query_is_ref:
138+
mat = c_f.mask_out_self(mat, s)
139+
rsum[0] += torch.sum(_mmd_quadratic(mat, scale, weights))
140+
141+
return fn
142+
143+
144+
def get_median_of_medians(x, dist_func):
145+
medians = []
146+
147+
def fn(mat, *_):
148+
with torch.no_grad():
149+
medians.append(torch.median(mat))
150+
151+
dist_func.iter_fn = fn
152+
dist_func(x, x)
153+
return torch.median(torch.stack(medians))
154+
155+
156+
def get_mmd_quadratic_batched(x, y, dist_func, kernel_scales, bandwidth, weights=None):
157+
if torch.is_tensor(kernel_scales):
158+
kernel_scales = pml_cf.to_device(kernel_scales, x, dtype=x.dtype)
159+
if bandwidth is None:
160+
bandwidth = get_median_of_medians(x, dist_func)
161+
scale = -kernel_scales / bandwidth
162+
weights = c_f.default(weights, get_default_kernel_weights(scale))
163+
164+
sums = []
165+
for s, t in [(x, x), (y, y), (x, y)]:
166+
rsum = [0]
167+
query_is_ref = s is t
168+
dist_func.iter_fn = _mmd_quadratic_batched(rsum, scale, weights, query_is_ref)
169+
dist_func(s, t)
170+
denom = (len(s) * (len(s) - 1)) if query_is_ref else (len(s) * len(t))
171+
sums.append(torch.sum(rsum[0]) / denom)
172+
173+
return sums[0] + sums[1] - 2 * sums[2]

src/pytorch_adapt/utils/common_functions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,3 +551,15 @@ def subset_of_dict(x, subset):
551551
if isinstance(subset, dict):
552552
return {k: subset_of_dict(x[k], v) for k, v in subset.items()}
553553
raise TypeError("subset argument must be dict or set")
554+
555+
556+
def mask_out_self(sim_mat, start_idx, return_mask=False):
557+
num_rows, num_cols = sim_mat.shape
558+
mask = torch.ones(num_rows, num_cols, dtype=torch.bool)
559+
rows = torch.arange(num_rows)
560+
cols = rows + start_idx
561+
mask[rows, cols] = False
562+
sim_mat = sim_mat[mask].view(num_rows, num_cols - 1)
563+
if return_mask:
564+
return sim_mat, mask
565+
return sim_mat
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from .accuracy_validator import AccuracyValidator
22
from .base_validator import BaseValidator
3+
from .class_cluster_validator import ClassClusterValidator
34
from .deep_embedded_validator import DeepEmbeddedValidator
45
from .diversity_validator import DiversityValidator
56
from .entropy_validator import EntropyValidator
67
from .error_validator import ErrorValidator
78
from .im_validator import IMValidator
8-
9-
# from .knn_validator import ClusterValidator, KNNValidator
9+
from .ist_validator import ISTValidator
10+
from .knn_validator import KNNValidator
11+
from .mmd_validator import MMDValidator
1012
from .multiple_validators import MultipleValidators
13+
from .per_class_validator import PerClassValidator
1114
from .score_history import ScoreHistories, ScoreHistory
12-
from .silhouette_score_validator import SilhouetteScoreValidator
1315
from .snd_validator import SNDValidator
16+
from .target_knn_validator import TargetKNNValidator

0 commit comments

Comments
 (0)