@@ -277,13 +277,20 @@ class FactorizationMachinesPredictor(Predictor):
277277 to fit the model this Predictor performs inference on.
278278
279279 :meth:`predict()` returns a list of
280- :class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in
280+ :class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default
281+ recordio-protobuf ``deserializer`` is used), one for each row in
281282 the input ``ndarray``. The prediction is stored in the ``"score"`` key of
282283 the ``Record.label`` field. Please refer to the formats details described:
283284 https://docs.aws.amazon.com/sagemaker/latest/dg/fm-in-formats.html
284285 """
285286
286- def __init__ (self , endpoint_name , sagemaker_session = None ):
287+ def __init__ (
288+ self ,
289+ endpoint_name ,
290+ sagemaker_session = None ,
291+ serializer = RecordSerializer (),
292+ deserializer = RecordDeserializer (),
293+ ):
287294 """
288295 Args:
289296 endpoint_name (str): Name of the Amazon SageMaker endpoint to which
@@ -292,12 +299,16 @@ def __init__(self, endpoint_name, sagemaker_session=None):
292299 object, used for SageMaker interactions (default: None). If not
293300 specified, one is created using the default AWS configuration
294301 chain.
302+ serializer (sagemaker.serializers.BaseSerializer): Optional. Default
303+ serializes input data to x-recordio-protobuf format.
304+ deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
305+ Default parses responses from x-recordio-protobuf format.
295306 """
296307 super (FactorizationMachinesPredictor , self ).__init__ (
297308 endpoint_name ,
298309 sagemaker_session ,
299- serializer = RecordSerializer () ,
300- deserializer = RecordDeserializer () ,
310+ serializer = serializer ,
311+ deserializer = deserializer ,
301312 )
302313
303314
0 commit comments