1010import torch
1111from typing import List
1212from rwkv .model import RWKV
13+ os .environ ['RWKV_JIT_ON' ] = '1'
14+ os .environ ["RWKV_CUDA_ON" ] = '0'
1315
1416def parse_args ():
1517 parser = argparse .ArgumentParser (description = 'Measure perplexity and per-token latency of an RWKV model on a given text file' )
@@ -56,9 +58,10 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str:
5658
5759# ---
5860device = torch .device ('cuda:0' if torch .cuda .is_available () else 'cpu' )
59- # device= torch.device('cpu')
61+ # device = torch.device('cpu')
6062
61- model = RWKV (model = args .model_path , strategy = 'cuda fp16i8' )
63+ #TODO: Why is PERPLEXITY SO DAMN HIGH ?
64+ model = RWKV (model = args .model_path , strategy = 'cuda fp16' )
6265
6366logits , state = None , None
6467loss_sum : torch .Tensor = torch .tensor ([0.0 ], device = device )
@@ -72,7 +75,7 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str:
7275for i in range (run_count ):
7376 token : int = test_tokens [i ]
7477 target : int = test_tokens [i + 1 ]
75-
78+
7679 logits , state = model .forward ([token ], None if i == 0 else state )
7780
7881 if ignore_first_n_tokens == 0 or i + 1 >= ignore_first_n_tokens :
@@ -105,7 +108,7 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str:
105108print (f'Average latency: { int ((time .time () - start ) * 1000 / run_count )} ms per token' )
106109
107110print ()
108- print (f'Model: { os .path .basename (args .model_path )} , '
109- f'data: { os .path .basename (args .dataset_path )} with { token_count } tokens, '
110- f'Ignored first { ignore_first_n_tokens } tokens, '
111+ print (f'Model: { os .path .basename (args .model_path )} \n '
112+ f'data: { os .path .basename (args .dataset_path )} with { token_count } tokens\n '
113+ f'Ignored first { ignore_first_n_tokens } tokens\n '
111114 f'averages: { format_loss_with_perplexity (loss_sum / loss_count )} ' )
0 commit comments