Skip to content

Commit 6685846

Browse files
committed
fix for cpu-only config server
1 parent c843243 commit 6685846

File tree

9 files changed

+42
-15
lines changed

9 files changed

+42
-15
lines changed

lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import triton.language as tl
55
import math
66
import torch.nn.functional as F
7+
from lightllm.utils.device_utils import get_cuda_device_name, get_device_capability
78

8-
TESLA = "Tesla" in torch.cuda.get_device_name(0)
9-
CUDA_CAPABILITY = torch.cuda.get_device_capability()
9+
TESLA = "Tesla" in get_cuda_device_name()
10+
CUDA_CAPABILITY = get_device_capability()
1011

1112

1213
@triton.jit

lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_fp8.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import triton.language as tl
55
import math
66
import torch.nn.functional as F
7+
from lightllm.utils.device_utils import get_cuda_device_name, get_device_capability
78

8-
TESLA = "Tesla" in torch.cuda.get_device_name(0)
9-
CUDA_CAPABILITY = torch.cuda.get_device_capability()
9+
TESLA = "Tesla" in get_cuda_device_name()
10+
CUDA_CAPABILITY = get_device_capability()
1011

1112

1213
@triton.jit

lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import triton.language as tl
55
import math
66
import torch.nn.functional as F
7+
from lightllm.utils.device_utils import get_cuda_device_name, get_device_capability
78

8-
TESLA = "Tesla" in torch.cuda.get_device_name(0)
9-
CUDA_CAPABILITY = torch.cuda.get_device_capability()
9+
TESLA = "Tesla" in get_cuda_device_name()
10+
CUDA_CAPABILITY = get_device_capability()
1011

1112

1213
@triton.jit

lightllm/models/deepseek2/triton_kernel/sample_kv.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import triton
44
import triton.language as tl
55

6-
TESLA = "Tesla" in torch.cuda.get_device_name(0)
7-
CUDA_CAPABILITY = torch.cuda.get_device_capability()
6+
from lightllm.utils.device_utils import get_cuda_device_name, get_device_capability
7+
8+
TESLA = "Tesla" in get_cuda_device_name()
9+
CUDA_CAPABILITY = get_device_capability()
810

911

1012
@triton.jit

lightllm/models/llama/triton_kernel/context_flashattention_nopad.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import math
88
import torch.nn.functional as F
99

10-
TESLA = "Tesla" in torch.cuda.get_device_name(0)
10+
from lightllm.utils.device_utils import get_cuda_device_name
11+
12+
TESLA = "Tesla" in get_cuda_device_name()
1113

1214

1315
@triton.jit

lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import math
66
import torch.nn.functional as F
77

8-
TESLA = "Tesla" in torch.cuda.get_device_name(0)
8+
from lightllm.utils.device_utils import get_cuda_device_name
9+
10+
TESLA = "Tesla" in get_cuda_device_name()
911

1012

1113
@triton.jit

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import math
66
import torch.nn.functional as F
77

8-
TESLA = "Tesla" in torch.cuda.get_device_name(0)
8+
from lightllm.utils.device_utils import get_cuda_device_name, get_device_capability
9+
910
HOPPER = (
10-
"H100" in torch.cuda.get_device_name(0)
11-
or "H200" in torch.cuda.get_device_name(0)
12-
or "H800" in torch.cuda.get_device_name(0)
13-
or "Hopper" in torch.cuda.get_device_name(0)
11+
"H100" in get_cuda_device_name()
12+
or "H200" in get_cuda_device_name()
13+
or "H800" in get_cuda_device_name()
14+
or "Hopper" in get_cuda_device_name()
1415
)
1516

1617

lightllm/utils/device_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import time
3+
import torch
34
import shutil
45
import subprocess
56
from functools import lru_cache
@@ -8,6 +9,20 @@
89
logger = init_logger(__name__)
910

1011

12+
@lru_cache(maxsize=None)
13+
def get_cuda_device_name():
14+
if not torch.cuda.is_available():
15+
return ""
16+
return torch.cuda.get_device_name(0)
17+
18+
19+
@lru_cache(maxsize=None)
20+
def get_device_capability():
21+
if not torch.cuda.is_available():
22+
return (-1, -1)
23+
return torch.cuda.get_device_capability()
24+
25+
1126
@lru_cache(maxsize=None)
1227
def get_device_sm_count():
1328
import triton

lightllm/utils/envs_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def get_unique_server_name():
2424

2525

2626
def set_cuda_arch(args):
27+
if not torch.cuda.is_available():
28+
return
2729
if args.enable_flashinfer_prefill or args.enable_flashinfer_decode:
2830
capability = torch.cuda.get_device_capability()
2931
arch = f"{capability[0]}.{capability[1]}"

0 commit comments

Comments
 (0)