Skip to content

Commit af0e661

Browse files
committed
Remove TorchMHCCoeffs class from test_mhc.py
1 parent cb25136 commit af0e661

File tree

1 file changed

+0
-41
lines changed

1 file changed

+0
-41
lines changed

test/transformers/test_mhc.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -32,47 +32,6 @@
3232
]
3333

3434

35-
class TorchMHCCoeffs(nn.Module):
36-
def __init__(
37-
self,
38-
*,
39-
tmax: int,
40-
rms_eps: float,
41-
pre_eps: float,
42-
sinkhorn_eps: float,
43-
post_mult: float,
44-
):
45-
super().__init__()
46-
self.tmax = int(tmax)
47-
self.rms_eps = float(rms_eps)
48-
self.pre_eps = float(pre_eps)
49-
self.sinkhorn_eps = float(sinkhorn_eps)
50-
self.post_mult = float(post_mult)
51-
52-
def forward(
53-
self,
54-
x: torch.Tensor,
55-
phi: torch.Tensor,
56-
b: torch.Tensor,
57-
alpha_pre: torch.Tensor,
58-
alpha_post: torch.Tensor,
59-
alpha_res: torch.Tensor,
60-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
61-
return mhc_coeffs_ref(
62-
x,
63-
phi,
64-
b,
65-
alpha_pre,
66-
alpha_post,
67-
alpha_res,
68-
tmax=self.tmax,
69-
rms_eps=self.rms_eps,
70-
pre_eps=self.pre_eps,
71-
sinkhorn_eps=self.sinkhorn_eps,
72-
post_mult=self.post_mult,
73-
)
74-
75-
7635
def mhc_sinkhorn_ref(logits: torch.Tensor, *, tmax: int, eps: float) -> torch.Tensor:
7736
"""
7837
logits: [N, HC, HC]

0 commit comments

Comments
 (0)