Skip to content

Commit 2358566

Browse files
feat: deprecate TRANSFORMERS_CACHE, use HF_HUB_CACHE everywhere (IBM#89)
#### Motivation `TRANSFORMERS_CACHE` is deprecated (slated for removal with Transformers v5) and `HUGGINGFACE_HUB_CACHE` is legacy. This PR standardizes on `HF_HUB_CACHE` to configure the cache. Also, not all operations/CLI commands were correctly pulling from `TRANSFORMERS_CACHE` so we have been setting both env vars anyways. After this change, everything should work with only `HF_HUB_CACHE`. #### Modifications - Launcher inspects HF_HUB_CACHE to determine the model cache path - TRANSFORMERS_CACHE and HUGGINGFACE_HUB_CACHE are still checked as well, but a deprecation warning is printed - if multiple values are present and do not match, raise an error - Launcher can resolve the default HF_HUB_CACHE so it does not need to be set (HF_HOME or its default can be used instead) - Server CLI checks TRANSFORMERS_CACHE and prints a warning if it is set - Server CLI returns an error if both TRANSFORMERS_CACHE and HF_HUB_CACHE are set with different values --------- Signed-off-by: Travis Johnson <[email protected]>
1 parent ddc56ee commit 2358566

File tree

6 files changed

+115
-47
lines changed

6 files changed

+115
-47
lines changed

Makefile

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,15 @@ check-test-image:
8585
integration-tests: check-test-image ## Run integration tests
8686
mkdir -p /tmp/transformers_cache
8787
docker run --rm -v /tmp/transformers_cache:/transformers_cache \
88-
-e HUGGINGFACE_HUB_CACHE=/transformers_cache \
89-
-e TRANSFORMERS_CACHE=/transformers_cache \
88+
-e HF_HUB_CACHE=/transformers_cache \
9089
-w /usr/src/integration_tests \
9190
$(TEST_IMAGE_NAME) make test
9291

9392
.PHONY: python-tests
9493
python-tests: check-test-image ## Run Python tests
9594
mkdir -p /tmp/transformers_cache
9695
docker run --rm -v /tmp/transformers_cache:/transformers_cache \
97-
-e HUGGINGFACE_HUB_CACHE=/transformers_cache \
98-
-e TRANSFORMERS_CACHE=/transformers_cache \
96+
-e HF_HUB_CACHE=/transformers_cache \
9997
$(TEST_IMAGE_NAME) pytest -sv --ignore=server/tests/test_utils.py server/tests
10098

10199
.PHONY: clean

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,15 @@ cd deployment
7171

7272
### Model configuration
7373

74-
When deploying TGIS, the `MODEL_NAME` environment variable can contain either the full name of a model on the Hugging Face hub (such as `google/flan-ul2`) or an absolute path to a (mounted) model directory inside the container. In the former case, the `TRANSFORMERS_CACHE` and `HUGGINGFACE_HUB_CACHE` environment variables should be set to the path of a mounted directory containing a local HF hub model cache, see [this](deployment/base/patches/pvcs/pvc.yaml) kustomize patch as an example.
74+
When deploying TGIS, the `MODEL_NAME` environment variable can contain either the full name of a model on the Hugging Face hub (such as `google/flan-ul2`) or an absolute path to a (mounted) model directory inside the container. In the former case, the `HF_HUB_CACHE` environment variable should be set to the path of a mounted directory containing a local HF hub model cache, see [this](deployment/base/patches/pvcs/pvc.yaml) kustomize patch as an example.
7575

7676
### Downloading model weights
7777

7878
TGIS will not download model data at runtime. To populate the local HF hub cache with models so that it can be used per above, the image can be run with the following command:
7979
```shell
8080
text-generation-server download-weights model_name
8181
```
82-
where `model_name` is the name of the model on the HF hub. Ensure that it's run with the same mounted directory and `TRANSFORMERS_CACHE` and `HUGGINGFACE_HUB_CACHE` environment variables, and that it has write access to this mounted filesystem.
82+
where `model_name` is the name of the model on the HF hub. Ensure that it's run with the same mounted directory and the `HF_HUB_CACHE` environment variable, and that it has write access to this mounted filesystem.
8383

8484
This will attempt to download weights in `.safetensors` format, and if those aren't in the HF hub will download pytorch `.bin` weights and then convert them to `.safetensors`.
8585

integration_tests/text_generation_tests/test_server.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def start_server(
2929
master_port: int,
3030
timeout=30,
3131
model_path=None,
32-
include_cache_env_vars=True,
32+
env=None,
3333
output_special_tokens=False,
3434
):
3535
# Download weights to the cache first
@@ -66,13 +66,12 @@ def start_server(
6666
if output_special_tokens:
6767
args.append("--output-special-tokens")
6868

69-
env = os.environ.copy()
69+
if env is None:
70+
env = os.environ.copy()
71+
7072
env["RUST_BACKTRACE"] = "full"
7173
env["ESTIMATE_MEMORY"] = "manual"
7274
env["PREFIX_STORE_PATH"] = os.path.join(TESTS_DIR, "prompt_prefixes")
73-
if not include_cache_env_vars:
74-
env.pop("TRANSFORMERS_CACHE", None)
75-
env.pop("HUGGING_FACE_HUB_CACHE", None)
7675

7776
# Start the process
7877
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env)
@@ -455,17 +454,21 @@ async def test_time_limit_stopping(server_fixture):
455454

456455
# Test loading when an explicit local path is provided
457456
def test_explicit_path():
458-
# Test with and without providing TRANSFORMERS_CACHE env var
459-
path = glob.glob(f'{os.environ["TRANSFORMERS_CACHE"]}/models--bigscience--mt0-small/snapshots/*')[0]
460-
for include_env_vars in [False, True]:
457+
path = glob.glob(f'{os.environ["HF_HUB_CACHE"]}/models--bigscience--mt0-small/snapshots/*')[0]
458+
459+
# Test with and without providing HF_HUB_CACHE
460+
env_with = os.environ.copy()
461+
env_without = os.environ.copy()
462+
env_without.pop("HF_HUB_CACHE", None)
463+
for env in [env_with, env_without]:
461464
p = start_server(
462465
"bigscience/mt0-small",
463466
".bin,.json,.model",
464467
1,
465468
3000,
466469
29502,
467470
model_path=path,
468-
include_cache_env_vars=include_env_vars,
471+
env=env,
469472
)
470473
try:
471474
async def test_model_info() -> pb2.ModelInfoResponse:
@@ -481,6 +484,32 @@ async def test_model_info() -> pb2.ModelInfoResponse:
481484

482485
assert p.wait(8.0) == 0
483486

487+
# Test loading with only TRANSFORMERS_CACHE set
488+
def test_transformers_cache():
489+
env = os.environ.copy()
490+
env["TRANSFORMERS_CACHE"] = env.pop("HF_HUB_CACHE")
491+
p = start_server(
492+
"bigscience/mt0-small",
493+
".bin,.json,.model",
494+
1,
495+
3000,
496+
29502,
497+
env=env,
498+
)
499+
try:
500+
async def test_model_info() -> pb2.ModelInfoResponse:
501+
async with grpc.aio.insecure_channel('localhost:8033') as channel:
502+
return await gpb2.GenerationServiceStub(channel).ModelInfo(pb2.ModelInfoRequest(model_id="unused"))
503+
504+
result = asyncio.get_event_loop().run_until_complete(test_model_info())
505+
assert result.max_sequence_length == 200
506+
assert result.max_new_tokens == 169
507+
assert result.model_kind == pb2.ModelInfoResponse.ModelKind.ENCODER_DECODER
508+
finally:
509+
p.terminate()
510+
511+
assert p.wait(8.0) == 0
512+
484513

485514
# To avoid errors related to event loop shutdown timing
486515
@pytest.fixture(scope="session")

launcher/src/main.rs

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,54 @@ fn main() -> ExitCode {
139139
// Determine number of shards based on command line arg and env vars
140140
let num_shard = find_num_shards(args.num_shard);
141141

142-
let config_path: PathBuf = resolve_config_path(&args.model_name, args.revision.as_deref())
142+
// Determine the model cache path and resolve from possible env vars:
143+
// - HF_HUB_CACHE
144+
// - TRANSFORMERS_CACHE (deprecated)
145+
// - HUGGINGFACE_HUB_CACHE (deprecated)
146+
//
147+
// We allow multiple to be set for compatibility, but then the values must match.
148+
149+
let mut cache_env_var: String = "".to_string();
150+
let mut cache_env_value: String = "".to_string();
151+
152+
if let Ok(t) = env::var("HF_HUB_CACHE") {
153+
cache_env_var = "HF_HUB_CACHE".into();
154+
cache_env_value = t.into();
155+
}
156+
157+
for deprecated_env_var in vec!["TRANSFORMERS_CACHE", "HUGGINGFACE_HUB_CACHE"] {
158+
match (
159+
env::var(deprecated_env_var),
160+
!cache_env_var.is_empty(),
161+
) {
162+
(Ok(t), false) => {
163+
cache_env_var = deprecated_env_var.into();
164+
cache_env_value = t.into();
165+
},
166+
(Ok(t), true) if t != cache_env_value => panic!(
167+
"{deprecated_env_var} and {cache_env_var} env vars can't be set to different values"
168+
),
169+
(Ok(_), true) => warn!(
170+
"{deprecated_env_var} is deprecated and should not be used. Use HF_HUB_CACHE instead."
171+
),
172+
_ => (),
173+
}
174+
}
175+
176+
// ensure HF_HUB_CACHE is set for downstream usage
177+
// default value to match huggingface_hub
178+
// REF: https://github.com/huggingface/huggingface_hub/blob/5ff2d150d121d04799b78bc08f2343c21b8f07a9/docs/source/en/package_reference/environment_variables.md?plain=1#L32
179+
let cache_path = if !cache_env_value.is_empty() {
180+
PathBuf::from(cache_env_value)
181+
} else if let Ok(hf_home) = env::var("HF_HOME") {
182+
PathBuf::from(hf_home).join("hub")
183+
} else if let Ok(home) = env::var("HOME") {
184+
PathBuf::from(home).join(".cache").join("huggingface").join("hub")
185+
} else {
186+
PathBuf::new()
187+
};
188+
189+
let config_path: PathBuf = resolve_config_path(cache_path.clone(), &args.model_name, args.revision.as_deref())
143190
.expect("Failed to resolve config path")
144191
.into();
145192

@@ -223,15 +270,18 @@ fn main() -> ExitCode {
223270
let (status_sender, status_receiver) = mpsc::channel();
224271

225272
// Start shard processes
273+
let cache_path_string = cache_path.into_os_string();
226274
for rank in 0..num_shard {
227275
let args = args.clone();
276+
let cache_path = cache_path_string.clone();
228277
let deployment_framework = deployment_framework.to_string();
229278
let status_sender = status_sender.clone();
230279
let shutdown = shutdown.clone();
231280
let shutdown_sender = shutdown_sender.clone();
232281
thread::spawn(move || {
233282
shard_manager(
234283
args.model_name,
284+
cache_path,
235285
args.revision,
236286
deployment_framework,
237287
args.dtype.or(args.dtype_str),
@@ -548,6 +598,7 @@ enum ShardStatus {
548598
#[allow(clippy::too_many_arguments)]
549599
fn shard_manager(
550600
model_name: String,
601+
cache_path: OsString,
551602
revision: Option<String>,
552603
deployment_framework: String,
553604
dtype: Option<String>,
@@ -620,19 +671,6 @@ fn shard_manager(
620671
// Copy current process env
621672
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
622673

623-
// Fix up TRANSFORMERS_CACHE and HUGGINGFACE_HUB_CACHE env vars
624-
match (
625-
env::var("TRANSFORMERS_CACHE"),
626-
env::var("HUGGINGFACE_HUB_CACHE"),
627-
) {
628-
(Ok(t), Err(_)) => env.push(("HUGGINGFACE_HUB_CACHE".into(), t.into())),
629-
(Err(_), Ok(h)) => env.push(("TRANSFORMERS_CACHE".into(), h.into())),
630-
(Ok(t), Ok(h)) if t != h => panic!(
631-
"TRANSFORMERS_CACHE and HUGGINGFACE_HUB_CACHE env vars can't be set to different values"
632-
),
633-
_ => (),
634-
}
635-
636674
if let Some(alloc_conf) = cuda_alloc_conf {
637675
if alloc_conf.is_empty() {
638676
// Remove it from env
@@ -665,6 +703,9 @@ fn shard_manager(
665703
// Ensure offline-only
666704
env.push(("HF_HUB_OFFLINE".into(), "1".into()));
667705

706+
// Ensure that we set the standard cache variable
707+
env.push(("HF_HUB_CACHE".into(), cache_path.into()));
708+
668709
// Start process
669710
info!("Starting shard {rank}");
670711
let mut p = match Command::new("text-generation-server")
@@ -776,18 +817,13 @@ fn write_termination_log(msg: &str) -> Result<(), io::Error> {
776817
Ok(())
777818
}
778819

779-
fn resolve_config_path(model_name: &str, revision: Option<&str>) -> Result<String, io::Error> {
780-
let cache = env::var("TRANSFORMERS_CACHE")
781-
.or_else(|_| env::var("HUGGINGFACE_HUB_CACHE"))
782-
.ok();
783-
let mut model_dir = cache
784-
.as_ref()
785-
.map(|c| Path::new(&c).join(format!("models--{}", model_name.replace('/', "--"))));
786-
if let Some(ref d) = model_dir {
787-
if !d.try_exists()? {
788-
model_dir = None;
789-
}
790-
}
820+
fn resolve_config_path(cache_path: PathBuf, model_name: &str, revision: Option<&str>) -> Result<String, io::Error> {
821+
let model_hf_cache_dir = cache_path.join(format!("models--{}", model_name.replace('/', "--")));
822+
let model_dir = if model_hf_cache_dir.try_exists()? {
823+
Some(model_hf_cache_dir)
824+
} else {
825+
None
826+
};
791827
if let Some(dir) = model_dir {
792828
let revision = revision.unwrap_or("main");
793829
let ref_path = dir.join("refs").join(revision);
@@ -811,11 +847,7 @@ fn resolve_config_path(model_name: &str, revision: Option<&str>) -> Result<Strin
811847
if try_path.try_exists()? {
812848
Ok(try_path.to_string_lossy().into())
813849
} else {
814-
let message = if cache.is_none() {
815-
format!("Model path {model_name} not found (TRANSFORMERS_CACHE env var not set)")
816-
} else {
817-
format!("Model {model_name} not found in local cache")
818-
};
850+
let message = format!("Model {model_name} not found");
819851
error!(message);
820852
Err(io::Error::new(ErrorKind::NotFound, message))
821853
}

server/text_generation_server/cli.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,4 +252,14 @@ def convert_to_fast_tokenizer(
252252

253253

254254
if __name__ == "__main__":
255+
256+
# Use of TRANSFORMERS_CACHE is deprecated
257+
if (tc := os.getenv("TRANSFORMERS_CACHE")) is not None:
258+
print("WARNING: Using TRANSFORMERS_CACHE is deprecated. Use HF_HUB_CACHE instead.")
259+
hc = os.getenv("HF_HUB_CACHE")
260+
if tc != hc:
261+
raise ValueError("Conflicting model cache values between TRANSFORMERS_CACHE and HF_HUB_CACHE")
262+
if hc is None:
263+
os.putenv("HF_HUB_CACHE", tc)
264+
255265
app()

server/text_generation_server/utils/hub.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def get_model_path(model_name: str, revision: Optional[str] = None):
7979
try:
8080
config_path = try_to_load_from_cache(
8181
model_name, config_file,
82-
cache_dir=os.getenv("TRANSFORMERS_CACHE"), # will fall back to HUGGINGFACE_HUB_CACHE
8382
revision=revision,
8483
)
8584
if config_path is not None:

0 commit comments

Comments
 (0)