File tree Expand file tree Collapse file tree 2 files changed +9
-3
lines changed Expand file tree Collapse file tree 2 files changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -123,6 +123,7 @@ def __init__(
123123 content_type = None ,
124124 content_template = None ,
125125 custom_attributes = None ,
126+ accelerator_type = None ,
126127 ):
127128 """Initializes a configuration of a model and the endpoint to be created for it.
128129
@@ -151,6 +152,9 @@ def __init__(
151152 Section 3.3.6. Field Value Components (
152153 https://tools.ietf.org/html/rfc7230#section-3.2.6) of the Hypertext Transfer
153154 Protocol (HTTP/1.1).
155+ accelerator_type (str): The Elastic Inference accelerator type to deploy to the model
156+ endpoint instance for making inferences to the model, see
157+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
154158 """
155159 self .predictor_config = {
156160 "model_name" : model_name ,
@@ -178,9 +182,8 @@ def __init__(
178182 f" Please include a placeholder $features."
179183 )
180184 self .predictor_config ["content_template" ] = content_template
181-
182- if custom_attributes is not None :
183- self .predictor_config ["custom_attributes" ] = custom_attributes
185+ _set (custom_attributes , "custom_attributes" , self .predictor_config )
186+ _set (accelerator_type , "accelerator_type" , self .predictor_config )
184187
185188 def get_predictor_config (self ):
186189 """Returns part of the predictor dictionary of the analysis config."""
Original file line number Diff line number Diff line change @@ -92,13 +92,15 @@ def test_model_config():
9292 accept_type = "text/csv"
9393 content_type = "application/jsonlines"
9494 custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4"
95+ accelerator_type = "ml.eia1.medium"
9596 model_config = ModelConfig (
9697 model_name = model_name ,
9798 instance_type = instance_type ,
9899 instance_count = instance_count ,
99100 accept_type = accept_type ,
100101 content_type = content_type ,
101102 custom_attributes = custom_attributes ,
103+ accelerator_type = accelerator_type ,
102104 )
103105 expected_config = {
104106 "model_name" : model_name ,
@@ -107,6 +109,7 @@ def test_model_config():
107109 "accept_type" : accept_type ,
108110 "content_type" : content_type ,
109111 "custom_attributes" : custom_attributes ,
112+ "accelerator_type" : accelerator_type ,
110113 }
111114 assert expected_config == model_config .get_predictor_config ()
112115
You can’t perform that action at this time.
0 commit comments