@@ -196,8 +196,17 @@ def fit_embeddings(self, Xs, Xt):
196196
197197 self .Cs_ = self .lambda_ * cov_Xs + np .eye (Xs .shape [1 ])
198198 self .Ct_ = self .lambda_ * cov_Xt + np .eye (Xt .shape [1 ])
199- Xs_emb = np .matmul (Xs , linalg .inv (linalg .sqrtm (self .Cs_ )))
200- Xs_emb = np .matmul (Xs_emb , linalg .sqrtm (self .Ct_ ))
199+
200+ Cs_sqrt_inv = linalg .inv (linalg .sqrtm (self .Cs_ ))
201+ Ct_sqrt = linalg .sqrtm (self .Ct_ )
202+
203+ if np .iscomplexobj (Cs_sqrt_inv ):
204+ Cs_sqrt_inv = Cs_sqrt_inv .real
205+ if np .iscomplexobj (Ct_sqrt ):
206+ Ct_sqrt = Ct_sqrt .real
207+
208+ Xs_emb = np .matmul (Xs , Cs_sqrt_inv )
209+ Xs_emb = np .matmul (Xs_emb , Ct_sqrt )
201210
202211 if self .verbose :
203212 new_cov_Xs = np .cov (Xs_emb , rowvar = False )
@@ -280,8 +289,16 @@ def predict_features(self, X, domain="tgt"):
280289 if domain in ["tgt" , "target" ]:
281290 X_emb = X
282291 elif domain in ["src" , "source" ]:
283- X_emb = np .matmul (X , linalg .inv (linalg .sqrtm (self .Cs_ )))
284- X_emb = np .matmul (X_emb , linalg .sqrtm (self .Ct_ ))
292+ Cs_sqrt_inv = linalg .inv (linalg .sqrtm (self .Cs_ ))
293+ Ct_sqrt = linalg .sqrtm (self .Ct_ )
294+
295+ if np .iscomplexobj (Cs_sqrt_inv ):
296+ Cs_sqrt_inv = Cs_sqrt_inv .real
297+ if np .iscomplexobj (Ct_sqrt ):
298+ Ct_sqrt = Ct_sqrt .real
299+
300+ X_emb = np .matmul (X , Cs_sqrt_inv )
301+ X_emb = np .matmul (X_emb , Ct_sqrt )
285302 else :
286303 raise ValueError ("`domain `argument "
287304 "should be `tgt` or `src`, "
0 commit comments