Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rwkv_pip_package/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rwkv"
version = "0.7.3"
version = "0.7.4"
authors = [
{ name="Bo PENG" },
]
Expand Down
29 changes: 28 additions & 1 deletion rwkv_pip_package/src/rwkv/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry):
########################################################################################################

class RWKV(MyModule):
def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None):
def __init__(self, model, strategy, lora, verbose = True, convert_and_save_and_exit = None):
super().__init__()
if verbose:
prxxx = lambda *args, **kwargs: print(*args, **kwargs)
Expand Down Expand Up @@ -102,6 +102,33 @@ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit =
gc.collect()
w = self.w

if lora.lora_r > 0:
prxxx(f'Loading lora ...')
# merge LoRA-only slim checkpoint into the main weights
w_lora = torch.load(lora.MODEL_LORA + '.pth', map_location='cpu')
for k in w_lora.keys():
w[k] = w_lora[k]
# merge LoRA weights
keys = set(w.keys())
for k in keys:
k: str
if k.endswith('.weight'):
prefix = k[:-len('.weight')]
lora_A = prefix + '.lora_A'
lora_B = prefix + '.lora_B'
if lora_A in keys:
assert lora_B in keys
print(f'merging {lora_A} and {lora_B} into {k}')
assert w[lora_B].shape[1] == w[lora_A].shape[0] == lora.lora_r
# merging needs matmul, which is slow on cpu; work on gpu if possible
if lora.RUN_DEVICE == 'cuda':
w[k] = w[k].cuda()
w[lora_A] = w[lora_A].cuda()
w[lora_B] = w[lora_B].cuda()
w[k] += w[lora_B] @ w[lora_A] * (lora.lora_alpha / lora.lora_r)
del w[lora_A]
del w[lora_B]

ALREADY_CONVERTED = False
if '_strategy' in w:
ALREADY_CONVERTED = True
Expand Down
10 changes: 9 additions & 1 deletion v2/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@
# args.strategy = 'cuda fp16i8 -> cpu fp32 *10'
# args.strategy = 'cuda fp16i8 *10+'


lora = types.SimpleNamespace()
lora.MODEL_LORA = './cp/rwkv-10'
lora.lora_r = 0 #r = 0 for no LORA
lora.lora_alpha = 8
lora.RUN_DEVICE = "cuda"


os.environ["RWKV_JIT_ON"] = '1' # '1' or '0', please use torch 1.13+ and benchmark speed
os.environ["RWKV_CUDA_ON"] = '0' # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries

Expand Down Expand Up @@ -119,7 +127,7 @@
# Load Model

print(f'Loading model - {args.MODEL_NAME}')
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy, lora=lora)
if not PILE_v2_MODEL:
pipeline = PIPELINE(model, f"{current_path}/20B_tokenizer.json")
END_OF_TEXT = 0
Expand Down