Skip to content

Commit 1b37cce

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 7f46307 + 1f5accb commit 1b37cce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1672
-174
lines changed

.github/workflows/build-linux-cross.yml

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,49 @@ on:
44
workflow_call:
55

66
jobs:
7-
ubuntu-24-riscv64-cpu-cross:
8-
runs-on: ubuntu-24.04
7+
# ubuntu-24-riscv64-cpu-cross:
8+
# runs-on: ubuntu-24.04
99

10-
steps:
11-
- uses: actions/checkout@v4
12-
- name: Setup Riscv
13-
run: |
14-
sudo dpkg --add-architecture riscv64
10+
# steps:
11+
# - uses: actions/checkout@v4
12+
# - name: Setup Riscv
13+
# run: |
14+
# sudo dpkg --add-architecture riscv64
1515

16-
# Add arch-specific repositories for non-amd64 architectures
17-
cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
18-
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
19-
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
20-
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
21-
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
22-
EOF
16+
# # Add arch-specific repositories for non-amd64 architectures
17+
# cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
18+
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
19+
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
20+
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
21+
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
22+
# EOF
2323

24-
sudo apt-get update || true ;# Prevent failure due to missing URLs.
24+
# sudo apt-get update || true ;# Prevent failure due to missing URLs.
2525

26-
sudo apt-get install -y --no-install-recommends \
27-
build-essential \
28-
gcc-14-riscv64-linux-gnu \
29-
g++-14-riscv64-linux-gnu
26+
# sudo apt-get install -y --no-install-recommends \
27+
# build-essential \
28+
# gcc-14-riscv64-linux-gnu \
29+
# g++-14-riscv64-linux-gnu
3030

31-
- name: Build
32-
run: |
33-
cmake -B build -DLLAMA_CURL=OFF \
34-
-DCMAKE_BUILD_TYPE=Release \
35-
-DGGML_OPENMP=OFF \
36-
-DLLAMA_BUILD_EXAMPLES=ON \
37-
-DLLAMA_BUILD_TOOLS=ON \
38-
-DLLAMA_BUILD_TESTS=OFF \
39-
-DCMAKE_SYSTEM_NAME=Linux \
40-
-DCMAKE_SYSTEM_PROCESSOR=riscv64 \
41-
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
42-
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
43-
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
44-
-DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
45-
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
46-
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
47-
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
31+
# - name: Build
32+
# run: |
33+
# cmake -B build -DLLAMA_CURL=OFF \
34+
# -DCMAKE_BUILD_TYPE=Release \
35+
# -DGGML_OPENMP=OFF \
36+
# -DLLAMA_BUILD_EXAMPLES=ON \
37+
# -DLLAMA_BUILD_TOOLS=ON \
38+
# -DLLAMA_BUILD_TESTS=OFF \
39+
# -DCMAKE_SYSTEM_NAME=Linux \
40+
# -DCMAKE_SYSTEM_PROCESSOR=riscv64 \
41+
# -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
42+
# -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
43+
# -DCMAKE_POSITION_INDEPENDENT_CODE=ON \
44+
# -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
45+
# -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
46+
# -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
47+
# -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
4848

49-
cmake --build build --config Release -j $(nproc)
49+
# cmake --build build --config Release -j $(nproc)
5050

5151
# ubuntu-24-riscv64-vulkan-cross:
5252
# runs-on: ubuntu-24.04

common/arg.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2768,6 +2768,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
27682768
params.image.emplace_back(value);
27692769
}
27702770
).set_examples({LLAMA_EXAMPLE_MTMD}));
2771+
add_opt(common_arg(
2772+
{"--image-min-tokens"}, "N",
2773+
"minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)",
2774+
[](common_params & params, int value) {
2775+
params.image_min_tokens = value;
2776+
}
2777+
).set_examples(mmproj_examples).set_env("LLAMA_ARG_IMAGE_MIN_TOKENS"));
2778+
add_opt(common_arg(
2779+
{"--image-max-tokens"}, "N",
2780+
"maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)",
2781+
[](common_params & params, int value) {
2782+
params.image_max_tokens = value;
2783+
}
2784+
).set_examples(mmproj_examples).set_env("LLAMA_ARG_IMAGE_MAX_TOKENS"));
27712785
if (llama_supports_rpc()) {
27722786
add_opt(common_arg(
27732787
{"--rpc"}, "SERVERS",

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ struct common_params {
406406
bool mmproj_use_gpu = true; // use GPU for multimodal model
407407
bool no_mmproj = false; // explicitly disable multimodal model
408408
std::vector<std::string> image; // path to image file(s)
409+
int image_min_tokens = -1;
410+
int image_max_tokens = -1;
409411

410412
// finetune
411413
struct lr_opt lr;

convert_hf_to_gguf.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9802,6 +9802,113 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
98029802

98039803
return [(self.map_tensor_name(name), data_torch)]
98049804

9805+
9806+
@ModelBase.register("JanusForConditionalGeneration")
9807+
class JanusProModel(LlamaModel):
9808+
model_arch = gguf.MODEL_ARCH.LLAMA # reuse Llama arch
9809+
9810+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
9811+
# Skip vision, aligner, and generation tensors
9812+
skip_prefixes = (
9813+
'model.vision_model.',
9814+
'model.aligner.',
9815+
'model.vqmodel.',
9816+
'model.generation_embeddings.',
9817+
'model.generation_aligner.',
9818+
'model.generation_head.',
9819+
)
9820+
if name.startswith(skip_prefixes):
9821+
return []
9822+
9823+
if name.startswith('model.language_model.'):
9824+
name = name.replace('model.language_model.', 'model.')
9825+
elif name.startswith('language_model.'):
9826+
name = name.replace('language_model.', '')
9827+
9828+
return super().modify_tensors(data_torch, name, bid)
9829+
9830+
9831+
@ModelBase.register("JanusForConditionalGeneration")
9832+
class JanusProVisionModel(MmprojModel):
9833+
def __init__(self, *args, **kwargs):
9834+
super().__init__(*args, **kwargs)
9835+
assert self.hparams_vision is not None
9836+
if "intermediate_size" not in self.hparams_vision:
9837+
mlp_ratio = self.hparams_vision.get("mlp_ratio")
9838+
hidden_size = self.hparams_vision.get("hidden_size")
9839+
if mlp_ratio is not None and hidden_size is not None:
9840+
self.hparams_vision["intermediate_size"] = int(round(hidden_size * mlp_ratio))
9841+
9842+
def set_gguf_parameters(self):
9843+
super().set_gguf_parameters()
9844+
assert self.hparams_vision is not None
9845+
9846+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JANUS_PRO)
9847+
9848+
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-6))
9849+
9850+
hidden_act = str(self.hparams_vision.get("hidden_act", "")).lower()
9851+
if hidden_act == "gelu":
9852+
self.gguf_writer.add_vision_use_gelu(True)
9853+
elif hidden_act == "silu":
9854+
self.gguf_writer.add_vision_use_silu(True)
9855+
9856+
def _map_aligner_tensor(self, data_torch: Tensor, name: str) -> Iterable[tuple[str, Tensor]]:
9857+
"""Map aligner tensors to projector format"""
9858+
suffix = ".bias" if name.endswith(".bias") else ".weight"
9859+
9860+
if name.startswith("model.aligner."):
9861+
local_name = name[len("model.aligner."):]
9862+
elif name.startswith("aligner."):
9863+
local_name = name[len("aligner."):]
9864+
else:
9865+
raise ValueError(f"Unsupported Janus aligner prefix: {name}")
9866+
9867+
if local_name.startswith("fc1."):
9868+
mm_index = 0
9869+
elif local_name.startswith("hidden_layers."):
9870+
parts = local_name.split(".", 2)
9871+
if len(parts) < 3:
9872+
raise ValueError(f"Unexpected Janus aligner tensor name: {name}")
9873+
mm_index = int(parts[1]) + 1
9874+
else:
9875+
raise ValueError(f"Unsupported Janus aligner tensor: {name}")
9876+
9877+
tensor_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, mm_index, suffix=suffix)
9878+
return [(tensor_name, data_torch)]
9879+
9880+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
9881+
del bid # unused
9882+
9883+
# Skip language model tensors as they will be handled by `JanusProModel`
9884+
if name.startswith(('model.language_model.', 'language_model.')):
9885+
return []
9886+
9887+
# Skip generation-related components
9888+
skip_generation_prefixes = (
9889+
'model.vqmodel.',
9890+
'vqmodel.',
9891+
'model.generation_embeddings.',
9892+
'generation_embeddings.',
9893+
'model.generation_aligner.',
9894+
'generation_aligner.',
9895+
'model.generation_head.',
9896+
'generation_head.',
9897+
)
9898+
if name.startswith(skip_generation_prefixes):
9899+
return []
9900+
9901+
# Handle aligner tensors
9902+
if name.startswith(('model.aligner.', 'aligner.')):
9903+
return list(self._map_aligner_tensor(data_torch, name))
9904+
9905+
# Handle vision tensors
9906+
if name.startswith(('model.vision_model.', 'vision_model.')):
9907+
return [(self.map_tensor_name(name), data_torch)]
9908+
9909+
return []
9910+
9911+
98059912
###### CONVERSION LOGIC ######
98069913

98079914

examples/model-conversion/scripts/causal/run-org-model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ def fn(_m, input, output):
138138
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
139139
)
140140

141+
142+
print("Loading model and tokenizer using AutoTokenizer:", model_path)
143+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
141144
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
142145

143146
print("Model type: ", config.model_type)
@@ -147,10 +150,6 @@ def fn(_m, input, output):
147150
print("BOS token id: ", config.bos_token_id)
148151
print("EOS token id: ", config.eos_token_id)
149152

150-
print("Loading model and tokenizer using AutoTokenizer:", model_path)
151-
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
152-
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
153-
154153
if unreleased_model_name:
155154
model_name_lower = unreleased_model_name.lower()
156155
unreleased_module_path = (
@@ -171,7 +170,7 @@ def fn(_m, input, output):
171170
exit(1)
172171
else:
173172
model = AutoModelForCausalLM.from_pretrained(
174-
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True
173+
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
175174
)
176175

177176
for name, module in model.named_modules():

ggml/src/ggml-cpu/arch/loongarch/quants.c

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
700700
for (; ib + 1 < nb; ib += 2) {
701701

702702
// Compute combined scale for the block 0 and 1
703-
const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
703+
const float ft0 = GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d);
704+
const __m128 d_0_1 = (__m128)(v4f32){ft0, ft0, ft0, ft0};
704705

705706
const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
706707

@@ -714,11 +715,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
714715
bx_1 = __lsx_vsub_b(bx_1, off);
715716
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
716717

717-
//_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
718-
//_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
719-
720718
// Compute combined scale for the block 2 and 3
721-
const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) );
719+
const float ft1 = GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d);
720+
const __m128 d_2_3 = (__m128)(v4f32){ft1, ft1, ft1, ft1};
722721

723722
const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
724723

ggml/src/ggml-cpu/ggml-cpu-impl.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,13 +500,15 @@ inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
500500

501501
#endif
502502

503-
#if defined(__loongarch_asx)
503+
#if defined(__loongarch_sx)
504504
/* float type data load instructions */
505505
static __m128 __lsx_vreplfr2vr_s(const float val) {
506506
v4f32 res = {val, val, val, val};
507507
return (__m128)res;
508508
}
509+
#endif
509510

511+
#if defined(__loongarch_asx)
510512
static __m256 __lasx_xvreplfr2vr_s(const float val) {
511513
v8f32 res = {val, val, val, val, val, val, val, val};
512514
return (__m256)res;

ggml/src/ggml-cpu/repack.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,10 +1678,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16781678
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
16791679
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
16801680

1681+
// Ensure minimum chunk size to avoid alignment issues with high thread counts
1682+
// Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
1683+
const int64_t min_chunk_size = NB_COLS;
1684+
if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
1685+
nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
1686+
}
1687+
16811688
if (nth == 1 || nchunk < nth || disable_chunking) {
16821689
nchunk = nth;
16831690
}
16841691

1692+
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
1693+
// This prevents creating too many tiny chunks that could overlap after alignment
1694+
const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
1695+
if (nchunk > max_nchunk) {
1696+
nchunk = max_nchunk;
1697+
}
1698+
16851699
if (ith == 0) {
16861700
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
16871701
ggml_threadpool_chunk_set(params->threadpool, nth);
@@ -1695,8 +1709,15 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16951709
while (current_chunk < nchunk) {
16961710
int64_t src0_start = (current_chunk * ne01) / nchunk;
16971711
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
1712+
1713+
// Align boundaries to NB_COLS - round up to ensure all data is included
1714+
// The chunk size limiting above ensures chunks are large enough to prevent overlaps
16981715
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
16991716
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1717+
if (src0_end > ne01) {
1718+
src0_end = ne01;
1719+
}
1720+
17001721
if (src0_start >= src0_end) {
17011722
break;
17021723
}
@@ -1808,8 +1829,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
18081829
int64_t src0_cur_start = (ith * ne01) / nth;
18091830
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
18101831

1832+
// Align boundaries to NB_COLS - round up to ensure all data is included
18111833
src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
18121834
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
1835+
if (src0_cur_end > ne01) {
1836+
src0_cur_end = ne01;
1837+
}
18131838

18141839
if (src0_cur_start >= src0_cur_end) {
18151840
return;

0 commit comments

Comments
 (0)