Skip to content

fix: first hash of prefix cache with same model name #1341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,16 +257,21 @@ func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize i
// If the last block is smaller than cacheBlockSize, it will be ignored.
res := make([]BlockHash, 0, 1+len(prompt)/cacheBlockSize)
// Add the model to the first block hash so that different models have different hashes even with the same body.
if len(prompt) >= cacheBlockSize {
if len(prompt) < cacheBlockSize {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit nitpicking but can we just converge the if/else to the below. It's much concise and more readable in my opinion.

firstBlockSize = cacheBlockSize
if len(prompt) < cacheBlockSize {
  firstBlockSize = len(prompt)
}
firstBlock := prompt[0:firstBlockSize]
firstBlockWithModel := append([]byte(request.TargetModel), firstBlock...)
res = append(res, BlockHash(xxhash.Sum64(firstBlockWithModel)))

firstBlockSize := len(prompt)
firstBlock := prompt[0:firstBlockSize]
firstBlockWithModel := append([]byte(request.TargetModel), firstBlock...)
res = append(res, BlockHash(xxhash.Sum64(firstBlockWithModel)))
} else {
firstBlock := prompt[0:cacheBlockSize]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation will lead to no hash if the prompt is smaller than cacheBlockSize. I suggest the below modification instead:

firstBlockSize = cacheBlockSize
if len(prompt) < cacheBlockSize {
  firstBlockSize = len(prompt)
}
firstBlock := prompt[0:firstBlockSize]
firstBlockWithModel := append([]byte(request.TargetModel), firstBlock...)
res = append(res, BlockHash(xxhash.Sum64(firstBlockWithModel)))

combined := append([]byte(request.TargetModel), firstBlock...)
res = append(res, BlockHash(xxhash.Sum64(combined)))
}
for i := cacheBlockSize; i+cacheBlockSize <= len(prompt); i += cacheBlockSize {
block := prompt[i : i+cacheBlockSize]
prevBlockHash := res[len(res)-1]
block = append(block, toBytes(prevBlockHash)...)
res = append(res, BlockHash(xxhash.Sum64(block)))
for i := cacheBlockSize; i+cacheBlockSize <= len(prompt); i += cacheBlockSize {
block := prompt[i : i+cacheBlockSize]
prevBlockHash := res[len(res)-1]
block = append(block, toBytes(prevBlockHash)...)
res = append(res, BlockHash(xxhash.Sum64(block)))
}
}
return res
}
Expand Down
24 changes: 12 additions & 12 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ func TestPrefixPlugin(t *testing.T) {
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
// Total hashes = 2 (the first one is for the model)
assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect")
// Total hashes = 1 (the first one is for the prefix with model)
assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect")
assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers")
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
Expand All @@ -76,8 +76,8 @@ func TestPrefixPlugin(t *testing.T) {
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
// Total hashes = 2 (the first one is for the model)
assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect")
// Total hashes = 1 (the first one is for the prefix with model)
assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect")
assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers")
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
Expand All @@ -96,10 +96,10 @@ func TestPrefixPlugin(t *testing.T) {
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
// Total hashes = 3 (the first one is for the model)
assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect")
// Total hashes = 2 (the first one is for the prefix with model)
assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect")
assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix")
assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match")
assert.Equal(t, 0.5, scores[pod1], "score should be 0.5 - the model and the first prefix block match")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
Expand All @@ -115,8 +115,8 @@ func TestPrefixPlugin(t *testing.T) {
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
// Total hashes = 3 (the first one is for the model)
assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect")
// Total hashes = 2 (the first one is for the prefix with model)
assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect")
assert.Equal(t, 0, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix")
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
Expand All @@ -134,10 +134,10 @@ func TestPrefixPlugin(t *testing.T) {
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 12, hash block size is 4, so 3 hashes will be calculated.
// Total hashes = 4 (the first one is for the model)
assert.Equal(t, 4, len(state.PrefixHashes), "number of hashes is incorrect")
// Total hashes = 3 (the first one is for the prefix with model)
assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect")
assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix")
assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match")
assert.Equal(t, 2./3, scores[pod1], "score should be 2./3 - the model and the first 2 prefix blocks match")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
Expand Down