Skip to content

Commit e865fb2

Browse files
committed
Adjusted test full embed dropout
1 parent cf9c9a3 commit e865fb2

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/test_model_components/test_mc_transformers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
batch_size = 10
2323
colnames = list(string.ascii_lowercase)[: (n_cols * 2)]
2424
embed_cols = [np.random.choice(np.arange(n_embed), batch_size) for _ in range(n_cols)]
25-
embed_cols_with_cls_token = [[n_embed] * batch_size] + embed_cols
25+
embed_cols_with_cls_token = [[n_embed] * batch_size] + embed_cols # type: ignore[operator]
2626
cont_cols = [np.random.rand(batch_size) for _ in range(n_cols)]
2727

2828
X_tab = torch.from_numpy(np.vstack(embed_cols + cont_cols).transpose())
2929
X_tab_with_cls_token = torch.from_numpy(
30-
np.vstack(embed_cols_with_cls_token + cont_cols).transpose()
30+
np.vstack(embed_cols_with_cls_token + cont_cols).transpose() # type: ignore[operator]
3131
)
3232

3333

@@ -169,11 +169,11 @@ def test_full_embed_dropout():
169169
bsz = 1
170170
cat = 10
171171
esz = 4
172-
full_embedding_dropout = FullEmbeddingDropout(dropout=0.5)
172+
full_embedding_dropout = FullEmbeddingDropout(dropout=0.8)
173173
inp = torch.rand(bsz, cat, esz)
174174
out = full_embedding_dropout(inp)
175175
# simply check that at least 1 full row is all 0s
176-
assert torch.any(torch.sum(out[0] == 0, axis=1) == esz)
176+
assert (torch.sum(out[0] == 0, axis=1) == esz).sum() > 0
177177

178178

179179
# ###############################################################################

0 commit comments

Comments
 (0)