@@ -58,6 +58,20 @@ class ProblemTypes(str, Enum):
5858 TABULAR_CLASSIFICATION = "Classification"
5959
6060
61+ class Frameworks (str , Enum ):
62+ """Possible frameworks for JumpStart models"""
63+
64+ TENSORFLOW = "Tensorflow Hub"
65+ PYTORCH = "Pytorch Hub"
66+ HUGGINGFACE = "HuggingFace"
67+ CATBOOST = "Catboost"
68+ GLUONCV = "GluonCV"
69+ LIGHTGBM = "LightGBM"
70+ XGBOOST = "XGBoost"
71+ SCIKIT_LEARN = "ScikitLearn"
72+ SOURCE = "Source"
73+
74+
6175JUMPSTART_REGION = "eu-west-2"
6276SDK_MANIFEST_FILE = "models_manifest.json"
6377JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com" .format (
@@ -82,6 +96,61 @@ class ProblemTypes(str, Enum):
8296 Tasks .TABULAR_CLASSIFICATION : ProblemTypes .TABULAR_CLASSIFICATION ,
8397}
8498
99+ TO_FRAMEWORK = {
100+ "Tensorflow Hub" : Frameworks .TENSORFLOW ,
101+ "Pytorch Hub" : Frameworks .PYTORCH ,
102+ "HuggingFace" : Frameworks .HUGGINGFACE ,
103+ "Catboost" : Frameworks .CATBOOST ,
104+ "GluonCV" : Frameworks .GLUONCV ,
105+ "LightGBM" : Frameworks .LIGHTGBM ,
106+ "XGBoost" : Frameworks .XGBOOST ,
107+ "ScikitLearn" : Frameworks .SCIKIT_LEARN ,
108+ "Source" : Frameworks .SOURCE ,
109+ }
110+
111+
112+ MODALITY_MAP = {
113+ (Tasks .IC , Frameworks .PYTORCH ): "algorithms/vision/image_classification_pytorch.rst" ,
114+ (Tasks .IC , Frameworks .TENSORFLOW ): "algorithms/vision/image_classification_tensorflow.rst" ,
115+ (Tasks .IC_EMBEDDING , Frameworks .TENSORFLOW ): "algorithms/vision/image_embedding_tensorflow.rst" ,
116+ (Tasks .IS , Frameworks .GLUONCV ): "algorithms/vision/instance_segmentation_mxnet.rst" ,
117+ (Tasks .OD , Frameworks .GLUONCV ): "algorithms/vision/object_detection_mxnet.rst" ,
118+ (Tasks .OD , Frameworks .PYTORCH ): "algorithms/vision/object_detection_pytorch.rst" ,
119+ (Tasks .OD , Frameworks .TENSORFLOW ): "algorithms/vision/object_detection_tensorflow.rst" ,
120+ (Tasks .SEMSEG , Frameworks .GLUONCV ): "algorithms/vision/semantic_segmentation_mxnet.rst" ,
121+ (
122+ Tasks .TRANSLATION ,
123+ Frameworks .HUGGINGFACE ,
124+ ): "algorithms/text/machine_translation_hugging_face.rst" ,
125+ (Tasks .NER , Frameworks .GLUONCV ): "algorithms/text/named_entity_recognition_hugging_face.rst" ,
126+ (Tasks .EQA , Frameworks .PYTORCH ): "algorithms/text/question_answering_pytorch.rst" ,
127+ (
128+ Tasks .SPC ,
129+ Frameworks .HUGGINGFACE ,
130+ ): "algorithms/text/sentence_pair_classification_hugging_face.rst" ,
131+ (
132+ Tasks .SPC ,
133+ Frameworks .TENSORFLOW ,
134+ ): "algorithms/text/sentence_pair_classification_tensorflow.rst" ,
135+ (Tasks .TC , Frameworks .TENSORFLOW ): "algorithms/text/text_classification_tensorflow.rst" ,
136+ (
137+ Tasks .TC_EMBEDDING ,
138+ Frameworks .GLUONCV ,
139+ ): "algorithms/vision/text_embedding_tensorflow_mxnet.rst" ,
140+ (
141+ Tasks .TC_EMBEDDING ,
142+ Frameworks .TENSORFLOW ,
143+ ): "algorithms/vision/text_embedding_tensorflow_mxnet.rst" ,
144+ (
145+ Tasks .TEXT_GENERATION ,
146+ Frameworks .HUGGINGFACE ,
147+ ): "algorithms/text/text_generation_hugging_face.rst" ,
148+ (
149+ Tasks .SUMMARIZATION ,
150+ Frameworks .HUGGINGFACE ,
151+ ): "algorithms/text/text_summarization_hugging_face.rst" ,
152+ }
153+
85154
86155def get_jumpstart_sdk_manifest ():
87156 url = "{}/{}" .format (JUMPSTART_BUCKET_BASE_URL , SDK_MANIFEST_FILE )
@@ -102,6 +171,10 @@ def get_model_task(id):
102171 return TASK_MAP [task_short ] if task_short in TASK_MAP else "Source"
103172
104173
174+ def get_string_model_task (id ):
175+ return id .split ("-" )[1 ]
176+
177+
105178def get_model_source (url ):
106179 if "tfhub" in url :
107180 return "Tensorflow Hub"
@@ -113,8 +186,6 @@ def get_model_source(url):
113186 return "Catboost"
114187 if "gluon" in url :
115188 return "GluonCV"
116- if "catboost" in url :
117- return "Catboost"
118189 if "lightgbm" in url :
119190 return "LightGBM"
120191 if "xgboost" in url :
@@ -138,58 +209,97 @@ def create_jumpstart_model_table():
138209 ) < Version (model ["version" ]):
139210 sdk_manifest_top_versions_for_models [model ["model_id" ]] = model
140211
141- file_content = []
212+ file_content_intro = []
142213
143- file_content .append (".. _all-pretrained-models:\n \n " )
144- file_content .append (".. |external-link| raw:: html\n \n " )
145- file_content .append (' <i class="fa fa-external-link"></i>\n \n ' )
214+ file_content_intro .append (".. _all-pretrained-models:\n \n " )
215+ file_content_intro .append (".. |external-link| raw:: html\n \n " )
216+ file_content_intro .append (' <i class="fa fa-external-link"></i>\n \n ' )
146217
147- file_content .append ("================================================\n " )
148- file_content .append ("Built-in Algorithms with pre-trained Model Table\n " )
149- file_content .append ("================================================\n " )
150- file_content .append (
218+ file_content_intro .append ("================================================\n " )
219+ file_content_intro .append ("Built-in Algorithms with pre-trained Model Table\n " )
220+ file_content_intro .append ("================================================\n " )
221+ file_content_intro .append (
151222 """
152223 The SageMaker Python SDK uses model IDs and model versions to access the necessary
153224 utilities for pre-trained models. This table serves to provide the core material plus
154225 some extra information that can be useful in selecting the correct model ID and
155226 corresponding parameters.\n """
156227 )
157- file_content .append (
228+ file_content_intro .append (
158229 """
159230 If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute.
160231 We highly suggest pinning an exact model version however.\n """
161232 )
162- file_content .append (
233+ file_content_intro .append (
163234 """
164235 These models are also available through the
165236 `JumpStart UI in SageMaker Studio <https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html>`__\n """
166237 )
167- file_content .append ("\n " )
168- file_content .append (".. list-table:: Available Models\n " )
169- file_content .append (" :widths: 50 20 20 20 30 20\n " )
170- file_content .append (" :header-rows: 1\n " )
171- file_content .append (" :class: datatable\n " )
172- file_content .append ("\n " )
173- file_content .append (" * - Model ID\n " )
174- file_content .append (" - Fine Tunable?\n " )
175- file_content .append (" - Latest Version\n " )
176- file_content .append (" - Min SDK Version\n " )
177- file_content .append (" - Problem Type\n " )
178- file_content .append (" - Source\n " )
238+ file_content_intro .append ("\n " )
239+ file_content_intro .append (".. list-table:: Available Models\n " )
240+ file_content_intro .append (" :widths: 50 20 20 20 30 20\n " )
241+ file_content_intro .append (" :header-rows: 1\n " )
242+ file_content_intro .append (" :class: datatable\n " )
243+ file_content_intro .append ("\n " )
244+ file_content_intro .append (" * - Model ID\n " )
245+ file_content_intro .append (" - Fine Tunable?\n " )
246+ file_content_intro .append (" - Latest Version\n " )
247+ file_content_intro .append (" - Min SDK Version\n " )
248+ file_content_intro .append (" - Problem Type\n " )
249+ file_content_intro .append (" - Source\n " )
250+
251+ dynamic_table_files = []
252+ file_content_entries = []
179253
180254 for model in sdk_manifest_top_versions_for_models .values ():
181255 model_spec = get_jumpstart_sdk_spec (model ["spec_key" ])
182256 model_task = get_model_task (model_spec ["model_id" ])
257+ string_model_task = get_string_model_task (model_spec ["model_id" ])
183258 model_source = get_model_source (model_spec ["url" ])
184- file_content .append (" * - {}\n " .format (model_spec ["model_id" ]))
185- file_content .append (" - {}\n " .format (model_spec ["training_supported" ]))
186- file_content .append (" - {}\n " .format (model ["version" ]))
187- file_content .append (" - {}\n " .format (model ["min_version" ]))
188- file_content .append (" - {}\n " .format (model_task ))
189- file_content .append (
259+ file_content_entries .append (" * - {}\n " .format (model_spec ["model_id" ]))
260+ file_content_entries .append (" - {}\n " .format (model_spec ["training_supported" ]))
261+ file_content_entries .append (" - {}\n " .format (model ["version" ]))
262+ file_content_entries .append (" - {}\n " .format (model ["min_version" ]))
263+ file_content_entries .append (" - {}\n " .format (model_task ))
264+ file_content_entries .append (
190265 " - `{} <{}>`__ |external-link|\n " .format (model_source , model_spec ["url" ])
191266 )
192267
193- f = open ("doc_utils/pretrainedmodels.rst" , "w" )
194- f .writelines (file_content )
268+ if (string_model_task , TO_FRAMEWORK [model_source ]) in MODALITY_MAP :
269+ file_content_single_entry = []
270+
271+ if (
272+ MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])]
273+ not in dynamic_table_files
274+ ):
275+ file_content_single_entry .append ("\n " )
276+ file_content_single_entry .append (".. list-table:: Available Models\n " )
277+ file_content_single_entry .append (" :widths: 50 20 20 20 20\n " )
278+ file_content_single_entry .append (" :header-rows: 1\n " )
279+ file_content_single_entry .append (" :class: datatable\n " )
280+ file_content_single_entry .append ("\n " )
281+ file_content_single_entry .append (" * - Model ID\n " )
282+ file_content_single_entry .append (" - Fine Tunable?\n " )
283+ file_content_single_entry .append (" - Latest Version\n " )
284+ file_content_single_entry .append (" - Min SDK Version\n " )
285+ file_content_single_entry .append (" - Source\n " )
286+
287+ dynamic_table_files .append (
288+ MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])]
289+ )
290+
291+ file_content_single_entry .append (" * - {}\n " .format (model_spec ["model_id" ]))
292+ file_content_single_entry .append (" - {}\n " .format (model_spec ["training_supported" ]))
293+ file_content_single_entry .append (" - {}\n " .format (model ["version" ]))
294+ file_content_single_entry .append (" - {}\n " .format (model ["min_version" ]))
295+ file_content_single_entry .append (
296+ " - `{} <{}>`__\n " .format (model_source , model_spec ["url" ])
297+ )
298+ f = open (MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])], "a" )
299+ f .writelines (file_content_single_entry )
300+ f .close ()
301+
302+ f = open ("doc_utils/pretrainedmodels.rst" , "a" )
303+ f .writelines (file_content_intro )
304+ f .writelines (file_content_entries )
195305 f .close ()
0 commit comments