Skip to content

Commit c99e0ed

Browse files
committed
Add nvidia-ml-py to set USE_FLASH_ATTENTION based on compute cap
1 parent 99d353c commit c99e0ed

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

tests/requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
pytest==8.3.2
2-
GPUtil==1.4.0
31
docker==7.1.0
2+
GPUtil==1.4.0
3+
pytest==8.3.2
4+
nvidia-ml-py==12.560.30
45
transformers==4.44.2

tests/tgi/test_tgi.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
import requests
99

10+
import pynvml
1011
from docker.types.containers import DeviceRequest
1112
from transformers import AutoTokenizer
1213

@@ -43,6 +44,14 @@ def test_text_generation_inference(
4344

4445
client = docker.from_env()
4546

47+
# If the GPU compute capability is lower than 8.0 (Ampere), then set `USE_FLASH_ATTENTION=false`
48+
pynvml.nvmlInit()
49+
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
50+
compute_capability = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
51+
if compute_capability[0] < 8:
52+
text_generation_launcher_kwargs["USE_FLASH_ATTENTION"] = "false"
53+
pynvml.nvmlShutdown()
54+
4655
logging.info(
4756
f"Starting container for {text_generation_launcher_kwargs.get('MODEL_ID', None)}..."
4857
)

0 commit comments

Comments
 (0)