Skip to content

Commit 6ec2533

Browse files
committed
CAR check is done elsewhere, as in upstream
Remove dead code
1 parent 4908f2c commit 6ec2533

File tree

2 files changed

+0
-33
lines changed

2 files changed

+0
-33
lines changed

csrc/layernorm_kernels.cu

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,6 @@
1010
#include <hipcub/hipcub.hpp>
1111
#endif
1212

13-
#ifdef USE_ROCM
14-
#include "quantization/fp8/amd/quant_utils.cuh"
15-
#else
16-
#include "quantization/fp8/nvidia/quant_utils.cuh"
17-
#endif
18-
1913
namespace vllm {
2014

2115
// This kernel uses the _f16Vec to represent vectorized data.
@@ -191,26 +185,6 @@ fused_add_rms_norm_kernel(
191185
}
192186
}
193187

194-
/* Function specialization in the case of FP16/BF16 tensors.
195-
Additional optimizations we can make in this case are
196-
packed and vectorized operations, which help with the
197-
memory latency bottleneck. */
198-
199-
template <>
200-
struct Vec<c10::Float8_e4m3fnuz, 8> {
201-
using Type = uint2;
202-
};
203-
204-
template <>
205-
struct Vec<c10::Half, 8> {
206-
using Type = uint4;
207-
};
208-
209-
template <>
210-
struct Vec<c10::BFloat16, 8> {
211-
using Type = bf16_8_t;
212-
};
213-
214188
} // namespace vllm
215189

216190
#define LAUNCH_RMS_NORM(width) \

vllm/config.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1951,13 +1951,6 @@ def __post_init__(self) -> None:
19511951

19521952
self._verify_args()
19531953

1954-
from vllm.platforms.rocm import on_gfx1x
1955-
if on_gfx1x() and self.tensor_parallel_size > 1:
1956-
self.disable_custom_all_reduce = True
1957-
logger.info(
1958-
"Disabled the custom all-reduce kernel because it is not "
1959-
"working correctly on multiple AMD Radeon GPUs.")
1960-
19611954
@property
19621955
def use_ray(self) -> bool:
19631956
return self.distributed_executor_backend == "ray" or (

0 commit comments

Comments
 (0)