33
44# DeepSpeed Team
55import os
6+ import pickle
7+ import time
68import importlib
79import torch
810import mii
11+ from types import SimpleNamespace
912from huggingface_hub import HfApi
1013
1114from mii .models .score .generate import generated_score_path
2023 SUPPORTED_MODEL_TYPES ,
2124 ModelProvider ,
2225 REQUIRED_KEYS_PER_TASK ,
23- TEXT2IMG_NAME )
26+ TEXT2IMG_NAME ,
27+ MII_HF_CACHE_EXPIRATION ,
28+ MII_HF_CACHE_EXPIRATION_DEFAULT )
2429
2530from mii .constants import Tasks
2631
@@ -75,21 +80,55 @@ def get_task(task_name):
7580 assert False , f"Unknown Task { task_name } "
7681
7782
78- def _get_hf_models_by_type (model_type , task = None ):
79- api = HfApi ()
80- models = api .list_models (filter = model_type )
81- models = ([m .modelId for m in models ]
82- if task is None else [m .modelId for m in models if m .pipeline_tag == task ])
83+ def _get_hf_models_by_type (model_type = None , task = None ):
84+ cache_file_path = os .path .join (mii_cache_path (), "HF_model_cache.pkl" )
85+ cache_expiration_seconds = os .getenv (MII_HF_CACHE_EXPIRATION ,
86+ MII_HF_CACHE_EXPIRATION_DEFAULT )
87+
88+ # Load or initialize the cache
89+ model_data = {"cache_time" : 0 , "model_list" : []}
90+ if os .path .isfile (cache_file_path ):
91+ with open (cache_file_path , 'rb' ) as f :
92+ model_data = pickle .load (f )
93+
94+ current_time = time .time ()
95+
96+ # Update the cache if it has expired
97+ if (model_data ["cache_time" ] + cache_expiration_seconds ) < current_time :
98+ api = HfApi ()
99+ model_data ["model_list" ] = [
100+ SimpleNamespace (modelId = m .modelId ,
101+ pipeline_tag = m .pipeline_tag ,
102+ tags = m .tags ) for m in api .list_models ()
103+ ]
104+ model_data ["cache_time" ] = current_time
105+
106+ # Save the updated cache
107+ with open (cache_file_path , 'wb' ) as f :
108+ pickle .dump (model_data , f )
109+
110+ # Filter the model list
111+ models = model_data ["model_list" ]
112+ if model_type is not None :
113+ models = [m for m in models if model_type in m .tags ]
114+ if task is not None :
115+ models = [m for m in models if m .pipeline_tag == task ]
116+
117+ # Extract model IDs
118+ model_ids = [m .modelId for m in models ]
119+
83120 if task == TEXT_GENERATION_NAME :
84121 # TODO: this is a temp solution to get around some HF models not having the correct tags
85- models .append ("microsoft/bloom-deepspeed-inference-fp16" )
86- models .append ("microsoft/bloom-deepspeed-inference-int8" )
87- models .append ("EleutherAI/gpt-neox-20b" )
88- return models
122+ model_ids .extend ([
123+ "microsoft/bloom-deepspeed-inference-fp16" ,
124+ "microsoft/bloom-deepspeed-inference-int8" ,
125+ "EleutherAI/gpt-neox-20b"
126+ ])
89127
128+ return model_ids
90129
91- # TODO read this from a file containing list of files supported for each task
92- def _get_supported_models_name (task ):
130+
131+ def get_supported_models (task ):
93132 supported_models = []
94133 task_name = get_task_name (task )
95134
@@ -109,16 +148,18 @@ def _get_supported_models_name(task):
109148
110149
111150def check_if_task_and_model_is_supported (task , model_name ):
112- supported_models = _get_supported_models_name (task )
113- assert model_name in supported_models , f"{ task } only supports { supported_models } "
151+ supported_models = get_supported_models (task )
152+ assert (
153+ model_name in supported_models
154+ ), f"{ task } is not supported by { model_name } . This task is supported by { len (supported_models )} other models. See which models with `mii.get_supported_models(mii.{ task } )`."
114155
115156
116157def check_if_task_and_model_is_valid (task , model_name ):
117158 task_name = get_task_name (task )
118159 valid_task_models = _get_hf_models_by_type (None , task_name )
119160 assert (
120161 model_name in valid_task_models
121- ), f"{ task_name } only supports { valid_task_models } "
162+ ), f"{ task_name } is not supported by { model_name } . This task is supported by { len ( valid_task_models ) } other models. See which models with `mii.get_supported_models(mii. { task } )`. "
122163
123164
124165def full_model_path (model_path ):
0 commit comments