File tree Expand file tree Collapse file tree 1 file changed +0
-41
lines changed
Expand file tree Collapse file tree 1 file changed +0
-41
lines changed Original file line number Diff line number Diff line change 3232]
3333
3434
35- class TorchMHCCoeffs (nn .Module ):
36- def __init__ (
37- self ,
38- * ,
39- tmax : int ,
40- rms_eps : float ,
41- pre_eps : float ,
42- sinkhorn_eps : float ,
43- post_mult : float ,
44- ):
45- super ().__init__ ()
46- self .tmax = int (tmax )
47- self .rms_eps = float (rms_eps )
48- self .pre_eps = float (pre_eps )
49- self .sinkhorn_eps = float (sinkhorn_eps )
50- self .post_mult = float (post_mult )
51-
52- def forward (
53- self ,
54- x : torch .Tensor ,
55- phi : torch .Tensor ,
56- b : torch .Tensor ,
57- alpha_pre : torch .Tensor ,
58- alpha_post : torch .Tensor ,
59- alpha_res : torch .Tensor ,
60- ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
61- return mhc_coeffs_ref (
62- x ,
63- phi ,
64- b ,
65- alpha_pre ,
66- alpha_post ,
67- alpha_res ,
68- tmax = self .tmax ,
69- rms_eps = self .rms_eps ,
70- pre_eps = self .pre_eps ,
71- sinkhorn_eps = self .sinkhorn_eps ,
72- post_mult = self .post_mult ,
73- )
74-
75-
7635def mhc_sinkhorn_ref (logits : torch .Tensor , * , tmax : int , eps : float ) -> torch .Tensor :
7736 """
7837 logits: [N, HC, HC]
You can’t perform that action at this time.
0 commit comments