Skip to content

Commit fadede2

Browse files
committed
Fixed issue #53 related to the use of some transformer models without categorical columns
1 parent 6540cd3 commit fadede2

File tree

8 files changed

+37
-19
lines changed

8 files changed

+37
-19
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
1212
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
1313
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
14+
[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)
1415

1516
# pytorch-widedeep
1617

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.0.9
1+
1.0.10

pypi_README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
77
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
88
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
9+
[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)
910

1011

1112
# pytorch-widedeep

pytorch_widedeep/models/transformers/ft_transformer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class FTTransformer(nn.Module):
134134
def __init__(
135135
self,
136136
column_idx: Dict[str, int],
137-
embed_input: List[Tuple[str, int]],
137+
embed_input: Optional[List[Tuple[str, int]]] = None,
138138
embed_dropout: float = 0.1,
139139
full_embed_dropout: bool = False,
140140
shared_embed: bool = False,
@@ -194,11 +194,6 @@ def __init__(
194194
self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
195195
self.n_feats = self.n_cat + self.n_cont
196196

197-
if self.n_cont and not self.n_cat and not self.embed_continuous:
198-
raise ValueError(
199-
"If only continuous features are used 'embed_continuous' must be set to 'True'"
200-
)
201-
202197
self.cat_and_cont_embed = CatAndContEmbeddings(
203198
input_dim,
204199
column_idx,

pytorch_widedeep/models/transformers/saint.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class SAINT(nn.Module):
120120
def __init__(
121121
self,
122122
column_idx: Dict[str, int],
123-
embed_input: List[Tuple[str, int]],
123+
embed_input: Optional[List[Tuple[str, int]]] = None,
124124
embed_dropout: float = 0.1,
125125
full_embed_dropout: bool = False,
126126
shared_embed: bool = False,
@@ -173,11 +173,6 @@ def __init__(
173173
self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
174174
self.n_feats = self.n_cat + self.n_cont
175175

176-
if self.n_cont and not self.n_cat and not self.embed_continuous:
177-
raise ValueError(
178-
"If only continuous features are used 'embed_continuous' must be set to 'True'"
179-
)
180-
181176
self.cat_and_cont_embed = CatAndContEmbeddings(
182177
input_dim,
183178
column_idx,

pytorch_widedeep/models/transformers/tab_fastformer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,6 @@ def __init__(
182182
self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
183183
self.n_feats = self.n_cat + self.n_cont
184184

185-
if self.n_cont and not self.n_cat and not self.embed_continuous:
186-
raise ValueError(
187-
"If only continuous features are used 'embed_continuous' must be set to 'True'"
188-
)
189-
190185
self.cat_and_cont_embed = CatAndContEmbeddings(
191186
input_dim,
192187
column_idx,

pytorch_widedeep/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.9"
1+
__version__ = "1.0.10"

tests/test_model_components/test_mc_transformers.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,3 +449,34 @@ def test_ft_transformer_mlp(mlp_first_h, shoud_work):
449449
else:
450450
with pytest.raises(AssertionError):
451451
model = _build_model("fttransformer", params) # noqa: F841
452+
453+
454+
###############################################################################
455+
# Test transformers with only continuous cols
456+
###############################################################################
457+
458+
459+
X_tab_only_cont = torch.from_numpy(
460+
np.vstack([np.random.rand(10) for _ in range(4)]).transpose()
461+
)
462+
colnames_only_cont = list(string.ascii_lowercase)[:4]
463+
464+
465+
@pytest.mark.parametrize(
466+
"model_name",
467+
[
468+
"fttransformer",
469+
"saint",
470+
"tabfastformer",
471+
],
472+
)
473+
def test_transformers_only_cont(model_name):
474+
params = {
475+
"column_idx": {k: v for v, k in enumerate(colnames_only_cont)},
476+
"continuous_cols": colnames_only_cont,
477+
}
478+
479+
model = _build_model(model_name, params)
480+
out = model(X_tab_only_cont)
481+
482+
assert out.size(0) == 10 and out.size(1) == model.output_dim

0 commit comments

Comments
 (0)