Skip to content

Commit 895329e

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 6e6934d + 51abc96 commit 895329e

File tree

33 files changed

+500
-98
lines changed

33 files changed

+500
-98
lines changed

.devops/rocm.Dockerfile

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,11 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build
1717
# gfx906 is deprecated
1818
#check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html
1919

20-
ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201'
21-
#ARG ROCM_DOCKER_ARCH=gfx1100
20+
ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151'
21+
#ARG ROCM_DOCKER_ARCH='gfx1151'
2222

23-
# Set ROCm architectured
23+
# Set ROCm architectures
2424
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
25-
# Enable ROCm
26-
# ENV CC=/opt/rocm/llvm/bin/clang
27-
# ENV CXX=/opt/rocm/llvm/bin/clang++
2825

2926
RUN apt-get update \
3027
&& apt-get install -y \
@@ -39,8 +36,16 @@ WORKDIR /app
3936

4037
COPY . .
4138

39+
RUN git clone https://github.com/rocm/rocwmma --branch develop --depth 1
40+
4241
RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
43-
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_BUILD_TESTS=OFF \
42+
cmake -S . -B build \
43+
-DGGML_HIP=ON \
44+
-DGGML_HIP_ROCWMMA_FATTN=ON \
45+
-DCMAKE_HIP_FLAGS="-I$(pwd)/rocwmma/library/include/" \
46+
-DAMDGPU_TARGETS="$ROCM_DOCKER_ARCH" \
47+
-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON \
48+
-DCMAKE_BUILD_TYPE=Release -DLLAMA_BUILD_TESTS=OFF \
4449
&& cmake --build build --config Release -j$(nproc)
4550

4651
RUN mkdir -p /app/lib \

.github/workflows/build.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ env:
5656

5757
jobs:
5858
macOS-latest-cmake-arm64:
59-
runs-on: macos-14
59+
runs-on: macos-latest
6060

6161
steps:
6262
- name: Clone
@@ -97,7 +97,7 @@ jobs:
9797
ctest -L 'main|curl' --verbose --timeout 900
9898
9999
macOS-latest-cmake-x64:
100-
runs-on: macos-13
100+
runs-on: macos-latest
101101

102102
steps:
103103
- name: Clone
@@ -138,7 +138,7 @@ jobs:
138138
ctest -L main --verbose --timeout 900
139139
140140
macOS-latest-cmake-arm64-webgpu:
141-
runs-on: macos-14
141+
runs-on: latest
142142

143143
steps:
144144
- name: Clone
@@ -1171,7 +1171,9 @@ jobs:
11711171
./build-xcframework.sh
11721172
11731173
- name: Build Xcode project
1174-
run: xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build
1174+
run: |
1175+
xcodebuild -downloadPlatform iOS
1176+
xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build
11751177
11761178
android-build:
11771179
runs-on: ubuntu-latest

.github/workflows/release.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -530,15 +530,13 @@ jobs:
530530
runs-on: windows-2022
531531

532532
env:
533-
# The ROCm version must correspond to the version used in the HIP SDK.
534-
ROCM_VERSION: "6.4.2"
535533
HIPSDK_INSTALLER_VERSION: "25.Q3"
536534

537535
strategy:
538536
matrix:
539537
include:
540538
- name: "radeon"
541-
gpu_targets: "gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
539+
gpu_targets: "gfx1151;gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
542540

543541
steps:
544542
- name: Clone
@@ -548,7 +546,7 @@ jobs:
548546
- name: Clone rocWMMA repository
549547
id: clone_rocwmma
550548
run: |
551-
git clone https://github.com/rocm/rocwmma --branch rocm-${{ env.ROCM_VERSION }} --depth 1
549+
git clone https://github.com/rocm/rocwmma --branch develop --depth 1
552550
553551
- name: Cache ROCm Installation
554552
id: cache-rocm

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ if (MSVC)
5858
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
5959
endif()
6060

61+
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
62+
set(LLAMA_TOOLS_INSTALL_DEFAULT OFF)
63+
else()
64+
set(LLAMA_TOOLS_INSTALL_DEFAULT ${LLAMA_STANDALONE})
65+
endif()
66+
6167
#
6268
# option list
6369
#
@@ -82,6 +88,7 @@ option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
8288
option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
8389
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
8490
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
91+
option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
8592

8693
# 3rd party libs
8794
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1704,7 +1704,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
17041704
[](common_params & params, const std::string & value) {
17051705
params.system_prompt = value;
17061706
}
1707-
).set_examples({LLAMA_EXAMPLE_MAIN}));
1707+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_DIFFUSION}));
17081708
add_opt(common_arg(
17091709
{"--no-perf"},
17101710
string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"),

convert_hf_to_gguf.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
888888
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
889889
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
890890
res = "mellum"
891+
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
892+
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
893+
res = "llada-moe"
891894

892895
if res is None:
893896
logger.warning("\n")
@@ -8239,6 +8242,76 @@ def prepare_tensors(self):
82398242
raise ValueError(f"Unprocessed experts: {experts}")
82408243

82418244

8245+
@ModelBase.register("LLaDAMoEModel", "LLaDAMoEModelLM")
8246+
class LLaDAMoEModel(TextModel):
8247+
model_arch = gguf.MODEL_ARCH.LLADA_MOE
8248+
8249+
def set_gguf_parameters(self):
8250+
super().set_gguf_parameters()
8251+
if (n_experts := self.hparams.get("num_experts")) is not None:
8252+
self.gguf_writer.add_expert_count(n_experts)
8253+
8254+
if (expert_intermediate_size := self.hparams.get("expert_intermediate_size")) is not None:
8255+
self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)
8256+
8257+
# number of experts used per token (top-k)
8258+
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
8259+
self.gguf_writer.add_expert_used_count(n_experts_used)
8260+
8261+
self.gguf_writer.add_mask_token_id(156895)
8262+
self.gguf_writer.add_causal_attention(False)
8263+
self.gguf_writer.add_diffusion_shift_logits(False)
8264+
8265+
_experts: list[dict[str, Tensor]] | None = None
8266+
8267+
# Copied from: Qwen2MoeModel
8268+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8269+
# process the experts separately
8270+
if name.find("experts") != -1:
8271+
n_experts = self.hparams["num_experts"]
8272+
assert bid is not None
8273+
8274+
if self._experts is None:
8275+
self._experts = [{} for _ in range(self.block_count)]
8276+
8277+
self._experts[bid][name] = data_torch
8278+
8279+
if len(self._experts[bid]) >= n_experts * 3:
8280+
tensors: list[tuple[str, Tensor]] = []
8281+
8282+
# merge the experts into a single 3d tensor
8283+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
8284+
datas: list[Tensor] = []
8285+
8286+
for xid in range(n_experts):
8287+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
8288+
datas.append(self._experts[bid][ename])
8289+
del self._experts[bid][ename]
8290+
8291+
data_torch = torch.stack(datas, dim=0)
8292+
8293+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
8294+
8295+
new_name = self.map_tensor_name(merged_name)
8296+
8297+
tensors.append((new_name, data_torch))
8298+
return tensors
8299+
else:
8300+
return []
8301+
8302+
return [(self.map_tensor_name(name), data_torch)]
8303+
8304+
# Copied from: Qwen2MoeModel
8305+
def prepare_tensors(self):
8306+
super().prepare_tensors()
8307+
8308+
if self._experts is not None:
8309+
# flatten `list[dict[str, Tensor]]` into `list[str]`
8310+
experts = [k for d in self._experts for k in d.keys()]
8311+
if len(experts) > 0:
8312+
raise ValueError(f"Unprocessed experts: {experts}")
8313+
8314+
82428315
@ModelBase.register("HunYuanDenseV1ForCausalLM")
82438316
class HunYuanModel(TextModel):
82448317
model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class TOKENIZER_TYPE(IntEnum):
139139
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
140140
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
141141
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
142+
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
142143
]
143144

144145
# some models are known to be broken upstream, so we will skip them as exceptions

examples/diffusion/diffusion-cli.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -510,19 +510,27 @@ static void diffusion_generate(llama_context * ctx,
510510
n_generated = params.max_length;
511511
}
512512

513-
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
513+
static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
514514
if (!use_chat_template) {
515515
return prompt;
516516
}
517517

518518
auto chat_templates = common_chat_templates_init(model, "");
519-
520519
common_chat_templates_inputs inputs;
521-
common_chat_msg user_msg;
522-
user_msg.role = "user";
523-
user_msg.content = prompt;
524-
inputs.add_generation_prompt = true;
520+
common_chat_msg system_msg;
521+
522+
if (!system_prompt.empty()) {
523+
system_msg.role = "system";
524+
system_msg.content = system_prompt;
525+
inputs.messages.push_back(system_msg);
526+
}
527+
528+
common_chat_msg user_msg;
529+
user_msg.role = "user";
530+
user_msg.content = prompt;
531+
525532
inputs.messages.push_back(user_msg);
533+
inputs.add_generation_prompt = true;
526534

527535
auto result = common_chat_templates_apply(chat_templates.get(), inputs);
528536

@@ -579,7 +587,8 @@ int main(int argc, char ** argv) {
579587
llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
580588

581589
const llama_vocab * vocab = llama_model_get_vocab(model);
582-
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
590+
591+
std::string formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model);
583592

584593
std::vector<llama_token> input_tokens = common_tokenize(vocab,
585594
formatted_prompt,
@@ -596,6 +605,7 @@ int main(int argc, char ** argv) {
596605
}
597606

598607
llama_token mask_token_id = llama_vocab_mask(vocab);
608+
599609
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
600610

601611
bool visual_mode = params.diffusion.visual_mode;

ggml/src/ggml-cuda/im2col.cu

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,14 @@ static __global__ void im2col_3d_kernel(
122122
int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
123123
int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
124124
int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
125+
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
125126
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
126127
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
127128
if (i >= IC_KD_KH_KW) {
128129
return;
129130
}
131+
GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH);
132+
GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW);
130133

131134
const int64_t iic = i / KD_KH_KW;
132135
const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
@@ -148,7 +151,7 @@ static __global__ void im2col_3d_kernel(
148151
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
149152
dst[offset_dst] = 0.0f;
150153
} else {
151-
const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
154+
const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
152155
dst[offset_dst] = src[offset_src];
153156
}
154157
}
@@ -159,6 +162,7 @@ template <typename T>
159162
static void im2col_3d_cuda(const float * src, T* dst,
160163
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
161164
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
165+
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
162166
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
163167
const int64_t OH_OW = OH*OW;
164168
const int64_t KD_KH_KW = KD*KH*KW;
@@ -179,23 +183,30 @@ static void im2col_3d_cuda(const float * src, T* dst,
179183
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
180184
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
181185
OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
186+
stride_q, stride_z, stride_y, stride_x,
182187
s0, s1, s2, p0, p1, p2, d0, d1, d2);
183188
}
184189

185190
static void im2col_3d_cuda_f16(const float * src, half * dst,
186191
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
187192
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
193+
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
188194
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
189195

190-
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
196+
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
197+
stride_q, stride_z, stride_y, stride_x,
198+
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
191199
}
192200

193201
static void im2col_3d_cuda_f32(const float * src, float * dst,
194202
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
195203
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
204+
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
196205
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
197206

198-
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
207+
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
208+
stride_q, stride_z, stride_y, stride_x,
209+
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
199210
}
200211

201212
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -235,9 +246,19 @@ void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
235246
const int64_t OH = ne2;
236247
const int64_t OW = ne1;
237248

249+
const size_t es = ggml_element_size(src1);
250+
const int64_t stride_x = src1->nb[0] / es;
251+
const int64_t stride_y = src1->nb[1] / es;
252+
const int64_t stride_z = src1->nb[2] / es;
253+
const int64_t stride_q = src1->nb[3] / es;
254+
238255
if(dst->type == GGML_TYPE_F16) {
239-
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
256+
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
257+
stride_q, stride_z, stride_y, stride_x,
258+
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
240259
} else {
241-
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
260+
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
261+
stride_q, stride_z, stride_y, stride_x,
262+
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
242263
}
243264
}

0 commit comments

Comments
 (0)