1414from urllib import request
1515import json
1616from packaging .version import Version
17+ from enum import Enum
18+
19+
20+ class Tasks (str , Enum ):
21+ """The ML task name as referenced in the infix of the model ID."""
22+
23+ IC = "ic"
24+ OD = "od"
25+ OD1 = "od1"
26+ SEMSEG = "semseg"
27+ IS = "is"
28+ TC = "tc"
29+ SPC = "spc"
30+ EQA = "eqa"
31+ TEXT_GENERATION = "textgeneration"
32+ IC_EMBEDDING = "icembedding"
33+ TC_EMBEDDING = "tcembedding"
34+ NER = "ner"
35+ SUMMARIZATION = "summarization"
36+ TRANSLATION = "translation"
37+ TABULAR_REGRESSION = "regression"
38+ TABULAR_CLASSIFICATION = "classification"
39+
40+
41+ class ProblemTypes (str , Enum ):
42+ """Possible problem types for JumpStart models."""
43+
44+ IMAGE_CLASSIFICATION = "Image Classification"
45+ IMAGE_EMBEDDING = "Image Embedding"
46+ OBJECT_DETECTION = "Object Detection"
47+ SEMANTIC_SEGMENTATION = "Semantic Segmentation"
48+ INSTANCE_SEGMENTATION = "Instance Segmentation"
49+ TEXT_CLASSIFICATION = "Text Classification"
50+ TEXT_EMBEDDING = "Text Embedding"
51+ QUESTION_ANSWERING = "Question Answering"
52+ SENTENCE_PAIR_CLASSIFICATION = "Sentence Pair Classification"
53+ TEXT_GENERATION = "Text Generation"
54+ TEXT_SUMMARIZATION = "Text Summarization"
55+ MACHINE_TRANSLATION = "Machine Translation"
56+ NAMED_ENTITY_RECOGNITION = "Named Entity Recognition"
57+ TABULAR_REGRESSION = "Regression"
58+ TABULAR_CLASSIFICATION = "Classification"
59+
1760
1861JUMPSTART_REGION = "eu-west-2"
1962SDK_MANIFEST_FILE = "models_manifest.json"
2063JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com" .format (
2164 JUMPSTART_REGION , JUMPSTART_REGION
2265)
66+ TASK_MAP = {
67+ Tasks .IC : ProblemTypes .IMAGE_CLASSIFICATION ,
68+ Tasks .IC_EMBEDDING : ProblemTypes .IMAGE_EMBEDDING ,
69+ Tasks .OD : ProblemTypes .OBJECT_DETECTION ,
70+ Tasks .OD1 : ProblemTypes .OBJECT_DETECTION ,
71+ Tasks .SEMSEG : ProblemTypes .SEMANTIC_SEGMENTATION ,
72+ Tasks .IS : ProblemTypes .INSTANCE_SEGMENTATION ,
73+ Tasks .TC : ProblemTypes .TEXT_CLASSIFICATION ,
74+ Tasks .TC_EMBEDDING : ProblemTypes .TEXT_EMBEDDING ,
75+ Tasks .EQA : ProblemTypes .QUESTION_ANSWERING ,
76+ Tasks .SPC : ProblemTypes .SENTENCE_PAIR_CLASSIFICATION ,
77+ Tasks .TEXT_GENERATION : ProblemTypes .TEXT_GENERATION ,
78+ Tasks .SUMMARIZATION : ProblemTypes .TEXT_SUMMARIZATION ,
79+ Tasks .TRANSLATION : ProblemTypes .MACHINE_TRANSLATION ,
80+ Tasks .NER : ProblemTypes .NAMED_ENTITY_RECOGNITION ,
81+ Tasks .TABULAR_REGRESSION : ProblemTypes .TABULAR_REGRESSION ,
82+ Tasks .TABULAR_CLASSIFICATION : ProblemTypes .TABULAR_CLASSIFICATION ,
83+ }
2384
2485
2586def get_jumpstart_sdk_manifest ():
@@ -36,6 +97,11 @@ def get_jumpstart_sdk_spec(key):
3697 return json .loads (model_spec )
3798
3899
100+ def get_model_task (id ):
101+ task_short = id .split ("-" )[1 ]
102+ return TASK_MAP [task_short ] if task_short in TASK_MAP else "Source"
103+
104+
39105def create_jumpstart_model_table ():
40106 sdk_manifest = get_jumpstart_sdk_manifest ()
41107 sdk_manifest_top_versions_for_models = {}
@@ -69,26 +135,29 @@ def create_jumpstart_model_table():
69135 )
70136 file_content .append (
71137 """
72- Each model id is linked to an external page that describes the model.\n
138+ Click on the Problem Type to navigate to the source of the model.\n
73139 """
74140 )
75141 file_content .append ("\n " )
76142 file_content .append (".. list-table:: Available Models\n " )
77- file_content .append (" :widths: 50 20 20 20\n " )
143+ file_content .append (" :widths: 50 20 20 20 30 \n " )
78144 file_content .append (" :header-rows: 1\n " )
79145 file_content .append (" :class: datatable\n " )
80146 file_content .append ("\n " )
81147 file_content .append (" * - Model ID\n " )
82148 file_content .append (" - Fine Tunable?\n " )
83149 file_content .append (" - Latest Version\n " )
84150 file_content .append (" - Min SDK Version\n " )
151+ file_content .append (" - Problem Type/Source\n " )
85152
86153 for model in sdk_manifest_top_versions_for_models .values ():
87154 model_spec = get_jumpstart_sdk_spec (model ["spec_key" ])
88- file_content .append (" * - `{} <{}>`_\n " .format (model_spec ["model_id" ], model_spec ["url" ]))
155+ model_task = get_model_task (model_spec ["model_id" ])
156+ file_content .append (" * - {}\n " .format (model_spec ["model_id" ]))
89157 file_content .append (" - {}\n " .format (model_spec ["training_supported" ]))
90158 file_content .append (" - {}\n " .format (model ["version" ]))
91159 file_content .append (" - {}\n " .format (model ["min_version" ]))
160+ file_content .append (" - `{} <{}>`__\n " .format (model_task , model_spec ["url" ]))
92161
93162 f = open ("doc_utils/jumpstart.rst" , "w" )
94163 f .writelines (file_content )
0 commit comments