Skip to content

Commit d6df124

Browse files
Making DeepFM output layer configurable
1 parent 8ab8742 commit d6df124

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

merlin/models/tf/models/ranking.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,12 @@ def DeepFMModel(
155155
deep_block: Optional[Block] = None,
156156
input_block: Optional[Block] = None,
157157
wide_input_block: Optional[Block] = None,
158+
wide_logit_block: Optional[Block] = None,
159+
deep_logit_block: Optional[Block] = None,
158160
prediction_tasks: Optional[
159161
Union[PredictionTask, List[PredictionTask], ParallelPredictionBlock]
160162
] = None,
161-
**deep_tower_kwargs,
163+
**kwargs,
162164
) -> Model:
163165
"""DeepFM-model architecture, which is the sum of the 1-dim output
164166
of a Factorization Machine [2] and a Deep Neural Network
@@ -198,6 +200,15 @@ def DeepFMModel(
198200
The input for the wide block. If not provided,
199201
creates a default block that encodes categorical features
200202
with one-hot / multi-hot representation and also includes the continuous features.
203+
wide_logit_block: Optional[Block], by default None
204+
The output layer of the wide input. The last dimension needs to be 1.
205+
You might want to provide your own output logit block if you want to add
206+
dropout or kernel regularization to the wide block.
207+
deep_logit_block: Optional[Block], by default MLPBlock([1], activation="linear", use_bias=True)
208+
The output layer of the deep block. The last dimension needs to be 1.
209+
You might want to provide your own output logit block if you want to add
210+
dropout or kernel regularization to the wide block.
211+
201212
prediction_tasks: optional
202213
The prediction tasks to be used, by default this will be inferred from the Schema.
203214
Defaults to None
@@ -217,14 +228,15 @@ def DeepFMModel(
217228
schema,
218229
fm_input_block=input_block,
219230
wide_input_block=wide_input_block,
231+
wide_logit_block=wide_logit_block,
220232
)
221233

222234
if deep_block is None:
223235
deep_block = MLPBlock([64])
224236
deep_block = deep_block.prepare(aggregation=ConcatFeatures())
225-
deep_tower = input_block.connect(deep_block).connect(
226-
MLPBlock([1], activation="linear", use_bias=True, **deep_tower_kwargs)
227-
)
237+
238+
deep_logit_block = deep_logit_block or MLPBlock([1], activation="linear", use_bias=True)
239+
deep_tower = input_block.connect(deep_block).connect(deep_logit_block)
228240

229241
deep_fm = ParallelBlock({"fm": fm_tower, "deep": deep_tower}, aggregation="element-wise-sum")
230242

0 commit comments

Comments
 (0)