diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 9c1a8d844..8a581ae1e 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -257,8 +257,16 @@ func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize i // If the last block is smaller than cacheBlockSize, it will be ignored. res := make([]BlockHash, 0, 1+len(prompt)/cacheBlockSize) // Add the model to the first block hash so that different models have different hashes even with the same body. - res = append(res, BlockHash(xxhash.Sum64String(request.TargetModel))) - for i := 0; i+cacheBlockSize <= len(prompt); i += cacheBlockSize { + + firstBlockSize := cacheBlockSize + if len(prompt) < cacheBlockSize { + firstBlockSize = len(prompt) + } + firstBlock := prompt[0:firstBlockSize] + firstBlockWithModel := append([]byte(request.TargetModel), firstBlock...) + res = append(res, BlockHash(xxhash.Sum64(firstBlockWithModel))) + + for i := cacheBlockSize; i+cacheBlockSize <= len(prompt); i += cacheBlockSize { block := prompt[i : i+cacheBlockSize] prevBlockHash := res[len(res)-1] block = append(block, toBytes(prevBlockHash)...) diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index aaf68f0d8..441cbc9d9 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -55,8 +55,8 @@ func TestPrefixPlugin(t *testing.T) { assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 6, hash block size is 4, the last 2 characters are ignored. - // Total hashes = 2 (the first one is for the model) - assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") + // Total hashes = 1 (the first one is for the prefix with model) + assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect") assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers") assert.Equal(t, float64(0), scores[pod1], "score for pod1") assert.Equal(t, float64(0), scores[pod2], "score for pod2") @@ -76,8 +76,8 @@ func TestPrefixPlugin(t *testing.T) { assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 6, hash block size is 4, the last 2 characters are ignored. - // Total hashes = 2 (the first one is for the model) - assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") + // Total hashes = 1 (the first one is for the prefix with model) + assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect") assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers") assert.Equal(t, float64(0), scores[pod1], "score for pod1") assert.Equal(t, float64(0), scores[pod2], "score for pod2") @@ -96,10 +96,10 @@ func TestPrefixPlugin(t *testing.T) { assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 8, hash block size is 4, so 2 hashes will be calculated. - // Total hashes = 3 (the first one is for the model) - assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect") + // Total hashes = 2 (the first one is for the prefix with model) + assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") - assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match") + assert.Equal(t, 0.5, scores[pod1], "score should be 0.5 - the model and the first prefix block match") assert.Equal(t, float64(0), scores[pod2], "score for pod2") plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}}) @@ -115,8 +115,8 @@ func TestPrefixPlugin(t *testing.T) { assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 8, hash block size is 4, so 2 hashes will be calculated. - // Total hashes = 3 (the first one is for the model) - assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect") + // Total hashes = 2 (the first one is for the prefix with model) + assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") assert.Equal(t, 0, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") assert.Equal(t, float64(0), scores[pod1], "score for pod1") assert.Equal(t, float64(0), scores[pod2], "score for pod2") @@ -134,10 +134,10 @@ func TestPrefixPlugin(t *testing.T) { assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 12, hash block size is 4, so 3 hashes will be calculated. - // Total hashes = 4 (the first one is for the model) - assert.Equal(t, 4, len(state.PrefixHashes), "number of hashes is incorrect") + // Total hashes = 3 (the first one is for the prefix with model) + assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect") assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") - assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match") + assert.Equal(t, 2./3, scores[pod1], "score should be 2./3 - the model and the first 2 prefix blocks match") assert.Equal(t, float64(0), scores[pod2], "score for pod2") plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}}) @@ -186,7 +186,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) { // Second cycle: validate internal state state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType) assert.NoError(b, err) - expectedHashes := int(math.Min(float64(maxPrefixBlocks+1), float64(len(req.Prompt)/blockSize+1))) // the extra one is for the model. + expectedHashes := int(math.Min(float64(maxPrefixBlocks), float64(len(req.Prompt)/blockSize))) assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect") } }