1717
1818from sagemaker .cli .compatibility .v2 .modifiers .modifier import Modifier
1919
20+ FRAMEWORK_ARG = "framework_version"
21+ PY_ARG = "py_version"
22+
2023FRAMEWORK_DEFAULTS = {
2124 "Chainer" : "4.1.0" ,
2225 "MXNet" : "1.2.0" ,
2528 "TensorFlow" : "1.11.0" ,
2629}
2730
28- FRAMEWORKS = list (FRAMEWORK_DEFAULTS .keys ())
31+ FRAMEWORK_CLASSES = list (FRAMEWORK_DEFAULTS .keys ())
32+ MODEL_CLASSES = ["{}Model" .format (fw ) for fw in FRAMEWORK_CLASSES ]
33+
2934# TODO: check for sagemaker.tensorflow.serving.Model
30- FRAMEWORK_CLASSES = FRAMEWORKS + ["{}Model" .format (fw ) for fw in FRAMEWORKS ]
31- FRAMEWORK_MODULES = [fw .lower () for fw in FRAMEWORKS ]
35+ FRAMEWORK_MODULES = [fw .lower () for fw in FRAMEWORK_CLASSES ]
3236FRAMEWORK_SUBMODULES = ("model" , "estimator" )
3337
3438
@@ -39,7 +43,8 @@ class FrameworkVersionEnforcer(Modifier):
3943
4044 def node_should_be_modified (self , node ):
4145 """Checks if the ast.Call node instantiates a framework estimator or model,
42- but doesn't specify the ``framework_version`` parameter.
46+ but doesn't specify the ``framework_version`` and ``py_version`` parameter,
47+ as appropriate.
4348
4449 This looks for the following formats:
4550
@@ -56,49 +61,12 @@ def node_should_be_modified(self, node):
5661 bool: If the ``ast.Call`` is instantiating a framework class that
5762 should specify ``framework_version``, but doesn't.
5863 """
59- if self . _is_framework_constructor (node ):
60- return not self . _fw_version_in_keywords (node )
64+ if _is_named_constructor (node , FRAMEWORK_CLASSES ):
65+ return _version_args_needed (node , "image_name" )
6166
62- return False
67+ if _is_named_constructor (node , MODEL_CLASSES ):
68+ return _version_args_needed (node , "image" )
6369
64- def _is_framework_constructor (self , node ):
65- """Checks if the ``ast.Call`` node represents a call of the form
66- <Framework> or sagemaker.<framework>.<Framework>.
67- """
68- # Check for <Framework> call
69- if isinstance (node .func , ast .Name ):
70- return node .func .id in FRAMEWORK_CLASSES
71-
72- # Check for something.that.ends.with.<framework>.<Framework> call
73- if not (isinstance (node .func , ast .Attribute ) and node .func .attr in FRAMEWORK_CLASSES ):
74- return False
75-
76- # Check for sagemaker.<frameworks>.<estimator/model>.<Framework> call
77- if (
78- isinstance (node .func .value , ast .Attribute )
79- and node .func .value .attr in FRAMEWORK_SUBMODULES
80- ):
81- return self ._is_in_framework_module (node .func .value )
82-
83- # Check for sagemaker.<framework>.<Framework> call
84- return self ._is_in_framework_module (node .func )
85-
86- def _is_in_framework_module (self , node ):
87- """Checks if the node is an ``ast.Attribute`` that represents a
88- ``sagemaker.<framework>`` module.
89- """
90- return (
91- isinstance (node .value , ast .Attribute )
92- and node .value .attr in FRAMEWORK_MODULES
93- and isinstance (node .value .value , ast .Name )
94- and node .value .value .id == "sagemaker"
95- )
96-
97- def _fw_version_in_keywords (self , node ):
98- """Checks if the ``ast.Call`` node's keywords contain ``framework_version``."""
99- for kw in node .keywords :
100- if kw .arg == "framework_version" and kw .value :
101- return True
10270 return False
10371
10472 def modify_node (self , node ):
@@ -112,30 +80,146 @@ def modify_node(self, node):
11280 - SKLearn: "0.20.0"
11381 - TensorFlow: "1.11.0"
11482
83+ The ``py_version`` value is determined by the framework, framework_version, and if it is a
84+ model, whether the model accepts a py_version
85+
11586 Args:
11687 node (ast.Call): a node that represents the constructor of a framework class.
11788 """
118- framework = self ._framework_name_from_node (node )
119- node .keywords .append (
120- ast .keyword (arg = "framework_version" , value = ast .Str (s = FRAMEWORK_DEFAULTS [framework ]))
121- )
89+ framework , is_model = _framework_from_node (node )
12290
123- def _framework_name_from_node (self , node ):
124- """Retrieves the framework name based on the function call.
91+ # if framework_version is not supplied, get default and append keyword
92+ framework_version = _arg_value (node , FRAMEWORK_ARG )
93+ if framework_version is None :
94+ framework_version = FRAMEWORK_DEFAULTS [framework ]
95+ node .keywords .append (ast .keyword (arg = FRAMEWORK_ARG , value = ast .Str (s = framework_version )))
12596
126- Args:
127- node (ast.Call): a node that represents the constructor of a framework class.
128- This can represent either <Framework> or sagemaker.<framework>.<Framework>.
97+ # if py_version is not supplied, get a conditional default, and if not None, append keyword
98+ py_version = _arg_value (node , PY_ARG )
99+ if py_version is None :
100+ py_version = _py_version_defaults (framework , framework_version , is_model )
101+ if py_version :
102+ node .keywords .append (ast .keyword (arg = PY_ARG , value = ast .Str (s = py_version )))
129103
130- Returns:
131- str: the (capitalized) framework name.
132- """
133- if isinstance (node .func , ast .Name ):
134- framework = node .func .id
135- elif isinstance (node .func , ast .Attribute ):
136- framework = node .func .attr
137104
138- if framework .endswith ("Model" ):
139- framework = framework [: framework .find ("Model" )]
105+ def _py_version_defaults (framework , framework_version , is_model = False ):
106+ """Gets the py_version required for the framework_version and if it's a model
107+
108+ Args:
109+ framework (str): name of the framework
110+ framework_version (str): version of the framework
111+ is_model (bool): whether it is a constructor for a model or not
112+
113+ Returns:
114+ str: the default py version, as appropriate. None if no default py_version
115+ """
116+ if framework in ("Chainer" , "PyTorch" ):
117+ return "py3"
118+ if framework == "SKLearn" and not is_model :
119+ return "py3"
120+ if framework == "MXNet" :
121+ return "py2"
122+ if framework == "TensorFlow" and not is_model :
123+ return _tf_py_version_default (framework_version )
124+ return None
125+
126+
127+ def _tf_py_version_default (framework_version ):
128+ """Gets the py_version default based on framework_version for TensorFlow."""
129+ if not framework_version :
130+ return "py2"
131+ version = [int (s ) for s in framework_version .split ("." )]
132+ if version < [1 , 12 ]:
133+ return "py2"
134+ if version < [2 , 2 ]:
135+ return "py3"
136+ return "py37"
137+
138+
139+ def _framework_from_node (node ):
140+ """Retrieves the framework class name based on the function call, and if it was a model
141+
142+ Args:
143+ node (ast.Call): a node that represents the constructor of a framework class.
144+ This can represent either <Framework> or sagemaker.<framework>.<Framework>.
145+
146+ Returns:
147+ str, bool: the (capitalized) framework class name, and if it is a model class
148+ """
149+ if isinstance (node .func , ast .Name ):
150+ framework = node .func .id
151+ elif isinstance (node .func , ast .Attribute ):
152+ framework = node .func .attr
153+ else :
154+ framework = ""
155+
156+ is_model = framework .endswith ("Model" )
157+ if is_model :
158+ framework = framework [: framework .find ("Model" )]
159+
160+ return framework , is_model
161+
162+
163+ def _is_named_constructor (node , names ):
164+ """Checks if the ``ast.Call`` node represents a call to particular named constructors.
165+
166+ Forms that qualify are either <Framework> or sagemaker.<framework>.<Framework>
167+ where <Framework> belongs to the list of names passed in.
168+ """
169+ # Check for call from particular names of constructors
170+ if isinstance (node .func , ast .Name ):
171+ return node .func .id in names
172+
173+ # Check for something.that.ends.with.<framework>.<Framework> call for Framework in names
174+ if not (isinstance (node .func , ast .Attribute ) and node .func .attr in names ):
175+ return False
176+
177+ # Check for sagemaker.<frameworks>.<estimator/model>.<Framework> call
178+ if isinstance (node .func .value , ast .Attribute ) and node .func .value .attr in FRAMEWORK_SUBMODULES :
179+ return _is_in_framework_module (node .func .value )
180+
181+ # Check for sagemaker.<framework>.<Framework> call
182+ return _is_in_framework_module (node .func )
183+
184+
185+ def _is_in_framework_module (node ):
186+ """Checks if node is an ``ast.Attribute`` representing a ``sagemaker.<framework>`` module."""
187+ return (
188+ isinstance (node .value , ast .Attribute )
189+ and node .value .attr in FRAMEWORK_MODULES
190+ and isinstance (node .value .value , ast .Name )
191+ and node .value .value .id == "sagemaker"
192+ )
193+
194+
195+ def _version_args_needed (node , image_arg ):
196+ """Determines if image_arg or version_arg was supplied
197+
198+ Applies similar logic as ``validate_version_or_image_args``
199+ """
200+ # if image_arg is present, no need to supply version arguments
201+ image_name = _arg_value (node , image_arg )
202+ if image_name :
203+ return False
204+
205+ # if framework_version is None, need args
206+ framework_version = _arg_value (node , FRAMEWORK_ARG )
207+ if framework_version is None :
208+ return True
209+
210+ # check if we expect py_version and we don't get it -- framework and model dependent
211+ framework , is_model = _framework_from_node (node )
212+ expecting_py_version = _py_version_defaults (framework , framework_version , is_model )
213+ if expecting_py_version :
214+ py_version = _arg_value (node , PY_ARG )
215+ return py_version is None
216+
217+ return False
218+
140219
141- return framework
220+ def _arg_value (node , arg ):
221+ """Gets the value associated with the arg keyword, if present"""
222+ for kw in node .keywords :
223+ if kw .arg == arg and kw .value :
224+ return kw .value .s
225+ return None
0 commit comments