Skip to content

Commit 175aee0

Browse files
authored
fix: Spark Imputer conversion with multiple input cols (#608)
* fix: Spark Imputer conversion with multiple input cols Signed-off-by: Jason Wang <[email protected]> * remove whitespace Signed-off-by: Jason Wang <[email protected]> --------- Signed-off-by: Jason Wang <[email protected]>
1 parent db71727 commit 175aee0

File tree

3 files changed

+23
-12
lines changed

3 files changed

+23
-12
lines changed

onnxmltools/convert/sparkml/operator_converters/imputer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,19 @@
55
from ...common.data_types import Int64TensorType, FloatTensorType
66
from ...common.utils import check_input_and_output_numbers, check_input_and_output_types
77
from ...common._registration import register_converter, register_shape_calculator
8+
from ...common._topology import Operator, Scope
9+
from pyspark.ml.feature import ImputerModel
10+
from typing import List
811

9-
10-
def convert_imputer(scope, operator, container):
11-
op = operator.raw_operator
12-
12+
def convert_imputer(scope: Scope, operator: Operator, container):
13+
op: ImputerModel = operator.raw_operator
1314
op_type = 'Imputer'
1415
name = scope.get_unique_operator_name(op_type)
1516
attrs = {'name': name}
1617
input_type = operator.inputs[0].type
1718
surrogates = op.surrogateDF.toPandas().values[0].tolist()
1819
value = op.getOrDefault('missingValue')
20+
1921
if isinstance(input_type, FloatTensorType):
2022
attrs['imputed_value_floats'] = surrogates
2123
attrs['replaced_value_float'] = value
@@ -37,13 +39,12 @@ def convert_imputer(scope, operator, container):
3739
name=scope.get_unique_operator_name('Split'),
3840
op_version=2,
3941
axis=1,
40-
split=range(1, len(operator.output_full_names)))
42+
split=[1] * len(operator.output_full_names))
4143
else:
4244
container.add_node(op_type, operator.inputs[0].full_name, operator.output_full_names[0],
4345
op_domain='ai.onnx.ml',
4446
**attrs)
4547

46-
4748
register_converter('pyspark.ml.feature.ImputerModel', convert_imputer)
4849

4950
def calculate_imputer_output_shapes(operator):

onnxmltools/convert/sparkml/ops_input_output.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,12 @@ def build_io_name_map():
122122
lambda model: [model.getOrDefault("predictionCol")],
123123
),
124124
"pyspark.ml.feature.ImputerModel": (
125-
lambda model: model.getOrDefault("inputCols"),
126-
lambda model: model.getOrDefault("outputCols"),
125+
lambda model: model.getOrDefault("inputCols")
126+
if model.isSet("inputCols")
127+
else [model.getOrDefault("inputCol")],
128+
lambda model: model.getOrDefault("outputCols")
129+
if model.isSet("outputCols")
130+
else [model.getOrDefault("outputCol")],
127131
),
128132
"pyspark.ml.feature.MaxAbsScalerModel": (
129133
lambda model: [model.getOrDefault("inputCol")],

tests/sparkml/test_imputer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class TestSparkmlImputer(SparkMlTestCase):
2727
def test_imputer_single(self):
2828
self._imputer_test_single()
2929

30-
@unittest.skipIf(True, reason="Name:'Split' Status Message: Cannot split using values in 'split")
3130
@unittest.skipIf(sys.version_info < (3, 8),
3231
reason="pickle fails on python 3.7")
3332
def test_imputer_multi(self):
@@ -52,13 +51,20 @@ def _imputer_test_multi(self):
5251

5352
# run the model
5453
predicted = model.transform(data)
55-
expected = predicted.select("out_a", "out_b").toPandas().values.astype(numpy.float32)
54+
55+
expected = {
56+
"out_a": predicted.select("out_a").toPandas().values.astype(numpy.int64),
57+
"out_b": predicted.select("out_b").toPandas().values.astype(numpy.int64),
58+
}
59+
5660
data_np = data.toPandas().values.astype(numpy.float32)
5761
data_np = {'a': data_np[:, :1], 'b': data_np[:, 1:]}
5862
paths = save_data_models(data_np, expected, model, model_onnx, basename="SparkmlImputerMulti")
5963
onnx_model_path = paths[-1]
60-
output, output_shapes = run_onnx_model(['out_a', 'out_b'], data_np, onnx_model_path)
61-
compare_results(expected, output, decimal=5)
64+
output_names = ['out_a', 'out_b']
65+
output, output_shapes = run_onnx_model(output_names, data_np, onnx_model_path)
66+
actual_output = dict(zip(output_names, output))
67+
compare_results(expected, actual_output, decimal=5)
6268

6369
def _imputer_test_single(self):
6470
data = self.spark.createDataFrame([

0 commit comments

Comments
 (0)