@@ -155,10 +155,12 @@ def DeepFMModel(
155155 deep_block : Optional [Block ] = None ,
156156 input_block : Optional [Block ] = None ,
157157 wide_input_block : Optional [Block ] = None ,
158+ wide_logit_block : Optional [Block ] = None ,
159+ deep_logit_block : Optional [Block ] = None ,
158160 prediction_tasks : Optional [
159161 Union [PredictionTask , List [PredictionTask ], ParallelPredictionBlock ]
160162 ] = None ,
161- ** deep_tower_kwargs ,
163+ ** kwargs ,
162164) -> Model :
163165 """DeepFM-model architecture, which is the sum of the 1-dim output
164166 of a Factorization Machine [2] and a Deep Neural Network
@@ -198,6 +200,15 @@ def DeepFMModel(
198200 The input for the wide block. If not provided,
199201 creates a default block that encodes categorical features
200202 with one-hot / multi-hot representation and also includes the continuous features.
203+ wide_logit_block: Optional[Block], by default None
204+ The output layer of the wide input. The last dimension needs to be 1.
205+ You might want to provide your own output logit block if you want to add
206+ dropout or kernel regularization to the wide block.
207+ deep_logit_block: Optional[Block], by default MLPBlock([1], activation="linear", use_bias=True)
208+ The output layer of the deep block. The last dimension needs to be 1.
209+ You might want to provide your own output logit block if you want to add
210+ dropout or kernel regularization to the wide block.
211+
201212 prediction_tasks: optional
202213 The prediction tasks to be used, by default this will be inferred from the Schema.
203214 Defaults to None
@@ -217,14 +228,15 @@ def DeepFMModel(
217228 schema ,
218229 fm_input_block = input_block ,
219230 wide_input_block = wide_input_block ,
231+ wide_logit_block = wide_logit_block ,
220232 )
221233
222234 if deep_block is None :
223235 deep_block = MLPBlock ([64 ])
224236 deep_block = deep_block .prepare (aggregation = ConcatFeatures ())
225- deep_tower = input_block . connect ( deep_block ). connect (
226- MLPBlock ([1 ], activation = "linear" , use_bias = True , ** deep_tower_kwargs )
227- )
237+
238+ deep_logit_block = deep_logit_block or MLPBlock ([1 ], activation = "linear" , use_bias = True )
239+ deep_tower = input_block . connect ( deep_block ). connect ( deep_logit_block )
228240
229241 deep_fm = ParallelBlock ({"fm" : fm_tower , "deep" : deep_tower }, aggregation = "element-wise-sum" )
230242
0 commit comments