@@ -4348,12 +4348,14 @@ def nce(input,
4348
4348
4349
4349
def hsigmoid (input ,
4350
4350
label ,
4351
- num_classes ,
4352
- ptabl = None ,
4351
+ num_classes = None ,
4352
+ non_leaf_num = None ,
4353
+ ptable = None ,
4353
4354
pcode = None ,
4354
4355
param_attr = None ,
4355
4356
bias_attr = None ,
4356
- name = None ):
4357
+ name = None ,
4358
+ is_costum = False ):
4357
4359
"""
4358
4360
The hierarchical sigmoid operator is used to accelerate the training
4359
4361
process of language model. This operator organizes the classes into a
@@ -4373,7 +4375,8 @@ def hsigmoid(input,
4373
4375
and :math:`D` is the feature size.
4374
4376
label (Variable): The tensor variable contains labels of training data.
4375
4377
It's a tensor with shape is :math:`[N \\ times 1]`.
4376
- num_classes: (int), The number of classes, must not be less than 2.
4378
+ num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set
4379
+ non_leaf_num: this defines the number of non-leaf nodes in costumed tree
4377
4380
ptable: (Variable|None) this variable can store each batch of samples' path to root,
4378
4381
it should be in leaf -> root order
4379
4382
ptable should have the same shape with pcode, and for each sample i ptable[i] indicates a np.array like
@@ -4409,20 +4412,33 @@ def hsigmoid(input,
4409
4412
out = helper .create_variable_for_type_inference (dtype )
4410
4413
pre_out = helper .create_variable_for_type_inference (dtype )
4411
4414
dim = input .shape [1 ]
4412
- if num_classes < 2 :
4413
- raise ValueError ("num_classes must not be less than 2." )
4414
- if (ptable is not None ) and (pcode is None ):
4415
- raise ValueError ("pcode should not be None when ptable has been set" )
4416
- elif (ptable is None ) and (pcode is not None ):
4417
- raise ValueError ("ptable should not be None when pcode has been set" )
4415
+ if ((num_classes < 2 ) or (num_classes is None )) and (not is_costum ):
4416
+ raise ValueError (
4417
+ "num_classes must not be less than 2 with default tree" )
4418
+
4419
+ if (is_costum ) and (pcode is None ):
4420
+ raise ValueError ("pcode should not be None with costum tree" )
4421
+ elif (is_costum ) and (ptable is None ):
4422
+ raise ValueError ("ptable should not be None with costum tree" )
4423
+ elif (is_costum ) and (non_leaf_num is None ):
4424
+ raise ValueError ("non_leaf_num should not be None with costum tree" )
4418
4425
else :
4419
4426
pass
4420
4427
4421
- weights = helper .create_parameter (
4422
- attr = helper .param_attr ,
4423
- shape = [num_classes - 1 , dim ],
4424
- is_bias = False ,
4425
- dtype = input .dtype )
4428
+ weights = None
4429
+
4430
+ if not is_costum :
4431
+ weights = helper .create_parameter (
4432
+ attr = helper .param_attr ,
4433
+ shape = [num_classes - 1 , dim ],
4434
+ is_bias = False ,
4435
+ dtype = input .dtype )
4436
+ else :
4437
+ weights = helper .create_parameter (
4438
+ attr = helper .param_attr ,
4439
+ shape = [non_leaf_num , dim ],
4440
+ is_bias = False ,
4441
+ dtype = input .dtype )
4426
4442
inputs = {
4427
4443
"X" : input ,
4428
4444
"W" : weights ,
@@ -4431,12 +4447,20 @@ def hsigmoid(input,
4431
4447
"Label" : label
4432
4448
}
4433
4449
if helper .bias_attr :
4434
- bias = helper .create_parameter (
4435
- attr = helper .bias_attr ,
4436
- shape = [1 , num_classes - 1 ],
4437
- is_bias = True ,
4438
- dtype = input .dtype )
4439
- inputs ['Bias' ] = bias
4450
+ if not is_costum :
4451
+ bias = helper .create_parameter (
4452
+ attr = helper .bias_attr ,
4453
+ shape = [1 , num_classes - 1 ],
4454
+ is_bias = True ,
4455
+ dtype = input .dtype )
4456
+ inputs ['Bias' ] = bias
4457
+ else :
4458
+ bias = helper .create_parameter (
4459
+ attr = helper .bias_attr ,
4460
+ shape = [1 , non_leaf_num ],
4461
+ is_bias = True ,
4462
+ dtype = input .dtype )
4463
+ inputs ['Bias' ] = bias
4440
4464
helper .append_op (
4441
4465
type = "hierarchical_sigmoid" ,
4442
4466
inputs = inputs ,
0 commit comments