Skip to content

Commit 8ab8742

Browse files
Fixed failing test
1 parent 1d2b14a commit 8ab8742

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

merlin/models/tf/blocks/interaction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def FMBlock(
294294
wide_input_block = wide_input_block or ParallelBlock(
295295
{
296296
"categorical": CategoryEncoding(cat_schema, output_mode="multi_hot", sparse=True),
297-
"continuous": Filter(cont_schema, post=ToSparse()),
297+
"continuous": Filter(cont_schema).connect(ToSparse()),
298298
},
299299
aggregation="concat",
300300
)
@@ -304,7 +304,7 @@ def FMBlock(
304304

305305
fm_input_block = fm_input_block or InputBlockV2(
306306
cat_schema,
307-
categorical=Embeddings(schema, dim=factors_dim),
307+
categorical=Embeddings(cat_schema, dim=factors_dim),
308308
aggregation=None,
309309
)
310310
pairwise_interaction = SequentialBlock(

merlin/models/tf/models/ranking.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from merlin.models.tf.core.aggregation import ConcatFeatures
99
from merlin.models.tf.core.base import Block
1010
from merlin.models.tf.core.combinators import ParallelBlock, TabularBlock
11-
from merlin.models.tf.core.tabular import Filter
1211
from merlin.models.tf.inputs.base import InputBlockV2
1312
from merlin.models.tf.inputs.embedding import EmbeddingOptions, Embeddings
1413
from merlin.models.tf.models.base import Model
@@ -216,7 +215,7 @@ def DeepFMModel(
216215

217216
fm_tower = FMBlock(
218217
schema,
219-
fm_input_block=input_block.connect(Filter(Tags.CATEGORICAL)),
218+
fm_input_block=input_block,
220219
wide_input_block=wide_input_block,
221220
)
222221

0 commit comments

Comments
 (0)