File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
tests/test_model_components Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change 2222batch_size = 10
2323colnames = list (string .ascii_lowercase )[: (n_cols * 2 )]
2424embed_cols = [np .random .choice (np .arange (n_embed ), batch_size ) for _ in range (n_cols )]
25- embed_cols_with_cls_token = [[n_embed ] * batch_size ] + embed_cols
25+ embed_cols_with_cls_token = [[n_embed ] * batch_size ] + embed_cols # type: ignore[operator]
2626cont_cols = [np .random .rand (batch_size ) for _ in range (n_cols )]
2727
2828X_tab = torch .from_numpy (np .vstack (embed_cols + cont_cols ).transpose ())
2929X_tab_with_cls_token = torch .from_numpy (
30- np .vstack (embed_cols_with_cls_token + cont_cols ).transpose ()
30+ np .vstack (embed_cols_with_cls_token + cont_cols ).transpose () # type: ignore[operator]
3131)
3232
3333
@@ -169,11 +169,11 @@ def test_full_embed_dropout():
169169 bsz = 1
170170 cat = 10
171171 esz = 4
172- full_embedding_dropout = FullEmbeddingDropout (dropout = 0.5 )
172+ full_embedding_dropout = FullEmbeddingDropout (dropout = 0.8 )
173173 inp = torch .rand (bsz , cat , esz )
174174 out = full_embedding_dropout (inp )
175175 # simply check that at least 1 full row is all 0s
176- assert torch . any (torch .sum (out [0 ] == 0 , axis = 1 ) == esz )
176+ assert (torch .sum (out [0 ] == 0 , axis = 1 ) == esz ). sum () > 0
177177
178178
179179# ###############################################################################
You can’t perform that action at this time.
0 commit comments