From 231acf8685fa03c210a5f631a34190d795ea151c Mon Sep 17 00:00:00 2001 From: Nazar Kulyk Date: Tue, 14 Nov 2023 11:42:01 +0100 Subject: [PATCH 1/2] fix CmakeLists.txt for extras for static build #146 --- extras/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extras/CMakeLists.txt b/extras/CMakeLists.txt index 184b948f..2787b68b 100644 --- a/extras/CMakeLists.txt +++ b/extras/CMakeLists.txt @@ -11,9 +11,9 @@ function(rwkv_add_extra source) if(RWKV_HIPBLAS) message(FATAL_ERROR "Static linking not supported for HIP/ROCm") else() - get_target_property(target_LINK_OPTIONS ${TEST_TARGET} LINK_OPTIONS) + get_target_property(target_LINK_OPTIONS rwkv_${EXTRA_TARGET} LINK_OPTIONS) list(REMOVE_ITEM target_LINK_OPTIONS "-static") - set_target_properties(${TEST_TARGET} PROPERTIES LINK_OPTIONS "${target_LINK_OPTIONS}") + set_target_properties(rwkv_${EXTRA_TARGET} PROPERTIES LINK_OPTIONS "${target_LINK_OPTIONS}") endif() endif() endfunction() From f14e2ceb6218d4074f2f52f4d37247d24d0db2f6 Mon Sep 17 00:00:00 2001 From: Nazar Kulyk Date: Wed, 29 Nov 2023 14:41:31 +0100 Subject: [PATCH 2/2] refactoring to model version detection --- python/convert_pytorch_to_ggml.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/python/convert_pytorch_to_ggml.py b/python/convert_pytorch_to_ggml.py index 99568449..77f62f5a 100644 --- a/python/convert_pytorch_to_ggml.py +++ b/python/convert_pytorch_to_ggml.py @@ -32,15 +32,21 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t n_vocab: int = emb_weight.shape[0] n_embed: int = emb_weight.shape[1] - is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict - is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict - - if is_v5_2: - print('Detected RWKV v5.2') - elif is_v5_1_or_2: - print('Detected RWKV v5.1') - else: - print('Detected RWKV v4') + version = 4 + keys = list(state_dict.keys()) + for k in keys: + if 'ln_x' in k: + version = max(5, version) + if 'gate.weight' in k: + version = max(5.1, version) + if int(version) == 5 and 'att.time_decay' in k: + if len(state_dict[k].shape) > 1: + if (state_dict[k].shape[1]) > 1: + version = max(5.2, version) + if "time_maa" in k: + version = max(6, version) + + print(f'Model detected v{version:.1f}') with open(dest_path, 'wb') as out_file: is_FP16: bool = data_type == 'FP16' or data_type == 'float16' @@ -57,15 +63,16 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t 1 if is_FP16 else 0 )) - for k in state_dict.keys(): + keys = list(state_dict.keys()) + for k in keys: tensor: torch.Tensor = state_dict[k].float() if '.time_' in k: tensor = tensor.squeeze() - if is_v5_1_or_2: + if int(version) == 5: if '.time_decay' in k: - if is_v5_2: + if version == 5.2: tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1) else: tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1)