@@ -68,6 +68,34 @@ def modify_node(self, node):
6868 keyword .arg = self .new_param_name
6969
7070
71+ class MethodParamRenamer (ParamRenamer ):
72+ """Abstract class to handle parameter renames for methods that belong to objects.
73+
74+ This differs from ``ParamRenamer`` in that a node for a standalone function call
75+ (i.e. where ``node.func`` is an ``ast.Name`` rather than an ``ast.Attribute``) is not modified.
76+ """
77+
78+ def node_should_be_modified (self , node ):
79+ """Checks if the node matches any of the relevant functions and
80+ contains the parameter to be renamed.
81+
82+ This looks for a call of the form ``<object>.<method>``, and
83+ assumes the method cannot be called on its own.
84+
85+ Args:
86+ node (ast.Call): a node that represents a function call. For more,
87+ see https://docs.python.org/3/library/ast.html#abstract-grammar.
88+
89+ Returns:
90+ bool: If the ``ast.Call`` matches the relevant function calls and
91+ contains the parameter to be renamed.
92+ """
93+ if isinstance (node .func , ast .Name ):
94+ return False
95+
96+ return super (MethodParamRenamer , self ).node_should_be_modified (node )
97+
98+
7199class DistributionParameterRenamer (ParamRenamer ):
72100 """A class to rename the ``distributions`` attribute to ``distrbution`` in
73101 MXNet and TensorFlow estimators.
@@ -100,7 +128,7 @@ def new_param_name(self):
100128 return "distribution"
101129
102130
103- class S3SessionRenamer (ParamRenamer ):
131+ class S3SessionRenamer (MethodParamRenamer ):
104132 """A class to rename the ``session`` attribute to ``sagemaker_session`` in
105133 ``S3Uploader`` and ``S3Downloader``.
106134
@@ -139,15 +167,6 @@ def new_param_name(self):
139167 """The new name for the SageMaker session argument."""
140168 return "sagemaker_session"
141169
142- def node_should_be_modified (self , node ):
143- """Checks if the node is one of the S3 utility functions and
144- contains the ``session`` parameter.
145- """
146- if isinstance (node .func , ast .Name ):
147- return False
148-
149- return super (S3SessionRenamer , self ).node_should_be_modified (node )
150-
151170
152171class EstimatorImageURIRenamer (ParamRenamer ):
153172 """A class to rename the ``image_name`` attribute to ``image_uri`` in estimators."""
@@ -209,3 +228,93 @@ def old_param_name(self):
209228 def new_param_name (self ):
210229 """The new name for the image URI argument."""
211230 return "image_uri"
231+
232+
233+ class EstimatorCreateModelImageURIRenamer (MethodParamRenamer ):
234+ """A class to rename ``image`` to ``image_uri`` in estimator ``create_model()`` methods."""
235+
236+ @property
237+ def calls_to_modify (self ):
238+ """A mapping of ``create_model`` to common variable names for estimators."""
239+ return {
240+ "create_model" : (
241+ "estimator" ,
242+ "chainer" ,
243+ "mxnet" ,
244+ "mx" ,
245+ "pytorch" ,
246+ "rl" ,
247+ "sklearn" ,
248+ "tensorflow" ,
249+ "tf" ,
250+ "xgboost" ,
251+ "xgb" ,
252+ )
253+ }
254+
255+ @property
256+ def old_param_name (self ):
257+ """The previous name for the image URI argument."""
258+ return "image"
259+
260+ @property
261+ def new_param_name (self ):
262+ """The new name for the the image URI argument."""
263+ return "image_uri"
264+
265+
266+ class SessionCreateModelImageURIRenamer (MethodParamRenamer ):
267+ """A class to rename ``primary_container_image`` to ``image_uri``.
268+
269+ This looks for the following calls:
270+
271+ - ``sagemaker_session.create_model_from_job()``
272+ - ``sess.create_model_from_job()``
273+ """
274+
275+ @property
276+ def calls_to_modify (self ):
277+ """A mapping of ``create_model_from_job`` to common variable names for Session."""
278+ return {
279+ "create_model_from_job" : ("sagemaker_session" , "sess" ),
280+ }
281+
282+ @property
283+ def old_param_name (self ):
284+ """The previous name for the image URI argument."""
285+ return "primary_container_image"
286+
287+ @property
288+ def new_param_name (self ):
289+ """The new name for the the image URI argument."""
290+ return "image_uri"
291+
292+
293+ class SessionCreateEndpointImageURIRenamer (MethodParamRenamer ):
294+ """A class to rename ``deployment_image`` to ``image_uri``.
295+
296+ This looks for the following calls:
297+
298+ - ``sagemaker_session.endpoint_from_job()``
299+ - ``sess.endpoint_from_job()``
300+ - ``sagemaker_session.endpoint_from_model_data()``
301+ - ``sess.endpoint_from_model_data()``
302+ """
303+
304+ @property
305+ def calls_to_modify (self ):
306+ """A mapping of the ``endpoint_from_*`` functions to common variable names for Session."""
307+ return {
308+ "endpoint_from_job" : ("sagemaker_session" , "sess" ),
309+ "endpoint_from_model_data" : ("sagemaker_session" , "sess" ),
310+ }
311+
312+ @property
313+ def old_param_name (self ):
314+ """The previous name for the image URI argument."""
315+ return "deployment_image"
316+
317+ @property
318+ def new_param_name (self ):
319+ """The new name for the the image URI argument."""
320+ return "image_uri"
0 commit comments