@@ -18,6 +18,7 @@ package llmdinferencesim
1818
1919import  (
2020	"context" 
21+ 	"encoding/json" 
2122	"errors" 
2223	"fmt" 
2324	"io" 
@@ -29,6 +30,8 @@ import (
2930
3031	"github.com/llm-d/llm-d-inference-sim/pkg/common" 
3132	kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache" 
33+ 	vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api" 
34+ 	"github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization" 
3235	. "github.com/onsi/ginkgo/v2" 
3336	. "github.com/onsi/gomega" 
3437	"github.com/openai/openai-go" 
@@ -39,6 +42,7 @@ import (
3942)
4043
4144const  model  =  "my_model" 
45+ const  qwenModelName  =  "Qwen/Qwen2-0.5B" 
4246const  baseURL  =  "http://localhost/v1" 
4347const  userMessage  =  "This is a test." 
4448const  invalidMaxTokensErrMsg  =  "Max completion tokens and max tokens should be positive" 
@@ -97,8 +101,17 @@ func startServerWithArgs(ctx context.Context, mode string, args []string, envs m
97101		return  nil , err 
98102	}
99103
104+ 	tokenizationConfig  :=  tokenization .DefaultConfig ()
105+ 	if  s .config .TokenizersCacheDir  !=  ""  {
106+ 		tokenizationConfig .TokenizersCacheDir  =  s .config .TokenizersCacheDir 
107+ 	}
108+ 	s .tokenizer , err  =  tokenization .NewCachedHFTokenizer (tokenizationConfig .HFTokenizerConfig )
109+ 	if  err  !=  nil  {
110+ 		return  nil , fmt .Errorf ("failed to create tokenizer: %w" , err )
111+ 	}
112+ 
100113	if  s .config .EnableKVCache  {
101- 		s .kvcacheHelper , err  =  kvcache .NewKVCacheHelper (s .config , s .logger , s .kvCacheUsageChan )
114+ 		s .kvcacheHelper , err  =  kvcache .NewKVCacheHelper (s .config , s .logger , s .kvCacheUsageChan ,  s . tokenizer )
102115		if  err  !=  nil  {
103116			return  nil , err 
104117		}
@@ -1065,7 +1078,71 @@ var _ = Describe("Simulator", func() {
10651078			Expect (factor ).To (BeNumerically (">" , 1.0 ))
10661079			Expect (factor ).To (BeNumerically ("<" , simulator .config .TimeFactorUnderLoad ))
10671080		})
1068- 
10691081	})
10701082
1083+ 	Context ("tokenize" , Ordered , func () {
1084+ 		tmpDir  :=  "./tests-tmp/" 
1085+ 		AfterAll (func () {
1086+ 			err  :=  os .RemoveAll (tmpDir )
1087+ 			Expect (err ).NotTo (HaveOccurred ())
1088+ 		})
1089+ 
1090+ 		It ("Should return correct response to /tokenize chat" , func () {
1091+ 			ctx  :=  context .TODO ()
1092+ 			args  :=  []string {"cmd" , "--model" , qwenModelName , "--mode" , common .ModeRandom ,
1093+ 				"--tokenizers-cache-dir" , tmpDir , "--max-model-len" , "2048" }
1094+ 			client , err  :=  startServerWithArgs (ctx , common .ModeRandom , args , nil )
1095+ 			Expect (err ).NotTo (HaveOccurred ())
1096+ 
1097+ 			reqBody  :=  `{ 
1098+ 				"messages": [{"role": "user", "content": "This is a test"}], 
1099+ 				"model": "Qwen/Qwen2-0.5B" 
1100+ 			}` 
1101+ 			resp , err  :=  client .Post ("http://localhost/tokenize" , "application/json" , strings .NewReader (reqBody ))
1102+ 			Expect (err ).NotTo (HaveOccurred ())
1103+ 			defer  func () {
1104+ 				err  :=  resp .Body .Close ()
1105+ 				Expect (err ).NotTo (HaveOccurred ())
1106+ 			}()
1107+ 
1108+ 			body , err  :=  io .ReadAll (resp .Body )
1109+ 			Expect (err ).NotTo (HaveOccurred ())
1110+ 
1111+ 			var  tokenizeResp  vllmapi.TokenizeResponse 
1112+ 			err  =  json .Unmarshal (body , & tokenizeResp )
1113+ 			Expect (err ).NotTo (HaveOccurred ())
1114+ 			Expect (tokenizeResp .Count ).To (Equal (4 ))
1115+ 			Expect (tokenizeResp .Tokens ).To (HaveLen (4 ))
1116+ 			Expect (tokenizeResp .MaxModelLen ).To (Equal (2048 ))
1117+ 		})
1118+ 
1119+ 		It ("Should return correct response to /tokenize text" , func () {
1120+ 			ctx  :=  context .TODO ()
1121+ 			args  :=  []string {"cmd" , "--model" , qwenModelName , "--mode" , common .ModeRandom ,
1122+ 				"--tokenizers-cache-dir" , tmpDir , "--max-model-len" , "2048" }
1123+ 			client , err  :=  startServerWithArgs (ctx , common .ModeRandom , args , nil )
1124+ 			Expect (err ).NotTo (HaveOccurred ())
1125+ 
1126+ 			reqBody  :=  `{ 
1127+ 				"prompt": "This is a test", 
1128+ 				"model": "Qwen/Qwen2-0.5B" 
1129+ 			}` 
1130+ 			resp , err  :=  client .Post ("http://localhost/tokenize" , "application/json" , strings .NewReader (reqBody ))
1131+ 			Expect (err ).NotTo (HaveOccurred ())
1132+ 			defer  func () {
1133+ 				err  :=  resp .Body .Close ()
1134+ 				Expect (err ).NotTo (HaveOccurred ())
1135+ 			}()
1136+ 
1137+ 			body , err  :=  io .ReadAll (resp .Body )
1138+ 			Expect (err ).NotTo (HaveOccurred ())
1139+ 
1140+ 			var  tokenizeResp  vllmapi.TokenizeResponse 
1141+ 			err  =  json .Unmarshal (body , & tokenizeResp )
1142+ 			Expect (err ).NotTo (HaveOccurred ())
1143+ 			Expect (tokenizeResp .Count ).To (Equal (4 ))
1144+ 			Expect (tokenizeResp .Tokens ).To (HaveLen (4 ))
1145+ 			Expect (tokenizeResp .MaxModelLen ).To (Equal (2048 ))
1146+ 		})
1147+ 	})
10711148})
0 commit comments