Skip to content

Commit be863c6

Browse files
Merge pull request #186 from MollySophia/v7-new
Add initial support for RWKV v7
2 parents 84fea22 + 5658a65 commit be863c6

39 files changed

+937
-1066
lines changed

.github/workflows/build.yml

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,8 @@ jobs:
186186
defines: '-DRWKV_AVX512=ON'
187187
- build: 'cuda12'
188188
defines: '-DRWKV_CUBLAS=ON'
189-
- build: 'rocm5.5'
190-
defines: '-G "Unix Makefiles" -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DRWKV_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS="gfx1100;gfx1102;gfx1030"'
191-
189+
- build: 'hip'
190+
defines: ''
192191
steps:
193192
- name: Clone
194193
id: checkout
@@ -206,25 +205,52 @@ jobs:
206205

207206
- name: Install rocm-toolkit
208207
id: rocm-toolkit
209-
if: ${{ matrix.build == 'rocm5.5' }}
210-
uses: Cyberhan123/rocm-toolkit@v0.1.0
211-
with:
212-
rocm: '5.5.0'
208+
if: ${{ matrix.build == 'hip' }}
209+
run: |
210+
$ErrorActionPreference = "Stop"
211+
write-host "Downloading AMD HIP SDK Installer"
212+
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
213+
write-host "Installing AMD HIP SDK"
214+
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
215+
write-host "Completed AMD HIP SDK installation"
216+
217+
- name: Verify ROCm
218+
id: rocm-verify
219+
if: ${{ matrix.build == 'hip' }}
220+
run: |
221+
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
213222
214223
- name: Install Ninja
215224
id: install-ninja
216-
if: ${{ matrix.build == 'rocm5.5' }}
225+
if: ${{ matrix.build == 'hip' }}
217226
uses: urkle/action-get-ninja@v1
218227
with:
219228
version: 1.11.1
220229

230+
- name: Install ccache
231+
uses: hendrikmuhs/ccache-action@v1.2
232+
with:
233+
key: ${{ github.job }}
234+
221235
- name: Build
222236
id: cmake_build
237+
if: ${{ matrix.build != 'hip' }}
223238
run: |
224239
mkdir build
225240
cd build
226241
cmake .. ${{ matrix.defines }}
227-
cmake --build . --config Release
242+
cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS}
243+
244+
- name: Build-hip
245+
id: cmake_build_hip
246+
if: ${{ matrix.build == 'hip' }}
247+
run: |
248+
mkdir build
249+
cd build
250+
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
251+
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
252+
cmake .. -G "Unix Makefiles" -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DRWKV_HIPBLAS=ON -DGGML_HIP=ON -DCMAKE_BUILD_TYPE=Release
253+
cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS}
228254
229255
- name: Check AVX512F support
230256
id: check_avx512f
@@ -242,7 +268,7 @@ jobs:
242268
- name: Test
243269
id: cmake_test
244270
# Test AVX-512 only when possible
245-
if: ${{ (matrix.build != 'avx512' || env.HAS_AVX512F == '1') && matrix.build != 'cuda12' && matrix.build != 'rocm5.5'}}
271+
if: ${{ (matrix.build != 'avx512' || env.HAS_AVX512F == '1') && matrix.build != 'cuda12' && matrix.build != 'hip'}}
246272
run: |
247273
cd build
248274
ctest -C Release --verbose

CMakeLists.txt

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ endfunction()
5858

5959
set(GGML_ACCELERATE ${RWKV_ACCELERATE})
6060
set(GGML_CUDA ${RWKV_CUBLAS})
61-
set(GGML_HIPBLAS ${RWKV_HIPBLAS})
61+
set(GGML_HIP ${RWKV_HIPBLAS})
6262
set(GGML_METAL ${RWKV_METAL})
6363
if (RWKV_OPENBLAS)
6464
set(GGML_BLAS_VENDOR "OpenBLAS")
@@ -107,6 +107,7 @@ if (RWKV_ALL_WARNINGS)
107107
-Wcast-qual
108108
-Wno-unused-function
109109
-Wno-multichar
110+
-Wno-nonnull
110111
)
111112
else()
112113
set(c_flags
@@ -234,7 +235,7 @@ if (GGML_METAL)
234235
)
235236
endif()
236237

237-
if (GGML_HIPBLAS)
238+
if (GGML_HIP)
238239
# CMake on Windows doesn't support the HIP language yet
239240
if (WIN32)
240241
set(CXX_IS_HIPCC TRUE)
@@ -262,12 +263,39 @@ if (GGML_HIPBLAS)
262263
endif()
263264

264265
target_include_directories(rwkv PUBLIC .)
265-
target_include_directories(rwkv PRIVATE ggml/include)
266+
target_include_directories(rwkv PRIVATE ggml/include ggml/src)
266267
target_compile_features(rwkv PUBLIC cxx_std_11)
267-
target_link_libraries(rwkv PRIVATE $<TARGET_OBJECTS:ggml> ${RWKV_EXTRA_LIBS})
268+
269+
if (GGML_METAL)
270+
set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} $<TARGET_OBJECTS:ggml-metal> $<TARGET_OBJECTS:ggml-blas>)
271+
endif()
272+
if (GGML_CUDA)
273+
set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} $<TARGET_OBJECTS:ggml-cuda>)
274+
endif()
275+
if (GGML_HIP)
276+
set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} $<TARGET_OBJECTS:ggml-hip>)
277+
endif()
278+
if (GGML_RPC)
279+
set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} $<TARGET_OBJECTS:ggml-rpc>)
280+
endif()
281+
282+
target_link_libraries(rwkv PRIVATE $<TARGET_OBJECTS:ggml> $<TARGET_OBJECTS:ggml-base> $<TARGET_OBJECTS:ggml-cpu> ${RWKV_EXTRA_LIBS})
268283

269284
if (RWKV_BUILD_SHARED_LIBRARY)
270285
set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
286+
set_target_properties(ggml-base PROPERTIES POSITION_INDEPENDENT_CODE ON)
287+
set_target_properties(ggml-cpu PROPERTIES POSITION_INDEPENDENT_CODE ON)
288+
if (GGML_METAL)
289+
set_target_properties(ggml-metal PROPERTIES POSITION_INDEPENDENT_CODE ON)
290+
set_target_properties(ggml-blas PROPERTIES POSITION_INDEPENDENT_CODE ON)
291+
endif()
292+
if (GGML_CUDA)
293+
set_target_properties(ggml-cuda PROPERTIES POSITION_INDEPENDENT_CODE ON)
294+
endif()
295+
if (GGML_HIP)
296+
set_target_properties(ggml-hip PROPERTIES POSITION_INDEPENDENT_CODE ON)
297+
endif()
298+
271299
target_compile_definitions(ggml PRIVATE GGML_SHARED GGML_BUILD)
272300
set_target_properties(rwkv PROPERTIES POSITION_INDEPENDENT_CODE ON)
273301
target_compile_definitions(rwkv PRIVATE RWKV_SHARED RWKV_BUILD)

README.md

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,18 @@ Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT
66

77
This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](python%2Frwkv_cpp%2Frwkv_cpp_model.py) for it.
88

9-
[RWKV](https://arxiv.org/abs/2305.13048) is a large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts.
9+
[RWKV](https://arxiv.org/abs/2305.13048) is a large language model architecture. In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts.
1010

11-
[RWKV v5](https://huggingface.co/BlinkDL/rwkv-5-world) is a major upgrade to RWKV architecture, making it competitive with Transformers in quality. RWKV v5 models are supported.
12-
13-
[RWKV v6](https://huggingface.co/BlinkDL/rwkv-6-world) is a further improvement to RWKV architecture, with better quality. RWKV v6 models are supported.
11+
This project supports RWKV [v4](https://huggingface.co/BlinkDL/rwkv-4-pile-14b), [v5](https://huggingface.co/BlinkDL/rwkv-5-world), [v6](https://huggingface.co/BlinkDL/rwkv-6-world) and the latest [v7](https://huggingface.co/BlinkDL/rwkv-7-world) architectures.
1412

1513
Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py).
1614

15+
<!-- TODO: Update data below -->
16+
1717
## Quality and performance
1818

1919
If you use `rwkv.cpp` for anything serious, please [test all available formats for perplexity and latency](rwkv%2Fmeasure_pexplexity.py) on a representative dataset, and decide which trade-off is best for you.
2020

21-
In general, **`RWKV v5` models are as fast as `RWKV v4` models**, with minor differencies in latency and memory consumption, and with having way higher quality than `v4`. Therefore, it is recommended to use `RWKV v5`.
22-
2321
Below table is for reference only. Measurements were made on 4C/8T x86 CPU with AVX2, 4 threads. The models are `RWKV v4 Pile 169M`, `RWKV v4 Pile 1.5B`.
2422

2523
| Format | Perplexity (169M) | Latency, ms (1.5B) | File size, GB (1.5B) |

extras/quantize.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ bool QueryPerformanceCounter(uint64_t* lpPerformanceCount);
2525
static enum ggml_type type_from_string(const char * string) {
2626
if (strcmp(string, "Q4_0") == 0) return GGML_TYPE_Q4_0;
2727
if (strcmp(string, "Q4_1") == 0) return GGML_TYPE_Q4_1;
28+
if (strcmp(string, "Q4_K") == 0) return GGML_TYPE_Q4_K;
2829
if (strcmp(string, "Q5_0") == 0) return GGML_TYPE_Q5_0;
2930
if (strcmp(string, "Q5_1") == 0) return GGML_TYPE_Q5_1;
31+
if (strcmp(string, "Q5_K") == 0) return GGML_TYPE_Q5_K;
3032
if (strcmp(string, "Q8_0") == 0) return GGML_TYPE_Q8_0;
3133
return GGML_TYPE_COUNT;
3234
}

ggml

Submodule ggml updated from 3e7e5e2 to c8bd0fe

python/chat_with_bot.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
4242
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
43+
parser.add_argument('-ngl', '--num_gpu_layers', type=int, default=99, help='Number of layers to run on GPU')
4344
add_tokenizer_argument(parser)
4445
args = parser.parse_args()
4546

@@ -48,7 +49,7 @@
4849
with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r', encoding='utf8') as json_file:
4950
prompt_data = json.load(json_file)
5051

51-
user, bot, separator, init_prompt = prompt_data['user'], prompt_data['bot'], prompt_data['separator'], prompt_data['prompt']
52+
user, assistant, separator, init_prompt = prompt_data['user'], prompt_data['assistant'], prompt_data['separator'], prompt_data['prompt']
5253

5354
if init_prompt == '':
5455
raise ValueError('Prompt must not be empty')
@@ -57,7 +58,7 @@
5758
print(f'System info: {library.rwkv_get_system_info_string()}')
5859

5960
print('Loading RWKV model')
60-
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
61+
model = rwkv_cpp_model.RWKVModel(library, args.model_path, gpu_layer_count=args.num_gpu_layers)
6162

6263
tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)
6364

@@ -154,7 +155,7 @@ def split_last_end_of_line(tokens: List[int]) -> List[int]:
154155
if msg == '+reset':
155156
load_thread_state('chat_init')
156157
save_thread_state('chat')
157-
print(f'{bot}{separator} Chat reset.\n')
158+
print(f'{assistant}{separator} Chat reset.\n')
158159
continue
159160
elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++':
160161

@@ -194,7 +195,7 @@ def split_last_end_of_line(tokens: List[int]) -> List[int]:
194195
load_thread_state('chat_init')
195196

196197
real_msg = msg[4:].strip()
197-
new = f'{user}{separator} {real_msg}\n\n{bot}{separator}'
198+
new = f'{user}{separator} {real_msg}\n\n{assistant}{separator}'
198199

199200
process_tokens(tokenizer_encode(new))
200201
save_thread_state('gen_0')
@@ -225,17 +226,17 @@ def split_last_end_of_line(tokens: List[int]) -> List[int]:
225226
except Exception as e:
226227
print(e)
227228
continue
228-
# chat with bot
229+
# chat with assistant
229230
else:
230231
load_thread_state('chat')
231-
new = f'{user}{separator} {msg}\n\n{bot}{separator}'
232+
new = f'{user}{separator} {msg}\n\n{assistant}{separator}'
232233
process_tokens(tokenizer_encode(new), new_line_logit_bias=-999999999)
233234
save_thread_state('chat_pre')
234235

235236
thread = 'chat'
236237

237-
# Print bot response
238-
print(f'> {bot}{separator}', end='')
238+
# Print assistant response
239+
print(f'> {assistant}{separator}', end='')
239240

240241
start_index: int = len(processed_tokens)
241242
accumulated_tokens: List[int] = []

python/convert_pytorch_to_ggml.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
3535
is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict
3636
is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict
3737
is_v6_0: bool = 'blocks.0.att.time_maa_x' in state_dict
38+
is_v7_0: bool = 'blocks.0.att.k_k' in state_dict
3839

39-
if is_v6_0:
40+
if is_v7_0:
41+
print('Detected RWKV v7.0')
42+
elif is_v6_0:
4043
print('Detected RWKV v6.0')
4144
elif is_v5_2:
4245
print('Detected RWKV v5.2')
@@ -45,6 +48,23 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
4548
else:
4649
print('Detected RWKV v4')
4750

51+
if is_v7_0:
52+
# concat to reduce some cpu overhead during ggml inference
53+
state_dict_new = {}
54+
for k in state_dict.keys():
55+
if 'att.x_' in k:
56+
l = int(k.split('.')[1].split('.')[0])
57+
try:
58+
state_dict_new[f'blocks.{l}.att.x_rwkvag'] = torch.cat(
59+
[state_dict_new[f'blocks.{l}.att.x_rwkvag'], state_dict[k]], dim=0)
60+
except KeyError:
61+
state_dict_new[f'blocks.{l}.att.x_rwkvag'] = state_dict[k]
62+
else:
63+
state_dict_new[k] = state_dict[k]
64+
65+
del state_dict[k]
66+
state_dict = state_dict_new
67+
4868
with open(dest_path, 'wb') as out_file:
4969
is_FP16: bool = data_type == 'FP16' or data_type == 'float16'
5070

@@ -68,7 +88,16 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
6888
if '.time_' in k:
6989
tensor = tensor.squeeze()
7090

71-
if is_v6_0:
91+
if is_v7_0:
92+
if any(s in k for s in [
93+
'.w1', '.w2',
94+
'.a1', '.a2',
95+
'.v1', '.v2',
96+
'.g1', '.g2',
97+
]):
98+
tensor = tensor.transpose(0, 1)
99+
100+
elif is_v6_0:
72101
if '.time_faaaa' in k:
73102
tensor = tensor.unsqueeze(-1)
74103
if '.time_maa_w1' in k or '.time_decay_w' in k:
@@ -95,7 +124,14 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
95124
tensor = -torch.exp(tensor)
96125

97126
# Keep 1-dim vectors and small matrices in FP32
98-
if is_FP16 and len(tensor.shape) > 1 and '.time_' not in k:
127+
if is_FP16 and len(tensor.shape) > 1 and all(
128+
s not in k for s in [
129+
'.time_',
130+
'.k_k', '.k_a', '.r_k',
131+
'.x_rwkvag', '.x_k',
132+
'.w0', '.a0', '.v0',
133+
]
134+
):
99135
tensor = tensor.half()
100136

101137
shape = tensor.shape

python/generate_completions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt')
3131
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
32+
parser.add_argument('-ngl', '--num_gpu_layers', type=int, default=99, help='Number of layers to run on GPU')
3233
add_tokenizer_argument(parser)
3334
args = parser.parse_args()
3435

@@ -39,7 +40,7 @@
3940
print(f'System info: {library.rwkv_get_system_info_string()}')
4041

4142
print('Loading RWKV model')
42-
model = rwkv_cpp_model.RWKVModel(library, args.model_path, gpu_layers_count=0)
43+
model = rwkv_cpp_model.RWKVModel(library, args.model_path, gpu_layers_count=args.num_gpu_layers)
4344

4445
tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)
4546

python/inference_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
# Parse received arguments.
1111
parser = argparse.ArgumentParser(description='Generate some text with an RWKV model')
1212
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
13+
parser.add_argument('-ngl', '--num_gpu_layers', type=int, default=99, help='Number of layers to run on GPU')
1314
add_tokenizer_argument(parser)
1415
args = parser.parse_args()
1516

1617
# Load the model.
1718
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
18-
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
19+
model = rwkv_cpp_model.RWKVModel(library, args.model_path, gpu_layer_count=args.num_gpu_layers)
1920

2021
# Set up the tokenizer.
2122
tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)

python/prompt/Chinese-Chat.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"user": "Bob",
3-
"bot": "Alice",
3+
"assistant": "Alice",
44
"separator": ":",
55
"prompt": "\nThe following is a coherent verbose detailed conversation between a Chinese girl named Alice and her friend Bob. Alice is very intelligent, creative and friendly. Alice likes to tell Bob a lot about herself and her opinions. Alice usually gives Bob kind, helpful and informative advices.\n\nBob: lhc\n\nAlice: LHC是指大型强子对撞机(Large Hadron Collider),是世界最大最强的粒子加速器,由欧洲核子中心(CERN)在瑞士日内瓦地下建造。LHC的原理是加速质子(氢离子)并让它们相撞,让科学家研究基本粒子和它们之间的相互作用,并在2012年证实了希格斯玻色子的存在。\n\nBob: 企鹅会飞吗\n\nAlice: 企鹅是不会飞的。企鹅的翅膀短而扁平,更像是游泳时的一对桨。企鹅的身体结构和羽毛密度也更适合在水中游泳,而不是飞行。\n\n"
66
}

0 commit comments

Comments
 (0)