Skip to content

Commit c91115d

Browse files
Merge pull request IBM#77 from IBM/main
[pull] main from IBM:main
2 parents 2d78aed + fb23def commit c91115d

File tree

11 files changed

+157
-63
lines changed

11 files changed

+157
-63
lines changed

Makefile

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,15 @@ check-test-image:
8585
integration-tests: check-test-image ## Run integration tests
8686
mkdir -p /tmp/transformers_cache
8787
docker run --rm -v /tmp/transformers_cache:/transformers_cache \
88-
-e HUGGINGFACE_HUB_CACHE=/transformers_cache \
89-
-e TRANSFORMERS_CACHE=/transformers_cache \
88+
-e HF_HUB_CACHE=/transformers_cache \
9089
-w /usr/src/integration_tests \
9190
$(TEST_IMAGE_NAME) make test
9291

9392
.PHONY: python-tests
9493
python-tests: check-test-image ## Run Python tests
9594
mkdir -p /tmp/transformers_cache
9695
docker run --rm -v /tmp/transformers_cache:/transformers_cache \
97-
-e HUGGINGFACE_HUB_CACHE=/transformers_cache \
98-
-e TRANSFORMERS_CACHE=/transformers_cache \
96+
-e HF_HUB_CACHE=/transformers_cache \
9997
$(TEST_IMAGE_NAME) pytest -sv --ignore=server/tests/test_utils.py server/tests
10098

10199
.PHONY: clean

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,15 @@ cd deployment
7171

7272
### Model configuration
7373

74-
When deploying TGIS, the `MODEL_NAME` environment variable can contain either the full name of a model on the Hugging Face hub (such as `google/flan-ul2`) or an absolute path to a (mounted) model directory inside the container. In the former case, the `TRANSFORMERS_CACHE` and `HUGGINGFACE_HUB_CACHE` environment variables should be set to the path of a mounted directory containing a local HF hub model cache, see [this](deployment/base/patches/pvcs/pvc.yaml) kustomize patch as an example.
74+
When deploying TGIS, the `MODEL_NAME` environment variable can contain either the full name of a model on the Hugging Face hub (such as `google/flan-ul2`) or an absolute path to a (mounted) model directory inside the container. In the former case, the `HF_HUB_CACHE` environment variable should be set to the path of a mounted directory containing a local HF hub model cache, see [this](deployment/base/patches/pvcs/pvc.yaml) kustomize patch as an example.
7575

7676
### Downloading model weights
7777

7878
TGIS will not download model data at runtime. To populate the local HF hub cache with models so that it can be used per above, the image can be run with the following command:
7979
```shell
8080
text-generation-server download-weights model_name
8181
```
82-
where `model_name` is the name of the model on the HF hub. Ensure that it's run with the same mounted directory and `TRANSFORMERS_CACHE` and `HUGGINGFACE_HUB_CACHE` environment variables, and that it has write access to this mounted filesystem.
82+
where `model_name` is the name of the model on the HF hub. Ensure that it's run with the same mounted directory and the `HF_HUB_CACHE` environment variable, and that it has write access to this mounted filesystem.
8383

8484
This will attempt to download weights in `.safetensors` format, and if those aren't in the HF hub will download pytorch `.bin` weights and then convert them to `.safetensors`.
8585

integration_tests/text_generation_tests/test_server.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def start_server(
2929
master_port: int,
3030
timeout=30,
3131
model_path=None,
32-
include_cache_env_vars=True,
32+
env=None,
3333
output_special_tokens=False,
3434
):
3535
# Download weights to the cache first
@@ -66,13 +66,12 @@ def start_server(
6666
if output_special_tokens:
6767
args.append("--output-special-tokens")
6868

69-
env = os.environ.copy()
69+
if env is None:
70+
env = os.environ.copy()
71+
7072
env["RUST_BACKTRACE"] = "full"
7173
env["ESTIMATE_MEMORY"] = "manual"
7274
env["PREFIX_STORE_PATH"] = os.path.join(TESTS_DIR, "prompt_prefixes")
73-
if not include_cache_env_vars:
74-
env.pop("TRANSFORMERS_CACHE", None)
75-
env.pop("HUGGING_FACE_HUB_CACHE", None)
7675

7776
# Start the process
7877
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env)
@@ -455,17 +454,21 @@ async def test_time_limit_stopping(server_fixture):
455454

456455
# Test loading when an explicit local path is provided
457456
def test_explicit_path():
458-
# Test with and without providing TRANSFORMERS_CACHE env var
459-
path = glob.glob(f'{os.environ["TRANSFORMERS_CACHE"]}/models--bigscience--mt0-small/snapshots/*')[0]
460-
for include_env_vars in [False, True]:
457+
path = glob.glob(f'{os.environ["HF_HUB_CACHE"]}/models--bigscience--mt0-small/snapshots/*')[0]
458+
459+
# Test with and without providing HF_HUB_CACHE
460+
env_with = os.environ.copy()
461+
env_without = os.environ.copy()
462+
env_without.pop("HF_HUB_CACHE", None)
463+
for env in [env_with, env_without]:
461464
p = start_server(
462465
"bigscience/mt0-small",
463466
".bin,.json,.model",
464467
1,
465468
3000,
466469
29502,
467470
model_path=path,
468-
include_cache_env_vars=include_env_vars,
471+
env=env,
469472
)
470473
try:
471474
async def test_model_info() -> pb2.ModelInfoResponse:
@@ -481,6 +484,32 @@ async def test_model_info() -> pb2.ModelInfoResponse:
481484

482485
assert p.wait(8.0) == 0
483486

487+
# Test loading with only TRANSFORMERS_CACHE set
488+
def test_transformers_cache():
489+
env = os.environ.copy()
490+
env["TRANSFORMERS_CACHE"] = env.pop("HF_HUB_CACHE")
491+
p = start_server(
492+
"bigscience/mt0-small",
493+
".bin,.json,.model",
494+
1,
495+
3000,
496+
29502,
497+
env=env,
498+
)
499+
try:
500+
async def test_model_info() -> pb2.ModelInfoResponse:
501+
async with grpc.aio.insecure_channel('localhost:8033') as channel:
502+
return await gpb2.GenerationServiceStub(channel).ModelInfo(pb2.ModelInfoRequest(model_id="unused"))
503+
504+
result = asyncio.get_event_loop().run_until_complete(test_model_info())
505+
assert result.max_sequence_length == 200
506+
assert result.max_new_tokens == 169
507+
assert result.model_kind == pb2.ModelInfoResponse.ModelKind.ENCODER_DECODER
508+
finally:
509+
p.terminate()
510+
511+
assert p.wait(8.0) == 0
512+
484513

485514
# To avoid errors related to event loop shutdown timing
486515
@pytest.fixture(scope="session")

launcher/src/main.rs

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,54 @@ fn main() -> ExitCode {
139139
// Determine number of shards based on command line arg and env vars
140140
let num_shard = find_num_shards(args.num_shard);
141141

142-
let config_path: PathBuf = resolve_config_path(&args.model_name, args.revision.as_deref())
142+
// Determine the model cache path and resolve from possible env vars:
143+
// - HF_HUB_CACHE
144+
// - TRANSFORMERS_CACHE (deprecated)
145+
// - HUGGINGFACE_HUB_CACHE (deprecated)
146+
//
147+
// We allow multiple to be set for compatibility, but then the values must match.
148+
149+
let mut cache_env_var: String = "".to_string();
150+
let mut cache_env_value: String = "".to_string();
151+
152+
if let Ok(t) = env::var("HF_HUB_CACHE") {
153+
cache_env_var = "HF_HUB_CACHE".into();
154+
cache_env_value = t.into();
155+
}
156+
157+
for deprecated_env_var in vec!["TRANSFORMERS_CACHE", "HUGGINGFACE_HUB_CACHE"] {
158+
match (
159+
env::var(deprecated_env_var),
160+
!cache_env_var.is_empty(),
161+
) {
162+
(Ok(t), false) => {
163+
cache_env_var = deprecated_env_var.into();
164+
cache_env_value = t.into();
165+
},
166+
(Ok(t), true) if t != cache_env_value => panic!(
167+
"{deprecated_env_var} and {cache_env_var} env vars can't be set to different values"
168+
),
169+
(Ok(_), true) => warn!(
170+
"{deprecated_env_var} is deprecated and should not be used. Use HF_HUB_CACHE instead."
171+
),
172+
_ => (),
173+
}
174+
}
175+
176+
// ensure HF_HUB_CACHE is set for downstream usage
177+
// default value to match huggingface_hub
178+
// REF: https://github.com/huggingface/huggingface_hub/blob/5ff2d150d121d04799b78bc08f2343c21b8f07a9/docs/source/en/package_reference/environment_variables.md?plain=1#L32
179+
let cache_path = if !cache_env_value.is_empty() {
180+
PathBuf::from(cache_env_value)
181+
} else if let Ok(hf_home) = env::var("HF_HOME") {
182+
PathBuf::from(hf_home).join("hub")
183+
} else if let Ok(home) = env::var("HOME") {
184+
PathBuf::from(home).join(".cache").join("huggingface").join("hub")
185+
} else {
186+
PathBuf::new()
187+
};
188+
189+
let config_path: PathBuf = resolve_config_path(cache_path.clone(), &args.model_name, args.revision.as_deref())
143190
.expect("Failed to resolve config path")
144191
.into();
145192

@@ -223,15 +270,18 @@ fn main() -> ExitCode {
223270
let (status_sender, status_receiver) = mpsc::channel();
224271

225272
// Start shard processes
273+
let cache_path_string = cache_path.into_os_string();
226274
for rank in 0..num_shard {
227275
let args = args.clone();
276+
let cache_path = cache_path_string.clone();
228277
let deployment_framework = deployment_framework.to_string();
229278
let status_sender = status_sender.clone();
230279
let shutdown = shutdown.clone();
231280
let shutdown_sender = shutdown_sender.clone();
232281
thread::spawn(move || {
233282
shard_manager(
234283
args.model_name,
284+
cache_path,
235285
args.revision,
236286
deployment_framework,
237287
args.dtype.or(args.dtype_str),
@@ -548,6 +598,7 @@ enum ShardStatus {
548598
#[allow(clippy::too_many_arguments)]
549599
fn shard_manager(
550600
model_name: String,
601+
cache_path: OsString,
551602
revision: Option<String>,
552603
deployment_framework: String,
553604
dtype: Option<String>,
@@ -620,19 +671,6 @@ fn shard_manager(
620671
// Copy current process env
621672
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
622673

623-
// Fix up TRANSFORMERS_CACHE and HUGGINGFACE_HUB_CACHE env vars
624-
match (
625-
env::var("TRANSFORMERS_CACHE"),
626-
env::var("HUGGINGFACE_HUB_CACHE"),
627-
) {
628-
(Ok(t), Err(_)) => env.push(("HUGGINGFACE_HUB_CACHE".into(), t.into())),
629-
(Err(_), Ok(h)) => env.push(("TRANSFORMERS_CACHE".into(), h.into())),
630-
(Ok(t), Ok(h)) if t != h => panic!(
631-
"TRANSFORMERS_CACHE and HUGGINGFACE_HUB_CACHE env vars can't be set to different values"
632-
),
633-
_ => (),
634-
}
635-
636674
if let Some(alloc_conf) = cuda_alloc_conf {
637675
if alloc_conf.is_empty() {
638676
// Remove it from env
@@ -665,6 +703,9 @@ fn shard_manager(
665703
// Ensure offline-only
666704
env.push(("HF_HUB_OFFLINE".into(), "1".into()));
667705

706+
// Ensure that we set the standard cache variable
707+
env.push(("HF_HUB_CACHE".into(), cache_path.into()));
708+
668709
// Start process
669710
info!("Starting shard {rank}");
670711
let mut p = match Command::new("text-generation-server")
@@ -776,18 +817,13 @@ fn write_termination_log(msg: &str) -> Result<(), io::Error> {
776817
Ok(())
777818
}
778819

779-
fn resolve_config_path(model_name: &str, revision: Option<&str>) -> Result<String, io::Error> {
780-
let cache = env::var("TRANSFORMERS_CACHE")
781-
.or_else(|_| env::var("HUGGINGFACE_HUB_CACHE"))
782-
.ok();
783-
let mut model_dir = cache
784-
.as_ref()
785-
.map(|c| Path::new(&c).join(format!("models--{}", model_name.replace('/', "--"))));
786-
if let Some(ref d) = model_dir {
787-
if !d.try_exists()? {
788-
model_dir = None;
789-
}
790-
}
820+
fn resolve_config_path(cache_path: PathBuf, model_name: &str, revision: Option<&str>) -> Result<String, io::Error> {
821+
let model_hf_cache_dir = cache_path.join(format!("models--{}", model_name.replace('/', "--")));
822+
let model_dir = if model_hf_cache_dir.try_exists()? {
823+
Some(model_hf_cache_dir)
824+
} else {
825+
None
826+
};
791827
if let Some(dir) = model_dir {
792828
let revision = revision.unwrap_or("main");
793829
let ref_path = dir.join("refs").join(revision);
@@ -811,11 +847,7 @@ fn resolve_config_path(model_name: &str, revision: Option<&str>) -> Result<Strin
811847
if try_path.try_exists()? {
812848
Ok(try_path.to_string_lossy().into())
813849
} else {
814-
let message = if cache.is_none() {
815-
format!("Model path {model_name} not found (TRANSFORMERS_CACHE env var not set)")
816-
} else {
817-
format!("Model {model_name} not found in local cache")
818-
};
850+
let message = format!("Model {model_name} not found");
819851
error!(message);
820852
Err(io::Error::new(ErrorKind::NotFound, message))
821853
}

server/text_generation_server/cli.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,4 +252,14 @@ def convert_to_fast_tokenizer(
252252

253253

254254
if __name__ == "__main__":
255+
256+
# Use of TRANSFORMERS_CACHE is deprecated
257+
if (tc := os.getenv("TRANSFORMERS_CACHE")) is not None:
258+
print("WARNING: Using TRANSFORMERS_CACHE is deprecated. Use HF_HUB_CACHE instead.")
259+
hc = os.getenv("HF_HUB_CACHE")
260+
if tc != hc:
261+
raise ValueError("Conflicting model cache values between TRANSFORMERS_CACHE and HF_HUB_CACHE")
262+
if hc is None:
263+
os.putenv("HF_HUB_CACHE", tc)
264+
255265
app()

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,9 @@ def _load_gqa(config, prefix: str, weights):
156156
assert config.hidden_size % config.num_attention_heads == 0
157157
assert config.num_attention_heads % weights.process_group.size() == 0
158158

159+
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
159160
weight = weights.get_multi_weights_col(
160-
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
161-
quantize=config.quantize,
162-
dim=0
161+
prefixes=prefixes, quantize=config.quantize, dim=0
163162
)
164163

165164
if config.quantize != "gptq":
@@ -173,7 +172,12 @@ def _load_gqa(config, prefix: str, weights):
173172
config.hidden_size,
174173
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
175174

176-
return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))
175+
if config.attention_bias:
176+
bias = torch.cat([weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes], dim=0)
177+
else:
178+
bias = None
179+
180+
return TensorParallelColumnLinear(get_linear(weight, bias=bias, quantize=config.quantize))
177181

178182

179183
class FlashLlamaAttention(torch.nn.Module):

server/text_generation_server/models/custom_modeling/paged_llama_modeling.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,9 @@ def _load_gqa(config, prefix: str, weights):
156156
assert config.hidden_size % config.num_attention_heads == 0
157157
assert config.num_attention_heads % weights.process_group.size() == 0
158158

159+
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
159160
weight = weights.get_multi_weights_col(
160-
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
161-
quantize=config.quantize,
162-
dim=0
161+
prefixes=prefixes, quantize=config.quantize, dim=0
163162
)
164163

165164
if config.quantize != "gptq":
@@ -173,7 +172,12 @@ def _load_gqa(config, prefix: str, weights):
173172
config.hidden_size,
174173
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
175174

176-
return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))
175+
if config.attention_bias:
176+
bias = torch.cat([weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes], dim=0)
177+
else:
178+
bias = None
179+
180+
return TensorParallelColumnLinear(get_linear(weight, bias=bias, quantize=config.quantize))
177181

178182

179183
class PagedLlamaAttention(torch.nn.Module):

server/text_generation_server/models/paged_causal_lm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,12 +327,15 @@ def __init__(
327327
model_config.num_attention_heads,
328328
model_config.hidden_size,
329329
kv_heads=model_config.num_key_value_heads,
330-
tensor_parallel_size=1,
330+
tensor_parallel_size=self.engine.world_size,
331331
dtype=dtype,
332332
device=self.device,
333333
total_num_gpu_blocks=total_num_gpu_blocks,
334334
)
335335

336+
# log number of free blocks at init
337+
print("[PagedKVCacheManager] number of free blocks: %d" % (len(self.kv_cache_manager.free_blocks)))
338+
336339
@property
337340
def batch_type(self) -> Type[PagedCausalLMBatch]:
338341
return self._batch_type

0 commit comments

Comments
 (0)