@@ -21,12 +21,18 @@ import (
2121 "os"
2222
2323 "github.com/go-logr/logr"
24+ "github.com/llm-d/llm-d-inference-sim/pkg/common"
25+ openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
2426 . "github.com/onsi/ginkgo/v2"
2527 . "github.com/onsi/gomega"
2628
2729 _ "github.com/mattn/go-sqlite3"
2830)
2931
32+ const (
33+ testPrompt = "Hello world!"
34+ )
35+
3036var _ = Describe ("CustomDataset" , func () {
3137 var (
3238 dataset * CustomDataset
@@ -90,20 +96,20 @@ var _ = Describe("CustomDataset", func() {
9096 err := dataset .Init (validDBPath , "" , "" )
9197 Expect (err ).NotTo (HaveOccurred ())
9298
93- row := dataset .db .QueryRow ("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'b94d27b9934d041c52e5b721d7373f13a07ed5e79179d63c5d8a0c102a9d00b2 ';" )
99+ row := dataset .db .QueryRow ("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8 ';" )
94100 var n_gen_tokens int
95101 err = row .Scan (& n_gen_tokens )
96102 Expect (err ).NotTo (HaveOccurred ())
97- Expect (n_gen_tokens ).To (Equal (3 ))
103+ Expect (n_gen_tokens ).To (Equal (4 ))
98104
99105 var jsonStr string
100- row = dataset .db .QueryRow ("SELECT gen_tokens FROM llmd WHERE prompt_hash=X'b94d27b9934d041c52e5b721d7373f13a07ed5e79179d63c5d8a0c102a9d00b2 ';" )
106+ row = dataset .db .QueryRow ("SELECT gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8 ';" )
101107 err = row .Scan (& jsonStr )
102108 Expect (err ).NotTo (HaveOccurred ())
103109 var tokens []string
104110 err = json .Unmarshal ([]byte (jsonStr ), & tokens )
105111 Expect (err ).NotTo (HaveOccurred ())
106- Expect (tokens ).To (Equal ([]string {"Hello" , "world" , "!" }))
112+ Expect (tokens ).To (Equal ([]string {"Hello" , " llm-d " , " world" , "!" }))
107113
108114 })
109115
@@ -136,4 +142,41 @@ var _ = Describe("CustomDataset", func() {
136142 Expect (err ).To (HaveOccurred ())
137143 Expect (err .Error ()).To (ContainSubstring ("incorrect type" ))
138144 })
145+
146+ It ("should return correct prompt hash in bytes" , func () {
147+ // b't\xbf\x14\xc0\x9c\x03\x83!\xcb\xa3\x97\x17\xda\xe1\xdcs(#\xaeJ\xbd\x8e\x15YY6v)\xa3\xc1\t\xa8'
148+ expectedHashBytes := []byte {0x74 , 0xbf , 0x14 , 0xc0 , 0x9c , 0x03 , 0x83 , 0x21 , 0xcb , 0xa3 , 0x97 , 0x17 , 0xda , 0xe1 , 0xdc , 0x73 , 0x28 , 0x23 , 0xae , 0x4a , 0xbd , 0x8e , 0x15 , 0x59 , 0x59 , 0x36 , 0x76 , 0x29 , 0xa3 , 0xc1 , 0x09 , 0xa8 }
149+
150+ req := & openaiserverapi.TextCompletionRequest {
151+ Prompt : testPrompt ,
152+ }
153+
154+ hashBytes := dataset .GetPromptHash (req )
155+ Expect (hashBytes ).To (Equal (expectedHashBytes ))
156+ })
157+
158+ It ("should return correct prompt hash in hex" , func () {
159+ expectedHashHex := "74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8"
160+
161+ req := & openaiserverapi.TextCompletionRequest {
162+ Prompt : testPrompt ,
163+ }
164+
165+ hashBytes := dataset .GetPromptHash (req )
166+ hashHex := dataset .GetPromptHashHex (hashBytes )
167+ Expect (hashHex ).To (Equal (expectedHashHex ))
168+ })
169+
170+ It ("should return tokens for existing prompt" , func () {
171+ err := dataset .Init (validDBPath , "" , "" )
172+ Expect (err ).NotTo (HaveOccurred ())
173+
174+ req := & openaiserverapi.TextCompletionRequest {
175+ Prompt : testPrompt ,
176+ }
177+ tokens , finishReason , err := dataset .GetTokens (req , common .ModeRandom )
178+ Expect (err ).NotTo (HaveOccurred ())
179+ Expect (finishReason ).To (Equal (StopFinishReason ))
180+ Expect (tokens ).To (Equal ([]string {"Hello" , " llm-d " , "world" , "!" }))
181+ })
139182})
0 commit comments