Skip to content

Commit 575eb40

Browse files
committed
Merge branch 'upstream' into concedo_experimental
# Conflicts: # docs/multimodal/minicpmv4.0.md # examples/model-conversion/Makefile # examples/model-conversion/README.md # examples/model-conversion/logits.cpp # examples/model-conversion/scripts/causal/modelcard.template # examples/model-conversion/scripts/utils/hf-create-model.py # ggml/src/ggml-opencl/ggml-opencl.cpp # tests/test-backend-ops.cpp # tools/batched-bench/batched-bench.cpp
2 parents 75c919c + 79a5462 commit 575eb40

File tree

21 files changed

+756
-438
lines changed

21 files changed

+756
-438
lines changed

convert_hf_to_gguf.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3159,7 +3159,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
31593159
yield from super().modify_tensors(data_torch, name, bid)
31603160

31613161

3162-
@ModelBase.register("Ernie4_5_ForCausalLM")
3162+
@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
31633163
class Ernie4_5Model(TextModel):
31643164
model_arch = gguf.MODEL_ARCH.ERNIE4_5
31653165

@@ -6254,9 +6254,11 @@ def prepare_tensors(self):
62546254
raise ValueError(f"Unprocessed experts: {experts}")
62556255

62566256

6257-
@ModelBase.register("DeepseekV2ForCausalLM")
6258-
@ModelBase.register("DeepseekV3ForCausalLM")
6259-
@ModelBase.register("KimiVLForConditionalGeneration")
6257+
@ModelBase.register(
6258+
"DeepseekV2ForCausalLM",
6259+
"DeepseekV3ForCausalLM",
6260+
"KimiVLForConditionalGeneration",
6261+
)
62606262
class DeepseekV2Model(TextModel):
62616263
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
62626264

@@ -8507,6 +8509,43 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", "
85078509
return "mm.2.weight"
85088510
return super().map_tensor_name(name, try_suffixes)
85098511

8512+
8513+
@ModelBase.register("KimiVLForConditionalGeneration")
8514+
class KimiVLModel(MmprojModel):
8515+
def __init__(self, *args, **kwargs):
8516+
super().__init__(*args, **kwargs)
8517+
assert self.hparams_vision is not None
8518+
self.hparams_vision["image_size"] = 64 * 14 # for compatibility
8519+
8520+
def set_gguf_parameters(self):
8521+
super().set_gguf_parameters()
8522+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIVL)
8523+
self.gguf_writer.add_vision_use_gelu(True)
8524+
self.gguf_writer.add_vision_projector_scale_factor(2)
8525+
# eps is the same as pytorch's default value
8526+
assert self.hparams_vision is not None
8527+
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5))
8528+
8529+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8530+
del bid # unused
8531+
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
8532+
8533+
if is_vision_tensor:
8534+
if "pos_emb.weight" in name:
8535+
data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2])
8536+
elif "wqkv" in name:
8537+
split_dim = 0 if "weight" in name else -1
8538+
wq, wk, wv = data_torch.chunk(3, dim=split_dim)
8539+
return [
8540+
(self.map_tensor_name(name.replace("wqkv", "wq")), wq),
8541+
(self.map_tensor_name(name.replace("wqkv", "wk")), wk),
8542+
(self.map_tensor_name(name.replace("wqkv", "wv")), wv)
8543+
]
8544+
8545+
return [(self.map_tensor_name(name), data_torch)]
8546+
8547+
return [] # skip other tensors
8548+
85108549
###### CONVERSION LOGIC ######
85118550

85128551

ggml/src/ggml-cuda/common.cuh

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -424,16 +424,28 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
424424

425425
template<int width = WARP_SIZE>
426426
static __device__ __forceinline__ int warp_reduce_all(int x) {
427-
#ifdef GGML_USE_HIP
427+
if (width == ggml_cuda_get_physical_warp_size()) {
428+
return __all_sync(0xffffffff, x);
429+
} else {
428430
#pragma unroll
429-
for (int offset = width/2; offset > 0; offset >>= 1) {
430-
x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
431+
for (int offset = width/2; offset > 0; offset >>= 1) {
432+
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
433+
}
434+
return x;
435+
}
436+
}
437+
438+
template<int width = WARP_SIZE>
439+
static __device__ __forceinline__ int warp_reduce_any(int x) {
440+
if (width == ggml_cuda_get_physical_warp_size()) {
441+
return __any_sync(0xffffffff, x);
442+
} else {
443+
#pragma unroll
444+
for (int offset = width/2; offset > 0; offset >>= 1) {
445+
x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
446+
}
447+
return x;
431448
}
432-
return x;
433-
#else
434-
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
435-
return __all_sync(0xffffffff, x);
436-
#endif // GGML_USE_HIP
437449
}
438450

439451
template<int width = WARP_SIZE>

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ static ggml_cuda_device_info ggml_cuda_init() {
207207
//#endif // GGML_CUDA_FORCE_CUBLAS
208208
GGML_LOG_INFO("---\nInitializing CUDA/HIP, please wait, the following step may take a few minutes (only for first launch)...\n---\n");
209209
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
210+
211+
std::vector<std::pair<int, std::string>> turing_devices_without_mma;
210212
for (int id = 0; id < info.device_count; ++id) {
211213
int device_vmm = 0;
212214

@@ -264,7 +266,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
264266
info.devices[id].cc = 100*prop.major + 10*prop.minor;
265267
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
266268
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
267-
#endif // defined(GGML_USE_HIP)
269+
std::string device_name(prop.name);
270+
if (device_name == "NVIDIA GeForce MX450") {
271+
turing_devices_without_mma.push_back({ id, device_name });
272+
} else if (device_name == "NVIDIA GeForce MX550") {
273+
turing_devices_without_mma.push_back({ id, device_name });
274+
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
275+
turing_devices_without_mma.push_back({ id, device_name });
276+
}
277+
#endif // defined(GGML_USE_HIP)
278+
}
279+
280+
if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) {
281+
GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n");
282+
for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) {
283+
GGML_LOG_INFO(
284+
" Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str());
285+
}
286+
GGML_LOG_INFO(
287+
"Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n");
268288
}
269289

270290
for (int id = 0; id < info.device_count; ++id) {

0 commit comments

Comments
 (0)