Skip to content

Commit 0182fa5

Browse files
authored
Cache HF API results (#233)
* cache HF API results * add option for modifying cache expiration
1 parent 0acf569 commit 0182fa5

File tree

3 files changed

+60
-16
lines changed

3 files changed

+60
-16
lines changed

mii/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .terminate import terminate
1010
from .constants import DeploymentType, Tasks
1111
from .aml_related.utils import aml_output_path
12-
12+
from .utils import get_supported_models
1313
from .config import MIIConfig, LoadBalancerConfig
1414
from .grpc_related.proto import modelresponse_pb2_grpc
1515

mii/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ class ModelProvider(enum.Enum):
104104
MII_CACHE_PATH = "MII_CACHE_PATH"
105105
MII_CACHE_PATH_DEFAULT = "/tmp/mii_cache"
106106

107+
MII_HF_CACHE_EXPIRATION = "MII_HF_CACHE_EXPIRATION"
108+
MII_HF_CACHE_EXPIRATION_DEFAULT = 60 * 60 # 1 hour
109+
107110
MII_DEBUG_MODE = "MII_DEBUG_MODE"
108111
MII_DEBUG_MODE_DEFAULT = "0"
109112

mii/utils.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33

44
# DeepSpeed Team
55
import os
6+
import pickle
7+
import time
68
import importlib
79
import torch
810
import mii
11+
from types import SimpleNamespace
912
from huggingface_hub import HfApi
1013

1114
from mii.models.score.generate import generated_score_path
@@ -20,7 +23,9 @@
2023
SUPPORTED_MODEL_TYPES,
2124
ModelProvider,
2225
REQUIRED_KEYS_PER_TASK,
23-
TEXT2IMG_NAME)
26+
TEXT2IMG_NAME,
27+
MII_HF_CACHE_EXPIRATION,
28+
MII_HF_CACHE_EXPIRATION_DEFAULT)
2429

2530
from mii.constants import Tasks
2631

@@ -75,21 +80,55 @@ def get_task(task_name):
7580
assert False, f"Unknown Task {task_name}"
7681

7782

78-
def _get_hf_models_by_type(model_type, task=None):
79-
api = HfApi()
80-
models = api.list_models(filter=model_type)
81-
models = ([m.modelId for m in models]
82-
if task is None else [m.modelId for m in models if m.pipeline_tag == task])
83+
def _get_hf_models_by_type(model_type=None, task=None):
84+
cache_file_path = os.path.join(mii_cache_path(), "HF_model_cache.pkl")
85+
cache_expiration_seconds = os.getenv(MII_HF_CACHE_EXPIRATION,
86+
MII_HF_CACHE_EXPIRATION_DEFAULT)
87+
88+
# Load or initialize the cache
89+
model_data = {"cache_time": 0, "model_list": []}
90+
if os.path.isfile(cache_file_path):
91+
with open(cache_file_path, 'rb') as f:
92+
model_data = pickle.load(f)
93+
94+
current_time = time.time()
95+
96+
# Update the cache if it has expired
97+
if (model_data["cache_time"] + cache_expiration_seconds) < current_time:
98+
api = HfApi()
99+
model_data["model_list"] = [
100+
SimpleNamespace(modelId=m.modelId,
101+
pipeline_tag=m.pipeline_tag,
102+
tags=m.tags) for m in api.list_models()
103+
]
104+
model_data["cache_time"] = current_time
105+
106+
# Save the updated cache
107+
with open(cache_file_path, 'wb') as f:
108+
pickle.dump(model_data, f)
109+
110+
# Filter the model list
111+
models = model_data["model_list"]
112+
if model_type is not None:
113+
models = [m for m in models if model_type in m.tags]
114+
if task is not None:
115+
models = [m for m in models if m.pipeline_tag == task]
116+
117+
# Extract model IDs
118+
model_ids = [m.modelId for m in models]
119+
83120
if task == TEXT_GENERATION_NAME:
84121
# TODO: this is a temp solution to get around some HF models not having the correct tags
85-
models.append("microsoft/bloom-deepspeed-inference-fp16")
86-
models.append("microsoft/bloom-deepspeed-inference-int8")
87-
models.append("EleutherAI/gpt-neox-20b")
88-
return models
122+
model_ids.extend([
123+
"microsoft/bloom-deepspeed-inference-fp16",
124+
"microsoft/bloom-deepspeed-inference-int8",
125+
"EleutherAI/gpt-neox-20b"
126+
])
89127

128+
return model_ids
90129

91-
# TODO read this from a file containing list of files supported for each task
92-
def _get_supported_models_name(task):
130+
131+
def get_supported_models(task):
93132
supported_models = []
94133
task_name = get_task_name(task)
95134

@@ -109,16 +148,18 @@ def _get_supported_models_name(task):
109148

110149

111150
def check_if_task_and_model_is_supported(task, model_name):
112-
supported_models = _get_supported_models_name(task)
113-
assert model_name in supported_models, f"{task} only supports {supported_models}"
151+
supported_models = get_supported_models(task)
152+
assert (
153+
model_name in supported_models
154+
), f"{task} is not supported by {model_name}. This task is supported by {len(supported_models)} other models. See which models with `mii.get_supported_models(mii.{task})`."
114155

115156

116157
def check_if_task_and_model_is_valid(task, model_name):
117158
task_name = get_task_name(task)
118159
valid_task_models = _get_hf_models_by_type(None, task_name)
119160
assert (
120161
model_name in valid_task_models
121-
), f"{task_name} only supports {valid_task_models}"
162+
), f"{task_name} is not supported by {model_name}. This task is supported by {len(valid_task_models)} other models. See which models with `mii.get_supported_models(mii.{task})`."
122163

123164

124165
def full_model_path(model_path):

0 commit comments

Comments
 (0)