Skip to content

Commit 868e23e

Browse files
xaduprexiaowuhu
andauthored
Fix mixed types when converting a LightGbm model (#591)
* Fix mixed types when converting a LightGbm model Signed-off-by: xadupre <[email protected]> * fix type issue Signed-off-by: xadupre <[email protected]> Signed-off-by: xadupre <[email protected]> Co-authored-by: xiaowuhu <[email protected]>
1 parent ece7e36 commit 868e23e

File tree

1 file changed

+22
-0
lines changed
  • onnxmltools/convert/lightgbm/operator_converters

1 file changed

+22
-0
lines changed

onnxmltools/convert/lightgbm/operator_converters/LightGbm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,17 @@ def convert_lightgbm(scope, operator, container):
516516
'probability_tensor')
517517
label_tensor_name = scope.get_unique_variable_name('label_tensor')
518518

519+
# onnx does not support int and float values for a float tensor
520+
update = {}
521+
for k, v in attrs.items():
522+
if not isinstance(v, list):
523+
continue
524+
tps = set(map(type, v))
525+
if len(tps) == 2:
526+
if tps == {int, float}:
527+
update[k] = [float(x) for x in v]
528+
attrs.update(update)
529+
519530
container.add_node(
520531
'TreeEnsembleClassifier', operator.input_full_names,
521532
[label_tensor_name, probability_tensor_name],
@@ -601,6 +612,17 @@ def convert_lightgbm(scope, operator, container):
601612
attrs['target' + k[5:]] = copy.deepcopy(attrs[k])
602613
del attrs[k]
603614

615+
# onnx does not support int and float values for a float tensor
616+
update = {}
617+
for k, v in attrs.items():
618+
if not isinstance(v, list):
619+
continue
620+
tps = set(map(type, v))
621+
if len(tps) == 2:
622+
if tps == {int, float}:
623+
update[k] = [float(x) for x in v]
624+
attrs.update(update)
625+
604626
split = getattr(operator, 'split', None)
605627
if split in (None, -1):
606628
container.add_node(

0 commit comments

Comments
 (0)