11from typing import List , Union
22
33import torch
4- from pytorch_metric_learning .distances import LpDistance
4+ from pytorch_metric_learning .distances import BatchedDistance , LpDistance
55from pytorch_metric_learning .utils import common_functions as pml_cf
66
77from ..utils import common_functions as c_f
88from . import utils as l_u
99
1010
11+ def check_batch_sizes (s , t , mmd_type ):
12+ if mmd_type == "quadratic" :
13+ return
14+ is_list = c_f .is_list_or_tuple (s )
15+ if (is_list and any (s [i ].shape != t [i ].shape for i in range (len (s )))) or (
16+ not is_list and s .shape != t .shape
17+ ):
18+ raise ValueError (
19+ "For mmd_type 'linear', source and target must have the same batch size."
20+ )
21+
22+
1123class MMDLoss (torch .nn .Module ):
1224 """
1325 Implementation of
@@ -18,7 +30,11 @@ class MMDLoss(torch.nn.Module):
1830 """
1931
2032 def __init__ (
21- self , kernel_scales : Union [float , torch .Tensor ] = 1 , mmd_type : str = "linear"
33+ self ,
34+ kernel_scales : Union [float , torch .Tensor ] = 1 ,
35+ mmd_type : str = "linear" ,
36+ dist_func = None ,
37+ bandwidth = None ,
2238 ):
2339 """
2440 Arguments:
@@ -28,7 +44,10 @@ def __init__(
2844 """
2945 super ().__init__ ()
3046 self .kernel_scales = kernel_scales
31- self .dist_func = LpDistance (normalize_embeddings = False , p = 2 , power = 2 )
47+ self .dist_func = c_f .default (
48+ dist_func , LpDistance (normalize_embeddings = False , p = 2 , power = 2 )
49+ )
50+ self .bandwidth = bandwidth
3251 self .mmd_type = mmd_type
3352 if mmd_type == "linear" :
3453 self .mmd_func = l_u .get_mmd_linear
@@ -50,7 +69,8 @@ def forward(
5069 Returns:
5170 MMD if the inputs are tensors, and Joint MMD (JMMD) if the inputs are lists of tensors.
5271 """
53- xx , yy , zz , scale = l_u .get_mmd_dist_mats (x , y , self .dist_func )
72+ check_batch_sizes (x , y , self .mmd_type )
73+ xx , yy , zz , scale = l_u .get_mmd_dist_mats (x , y , self .dist_func , self .bandwidth )
5474 if torch .is_tensor (self .kernel_scales ):
5575 s = scale [0 ] if c_f .is_list_or_tuple (scale ) else scale
5676 self .kernel_scales = pml_cf .to_device (self .kernel_scales , s , dtype = s .dtype )
@@ -66,3 +86,25 @@ def forward(
6686 def extra_repr (self ):
6787 """"""
6888 return c_f .extra_repr (self , ["mmd_type" , "kernel_scales" ])
89+
90+
91+ class MMDBatchedLoss (MMDLoss ):
92+ def __init__ (self , batch_size = 1024 , ** kwargs ):
93+ super ().__init__ (** kwargs )
94+ if self .mmd_type != "quadratic" :
95+ raise ValueError ("mmd_type must be 'quadratic'" )
96+ self .mmd_func = l_u .get_mmd_quadratic_batched
97+ self .dist_func = BatchedDistance (self .dist_func , batch_size = batch_size )
98+
99+ def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
100+ """
101+ Arguments:
102+ x: features from one domain.
103+ y: features from the other domain.
104+ Returns:
105+ MMD
106+ """
107+ if c_f .is_list_or_tuple (x ) or c_f .is_list_or_tuple (y ):
108+ raise TypeError ("List of features not yet supported" )
109+ check_batch_sizes (x , y , self .mmd_type )
110+ return self .mmd_func (x , y , self .dist_func , self .kernel_scales , self .bandwidth )
0 commit comments