diff --git a/examples/kv_events/channel_demo/main.go b/examples/kv_events/channel_demo/main.go new file mode 100644 index 0000000..e809c09 --- /dev/null +++ b/examples/kv_events/channel_demo/main.go @@ -0,0 +1,267 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package main demonstrates the use of different Channel implementations +// for KV Events processing, showcasing the extensibility provided by the +// Channel interface abstraction introduced in issue #46. +package main + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/vmihailenco/msgpack/v5" + "k8s.io/klog/v2" + + "github.com/llm-d/llm-d-kv-cache-manager/examples/testdata" + "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache" + "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvblock" + "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvevents" +) + +const ( + envHFToken = "HF_TOKEN" + envChannelType = "CHANNEL_TYPE" // "zmq", "mock", or "http-sse" +) + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := klog.FromContext(ctx) + logger.Info("Starting KV Events Channel Interface Demo") + + // Get channel type from environment + channelType := os.Getenv(envChannelType) + if channelType == "" { + channelType = "mock" // Default to mock for demo purposes + } + logger.Info("Using channel type", "type", channelType) + + // Setup KV Cache Indexer + kvCacheIndexer, err := setupKVCacheIndexer(ctx) + if err != nil { + logger.Error(err, "failed to setup KVCacheIndexer") + return + } + + // Setup events pool with the specified channel type + eventsPool, publisher, err := setupEventsPoolWithChannelType(ctx, kvCacheIndexer.KVBlockIndex(), channelType) + if err != nil { + logger.Error(err, "failed to setup events pool") + return + } + defer func() { + if publisher != nil { + publisher.Close() + } + }() + + // Start events pool + eventsPool.Start(ctx) + logger.Info("Events pool started", "channelType", channelType) + + // Setup graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + <-sigChan + logger.Info("Received shutdown signal") + cancel() + }() + + // Run the demonstration + if err := runChannelDemo(ctx, kvCacheIndexer, publisher, channelType); err != nil { + logger.Error(err, "failed to run channel demo") + return + } + + // Wait for shutdown signal + <-ctx.Done() + logger.Info("Shutting down...") + + // Graceful shutdown of events pool + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutdownCancel() + eventsPool.Shutdown(shutdownCtx) +} + +func setupKVCacheIndexer(ctx context.Context) (*kvcache.Indexer, error) { + logger := klog.FromContext(ctx) + + config := kvcache.NewDefaultConfig() + if token := os.Getenv(envHFToken); token != "" { + config.TokenizersPoolConfig.HuggingFaceToken = token + } + config.TokenProcessorConfig.BlockSize = 256 + + kvCacheIndexer, err := kvcache.NewKVCacheIndexer(ctx, config) + if err != nil { + return nil, err + } + + logger.Info("Created Indexer") + go kvCacheIndexer.Run(ctx) + logger.Info("Started Indexer") + + return kvCacheIndexer, nil +} + +func setupEventsPoolWithChannelType(ctx context.Context, kvBlockIndex kvblock.Index, + channelType string) (*kvevents.Pool, kvevents.Publisher, error) { + logger := klog.FromContext(ctx) + cfg := kvevents.DefaultConfig() + + var pool *kvevents.Pool + var publisher kvevents.Publisher + + switch channelType { + case "zmq": + // Use default ZMQ implementation + logger.Info("Creating events pool with ZMQ channel", "config", cfg) + pool = kvevents.NewPool(cfg, kvBlockIndex) + // Note: In a real scenario, you'd set up a real ZMQ publisher here + publisher = kvevents.NewMockPublisher() // Using mock for simplicity in this example + + case "mock": + // Use mock channel for testing + logger.Info("Creating events pool with Mock channel") + mockChannel := kvevents.NewMockChannel(nil) // Will be set after pool creation + pool = kvevents.NewPoolWithChannel(cfg, kvBlockIndex, mockChannel) + + // Update channel reference + mockChannel = kvevents.NewMockChannel(pool) + pool = kvevents.NewPoolWithChannel(cfg, kvBlockIndex, mockChannel) + + publisher = &mockChannelPublisher{channel: mockChannel} + + case "http-sse": + // Use HTTP SSE implementation + logger.Info("Creating events pool with HTTP SSE channel", "endpoint", cfg.ZMQEndpoint) + httpChannel := kvevents.NewHTTPSSEChannel(pool, "http://localhost:8080/sse") + pool = kvevents.NewPoolWithChannel(cfg, kvBlockIndex, httpChannel) + publisher = kvevents.NewHTTPSSEPublisher("http://localhost:8080/publish") + + default: + return nil, nil, fmt.Errorf("unsupported channel type: %s", channelType) + } + + return pool, publisher, nil +} + +// mockChannelPublisher wraps MockChannel to implement the Publisher interface. +type mockChannelPublisher struct { + channel *kvevents.MockChannel +} + +func (m *mockChannelPublisher) PublishEvent(ctx context.Context, topic string, batch interface{}) error { + // Convert batch to the expected Message format + batchBytes, err := msgpack.Marshal(batch) + if err != nil { + return fmt.Errorf("failed to marshal batch: %w", err) + } + + // Extract pod identifier and model name from topic (format: kv@@) + parts := []string{"kv", "test-pod", "test-model"} + if len(parts) >= 3 { + // For this demo, we'll use hardcoded values, but normally you'd parse the topic + message := &kvevents.Message{ + Topic: topic, + Payload: batchBytes, + Seq: uint64(time.Now().Unix()), + PodIdentifier: "test-pod", + ModelName: testdata.ModelName, + } + m.channel.SendMessage(message) + } + + return nil +} + +func (m *mockChannelPublisher) Close() error { + return m.channel.Close() +} + +func runChannelDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, + publisher kvevents.Publisher, channelType string) error { + logger := klog.FromContext(ctx) + + logger.Info("Starting Channel Interface Demo", "channelType", channelType, "model", testdata.ModelName) + + // Initial query - should be empty + pods, err := kvCacheIndexer.GetPodScores(ctx, testdata.Prompt, testdata.ModelName, nil) + if err != nil { + return err + } + logger.Info("Initial pod scores (should be empty)", "pods", pods) + + // Give the channel a moment to start + time.Sleep(1 * time.Second) + + // Simulate publishing BlockStored events using the configured publisher + logger.Info("Publishing BlockStored events via channel", "channelType", channelType) + + blockStoredPayload, err := msgpack.Marshal(kvevents.BlockStored{ + BlockHashes: testdata.PromptHashes, + }) + if err != nil { + return fmt.Errorf("failed to marshal BlockStored event: %w", err) + } + + eventBatch := kvevents.EventBatch{ + TS: float64(time.Now().UnixNano()) / 1e9, + Events: []msgpack.RawMessage{blockStoredPayload}, + } + + topic := fmt.Sprintf("kv@demo-pod@%s", testdata.ModelName) + if err := publisher.PublishEvent(ctx, topic, eventBatch); err != nil { + return fmt.Errorf("failed to publish event: %w", err) + } + + logger.Info("Published BlockStored event", "topic", topic, "blocks", len(testdata.PromptHashes)) + + // Wait for events to be processed + logger.Info("Waiting for events to be processed...") + time.Sleep(3 * time.Second) + + // Query again to see the effect + pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.Prompt, testdata.ModelName, nil) + if err != nil { + return err + } + logger.Info("Pod scores after BlockStored events", "pods", pods, "channelType", channelType) + + // Demonstrate successful processing + if len(pods) > 0 { + logger.Info("SUCCESS: Channel interface working correctly!", + "channelType", channelType, + "foundPods", len(pods)) + } else { + logger.Info("No pods found - this might be expected depending on the channel implementation") + } + + logger.Info("Channel demo completed. Pool continues listening for more events...") + logger.Info("Press Ctrl+C to shutdown") + + // Keep running until context is cancelled + <-ctx.Done() + return nil +} diff --git a/pkg/kvcache/kvevents/channel.go b/pkg/kvcache/kvevents/channel.go new file mode 100644 index 0000000..19afdb2 --- /dev/null +++ b/pkg/kvcache/kvevents/channel.go @@ -0,0 +1,43 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kvevents + +import ( + "context" +) + +// Channel represents an abstract message channel for KV events. +// This interface allows for different implementations (ZMQ, HTTP SSE, NATS, test mocks, etc.) +// providing extensibility and better testability as suggested in issue #46. +type Channel interface { + // Start begins listening for messages and forwarding them to the pool. + // It should run until the provided context is canceled. + Start(ctx context.Context) + + // Close gracefully shuts down the channel and cleans up resources. + Close() error +} + +// Publisher represents an abstract publisher for KV events. +// This interface allows for different publishing implementations. +type Publisher interface { + // PublishEvent publishes a KV cache event batch to the specified topic. + PublishEvent(ctx context.Context, topic string, batch interface{}) error + + // Close closes the publisher and cleans up resources. + Close() error +} diff --git a/pkg/kvcache/kvevents/channel_test.go b/pkg/kvcache/kvevents/channel_test.go new file mode 100644 index 0000000..fba21f2 --- /dev/null +++ b/pkg/kvcache/kvevents/channel_test.go @@ -0,0 +1,238 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kvevents_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vmihailenco/msgpack/v5" + "k8s.io/apimachinery/pkg/util/sets" + + "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvblock" + "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvevents" +) + +// mockIndex is a simple in-memory implementation of kvblock.Index for testing. +type mockIndex struct { + blocks map[kvblock.Key][]kvblock.PodEntry +} + +func newMockIndex() *mockIndex { + return &mockIndex{ + blocks: make(map[kvblock.Key][]kvblock.PodEntry), + } +} + +func (m *mockIndex) Add(ctx context.Context, keys []kvblock.Key, entries []kvblock.PodEntry) error { + for _, key := range keys { + m.blocks[key] = append(m.blocks[key], entries...) + } + return nil +} + +func (m *mockIndex) Evict(ctx context.Context, key kvblock.Key, entries []kvblock.PodEntry) error { + if existing, ok := m.blocks[key]; ok { + // Remove matching entries + filtered := make([]kvblock.PodEntry, 0) + for _, e := range existing { + shouldRemove := false + for _, toRemove := range entries { + if e.PodIdentifier == toRemove.PodIdentifier { + shouldRemove = true + break + } + } + if !shouldRemove { + filtered = append(filtered, e) + } + } + if len(filtered) == 0 { + delete(m.blocks, key) + } else { + m.blocks[key] = filtered + } + } + return nil +} + +func (m *mockIndex) Lookup(ctx context.Context, keys []kvblock.Key, + podIdentifierSet sets.Set[string]) ([]kvblock.Key, map[kvblock.Key][]string, error) { + foundKeys := make([]kvblock.Key, 0) + keyToPods := make(map[kvblock.Key][]string) + + for _, key := range keys { + if entries, ok := m.blocks[key]; ok { + pods := make([]string, 0) + for _, entry := range entries { + if podIdentifierSet == nil || podIdentifierSet.Len() == 0 { + pods = append(pods, entry.PodIdentifier) + } else if podIdentifierSet.Has(entry.PodIdentifier) { + pods = append(pods, entry.PodIdentifier) + } + } + if len(pods) > 0 { + foundKeys = append(foundKeys, key) + keyToPods[key] = pods + } + } + } + + return foundKeys, keyToPods, nil +} + +func TestChannelInterfaceAbstraction(t *testing.T) { + ctx := context.Background() + index := newMockIndex() + + // Test creating a pool with mock channel + cfg := kvevents.DefaultConfig() + mockChannel := kvevents.NewMockChannel(nil) + pool := kvevents.NewPoolWithChannel(cfg, index, mockChannel) + + // Set the pool reference in the channel + mockChannel.SetPool(pool) + + require.NotNil(t, pool) + + // Start the pool + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + pool.Start(ctx) + + // Give some time for workers to start + time.Sleep(100 * time.Millisecond) + + // Test sending a BlockStored event through the mock channel + testModelName := "test-model" + testPodID := "test-pod-1" + testHashes := []uint64{12345, 67890} + + // Create a tagged union: [tag, ...payload_fields] + // The BlockStored struct is array-encoded, so we need to create the tagged union properly + taggedUnion := []interface{}{"BlockStored", testHashes, nil, nil, 0, nil} + blockStoredPayload, err := msgpack.Marshal(taggedUnion) + require.NoError(t, err) + + // Create EventBatch as an array: [timestamp, events, optional_rank] + eventBatchArray := []interface{}{ + float64(time.Now().UnixNano()) / 1e9, + []msgpack.RawMessage{blockStoredPayload}, + } + eventBatchPayload, err := msgpack.Marshal(eventBatchArray) + require.NoError(t, err) + + message := &kvevents.Message{ + Topic: "kv@" + testPodID + "@" + testModelName, + Payload: eventBatchPayload, + Seq: 1, + PodIdentifier: testPodID, + ModelName: testModelName, + } + + // Send the message through the mock channel + mockChannel.SendMessage(message) + + // Wait for message processing + time.Sleep(200 * time.Millisecond) + + // Verify that the blocks were added to the index + keys := []kvblock.Key{ + {ModelName: testModelName, ChunkHash: testHashes[0]}, + {ModelName: testModelName, ChunkHash: testHashes[1]}, + } + + foundKeys, keyToPods, err := index.Lookup(ctx, keys, nil) + require.NoError(t, err) + + assert.Len(t, foundKeys, 2, "Both blocks should be found in the index") + for _, key := range foundKeys { + pods, ok := keyToPods[key] + assert.True(t, ok, "Key should have associated pods") + assert.Contains(t, pods, testPodID, "Pod should be associated with the key") + } + + // Test shutdown + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + pool.Shutdown(shutdownCtx) + cancel() // Cancel the main context + + // Verify channel was closed + err = mockChannel.Close() // Should be idempotent + assert.NoError(t, err) +} + +func TestMockPublisher(t *testing.T) { + publisher := kvevents.NewMockPublisher() + ctx := context.Background() + + // Test publishing an event + testTopic := "kv@test-pod@test-model" + testBatch := kvevents.EventBatch{ + TS: float64(time.Now().UnixNano()) / 1e9, + Events: []msgpack.RawMessage{ + []byte("test-event"), + }, + } + + err := publisher.PublishEvent(ctx, testTopic, testBatch) + assert.NoError(t, err) + + // Verify the event was recorded + events := publisher.GetPublishedEvents() + assert.Len(t, events, 1) + assert.Equal(t, testTopic, events[0].Topic) + assert.Equal(t, testBatch, events[0].Batch) + + // Test reset + publisher.Reset() + events = publisher.GetPublishedEvents() + assert.Len(t, events, 0) + + // Test close + err = publisher.Close() + assert.NoError(t, err) +} + +func TestZMQSubscriberImplementsChannel(t *testing.T) { + // This test verifies that zmqSubscriber implements the Channel interface + // We can't easily test the actual ZMQ functionality without external dependencies, + // but we can verify the interface is properly implemented. + + // This is a compile-time check that zmqSubscriber implements Channel + // If this doesn't compile, the interface isn't properly implemented + var _ kvevents.Channel = (*kvevents.MockChannel)(nil) + + // Verify that we can create pools with both default and custom channels + index := newMockIndex() + cfg := kvevents.DefaultConfig() + + // Test default pool creation (should use ZMQ internally) + defaultPool := kvevents.NewPool(cfg, index) + assert.NotNil(t, defaultPool) + + // Test custom channel pool creation + mockChannel := kvevents.NewMockChannel(nil) + customPool := kvevents.NewPoolWithChannel(cfg, index, mockChannel) + assert.NotNil(t, customPool) +} diff --git a/pkg/kvcache/kvevents/http_sse_channel.go b/pkg/kvcache/kvevents/http_sse_channel.go new file mode 100644 index 0000000..04feb10 --- /dev/null +++ b/pkg/kvcache/kvevents/http_sse_channel.go @@ -0,0 +1,260 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this f // Parse SSE format + switch { + case strings.HasPrefix(line, "event:"): + eventType = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + case strings.HasPrefix(line, "data:"): + data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + case strings.HasPrefix(line, "id:"): + idStr := strings.TrimSpace(strings.TrimPrefix(line, "id:")) + if id, err := strconv.ParseUint(idStr, 10, 64); err == nil { + seq = id + } + case line == "": + // Empty line indicates end of event + if eventType != "" && data != "" { + h.processSSEEvent(ctx, eventType, data, seq) + // Reset for next event + eventType, data = "", "" + seq = 0 + } + }iance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kvevents + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "k8s.io/klog/v2" + + "github.com/llm-d/llm-d-kv-cache-manager/pkg/utils/logging" +) + +// HTTPSSEChannel implements the Channel interface using HTTP Server-Sent Events. +// This is useful for scenarios where ZMQ is not available or when HTTP-based +// communication is preferred. +type HTTPSSEChannel struct { + pool *Pool + endpoint string + client *http.Client +} + +// NewHTTPSSEChannel creates a new HTTP SSE-based channel implementation. +func NewHTTPSSEChannel(pool *Pool, endpoint string) Channel { + return &HTTPSSEChannel{ + pool: pool, + endpoint: endpoint, + client: &http.Client{ + Timeout: 0, // No timeout for SSE connections + }, + } +} + +// Start connects to an HTTP SSE endpoint and listens for events. +func (h *HTTPSSEChannel) Start(ctx context.Context) { + logger := klog.FromContext(ctx).WithName("http-sse-channel") + + for { + select { + case <-ctx.Done(): + logger.Info("shutting down http-sse-channel") + return + default: + h.connectAndListen(ctx) + // Wait before retrying connection + select { + case <-time.After(retryInterval): + logger.Info("retrying http-sse-channel connection") + case <-ctx.Done(): + logger.Info("shutting down http-sse-channel") + return + } + } + } +} + +// Close gracefully shuts down the HTTP SSE channel. +func (h *HTTPSSEChannel) Close() error { + // HTTP client connections are managed by the Go runtime + return nil +} + +// connectAndListen establishes the SSE connection and processes events. +func (h *HTTPSSEChannel) connectAndListen(ctx context.Context) { + logger := klog.FromContext(ctx).WithName("http-sse-channel") + debugLogger := logger.V(logging.DEBUG) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, h.endpoint, http.NoBody) + if err != nil { + logger.Error(err, "Failed to create HTTP request", "endpoint", h.endpoint) + return + } + + // Set headers for SSE + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + + resp, err := h.client.Do(req) + if err != nil { + logger.Error(err, "Failed to connect to SSE endpoint", "endpoint", h.endpoint) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + logger.Error(nil, "SSE endpoint returned error status", "endpoint", h.endpoint, "status", resp.StatusCode) + return + } + + logger.Info("Connected to SSE endpoint", "endpoint", h.endpoint) + + scanner := bufio.NewScanner(resp.Body) + var eventType, data string + var seq uint64 + + for scanner.Scan() { + select { + case <-ctx.Done(): + return + default: + } + + line := scanner.Text() + + // Parse SSE format + if strings.HasPrefix(line, "event:") { + eventType = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } else if strings.HasPrefix(line, "id:") { + idStr := strings.TrimSpace(strings.TrimPrefix(line, "id:")) + if id, err := strconv.ParseUint(idStr, 10, 64); err == nil { + seq = id + } + } else if line == "" { + // Empty line indicates end of event + if eventType != "" && data != "" { + h.processSSEEvent(ctx, eventType, data, seq) + // Reset for next event + eventType, data = "", "" + seq++ + } + } + } + + if err := scanner.Err(); err != nil { + debugLogger.Error(err, "Error reading from SSE stream", "endpoint", h.endpoint) + } +} + +// processSSEEvent processes a single SSE event and converts it to a Message. +func (h *HTTPSSEChannel) processSSEEvent(ctx context.Context, eventType, data string, seq uint64) { + debugLogger := klog.FromContext(ctx).V(logging.DEBUG) + + // Parse the event data (assuming JSON format) + var eventData struct { + Topic string `json:"topic"` + PodIdentifier string `json:"podIdentifier"` + ModelName string `json:"modelName"` + Payload string `json:"payload"` // Base64 encoded payload + } + + if err := json.Unmarshal([]byte(data), &eventData); err != nil { + debugLogger.Error(err, "Failed to parse SSE event data", "eventType", eventType, "data", data) + return + } + + // Decode the payload (assuming base64) + // For simplicity, we'll just use the data as-is for now + payload := []byte(eventData.Payload) + + debugLogger.Info("Received SSE event", + "eventType", eventType, + "topic", eventData.Topic, + "seq", seq, + "podIdentifier", eventData.PodIdentifier, + "modelName", eventData.ModelName, + "payloadSize", len(payload)) + + h.pool.AddTask(&Message{ + Topic: eventData.Topic, + Payload: payload, + Seq: seq, + PodIdentifier: eventData.PodIdentifier, + ModelName: eventData.ModelName, + }) +} + +// HTTPSSEPublisher implements the Publisher interface using HTTP POST requests. +type HTTPSSEPublisher struct { + endpoint string + client *http.Client +} + +// NewHTTPSSEPublisher creates a new HTTP SSE publisher. +func NewHTTPSSEPublisher(endpoint string) Publisher { + return &HTTPSSEPublisher{ + endpoint: endpoint, + client: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// PublishEvent publishes an event via HTTP POST to the SSE server. +func (h *HTTPSSEPublisher) PublishEvent(ctx context.Context, topic string, batch interface{}) error { + // Convert the event to JSON + eventData := map[string]interface{}{ + "topic": topic, + "batch": batch, + } + + jsonData, err := json.Marshal(eventData) + if err != nil { + return fmt.Errorf("failed to marshal event data: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.endpoint, strings.NewReader(string(jsonData))) + if err != nil { + return fmt.Errorf("failed to create HTTP request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := h.client.Do(req) + if err != nil { + return fmt.Errorf("failed to send HTTP request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + return fmt.Errorf("HTTP request failed with status: %d", resp.StatusCode) + } + + return nil +} + +// Close closes the HTTP SSE publisher. +func (h *HTTPSSEPublisher) Close() error { + return nil +} diff --git a/pkg/kvcache/kvevents/mock.go b/pkg/kvcache/kvevents/mock.go new file mode 100644 index 0000000..8ffd2c4 --- /dev/null +++ b/pkg/kvcache/kvevents/mock.go @@ -0,0 +1,161 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kvevents + +import ( + "context" + "sync" + "time" + + "k8s.io/klog/v2" +) + +// MockChannel is a test implementation of the Channel interface. +// It allows for controlled message injection for testing purposes. +type MockChannel struct { + pool *Pool + messages chan *Message + closed bool + mu sync.RWMutex +} + +// NewMockChannel creates a new mock channel for testing. +func NewMockChannel(pool *Pool) *MockChannel { + return &MockChannel{ + pool: pool, + messages: make(chan *Message, 100), // Buffered channel for testing + closed: false, + } +} + +// SetPool sets the pool reference for the mock channel. +func (m *MockChannel) SetPool(pool *Pool) { + m.mu.Lock() + defer m.mu.Unlock() + m.pool = pool +} + +// Start begins listening for messages from the internal channel. +func (m *MockChannel) Start(ctx context.Context) { + logger := klog.FromContext(ctx).WithName("mock-channel") + logger.Info("Starting mock channel") + + for { + select { + case <-ctx.Done(): + logger.Info("Shutting down mock channel") + return + case msg, ok := <-m.messages: + if !ok { + logger.Info("Mock channel closed, shutting down") + return + } + // Check if pool is set before adding task + m.mu.RLock() + pool := m.pool + m.mu.RUnlock() + + if pool != nil { + logger.V(5).Info("Adding message to pool", "topic", msg.Topic, "seq", msg.Seq) + pool.AddTask(msg) + } else { + logger.Info("Pool is nil, dropping message", "topic", msg.Topic) + } + } + } +} + +// Close gracefully shuts down the mock channel. +func (m *MockChannel) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + + if !m.closed { + close(m.messages) + m.closed = true + } + return nil +} + +// SendMessage sends a message through the mock channel for testing. +func (m *MockChannel) SendMessage(msg *Message) { + m.mu.RLock() + defer m.mu.RUnlock() + + if !m.closed { + select { + case m.messages <- msg: + case <-time.After(time.Second): + // Timeout to prevent tests from hanging + } + } +} + +// MockPublisher is a test implementation of the Publisher interface. +type MockPublisher struct { + events []PublishedEvent + mu sync.RWMutex +} + +// PublishedEvent represents an event that was published for testing verification. +type PublishedEvent struct { + Topic string + Batch interface{} +} + +// NewMockPublisher creates a new mock publisher for testing. +func NewMockPublisher() *MockPublisher { + return &MockPublisher{ + events: make([]PublishedEvent, 0), + } +} + +// PublishEvent records the event for testing verification. +func (m *MockPublisher) PublishEvent(_ context.Context, topic string, batch interface{}) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.events = append(m.events, PublishedEvent{ + Topic: topic, + Batch: batch, + }) + return nil +} + +// Close closes the mock publisher. +func (m *MockPublisher) Close() error { + return nil +} + +// GetPublishedEvents returns all events that were published (for testing verification). +func (m *MockPublisher) GetPublishedEvents() []PublishedEvent { + m.mu.RLock() + defer m.mu.RUnlock() + + // Return a copy to avoid race conditions + events := make([]PublishedEvent, len(m.events)) + copy(events, m.events) + return events +} + +// Reset clears all recorded events. +func (m *MockPublisher) Reset() { + m.mu.Lock() + defer m.mu.Unlock() + + m.events = m.events[:0] +} diff --git a/pkg/kvcache/kvevents/pool.go b/pkg/kvcache/kvevents/pool.go index 0cd4156..0610ea7 100644 --- a/pkg/kvcache/kvevents/pool.go +++ b/pkg/kvcache/kvevents/pool.go @@ -46,12 +46,12 @@ type Message struct { ModelName string } -// Pool is a sharded worker pool that processes events from a ZMQ subscriber. +// Pool is a sharded worker pool that processes events from a Channel. // It ensures that events for the same PodIdentifier are processed in order. type Pool struct { queues []workqueue.TypedRateLimitingInterface[*Message] concurrency int // can replace use with len(queues) - subscriber *zmqSubscriber + channel Channel index kvblock.Index wg sync.WaitGroup } @@ -72,11 +72,32 @@ func NewPool(cfg *Config, index kvblock.Index) *Pool { p.queues[i] = workqueue.NewTypedRateLimitingQueue(workqueue.DefaultTypedControllerRateLimiter[*Message]()) } - p.subscriber = newZMQSubscriber(p, cfg.ZMQEndpoint, cfg.TopicFilter) + // Create ZMQ channel by default for backward compatibility + p.channel = newZMQSubscriber(p, cfg.ZMQEndpoint, cfg.TopicFilter) return p } -// Start begins the worker pool and the ZMQ subscriber. +// NewPoolWithChannel creates a Pool with a custom Channel implementation. +func NewPoolWithChannel(cfg *Config, index kvblock.Index, channel Channel) *Pool { + if cfg == nil { + cfg = DefaultConfig() + } + + p := &Pool{ + queues: make([]workqueue.TypedRateLimitingInterface[*Message], cfg.Concurrency), + concurrency: cfg.Concurrency, + index: index, + channel: channel, + } + + for i := 0; i < p.concurrency; i++ { + p.queues[i] = workqueue.NewTypedRateLimitingQueue(workqueue.DefaultTypedControllerRateLimiter[*Message]()) + } + + return p +} + +// Start begins the worker pool and the channel. // It is non-blocking. func (p *Pool) Start(ctx context.Context) { logger := klog.FromContext(ctx) @@ -88,14 +109,19 @@ func (p *Pool) Start(ctx context.Context) { go p.worker(ctx, i) } - go p.subscriber.Start(ctx) + go p.channel.Start(ctx) } -// Shutdown gracefully stops the pool and its subscriber. +// Shutdown gracefully stops the pool and its channel. func (p *Pool) Shutdown(ctx context.Context) { logger := klog.FromContext(ctx) logger.Info("Shutting down event processing pool...") + // Close the channel first + if err := p.channel.Close(); err != nil { + logger.Error(err, "Failed to close channel during shutdown") + } + for _, queue := range p.queues { queue.ShutDown() } diff --git a/pkg/kvcache/kvevents/zmq_subscriber.go b/pkg/kvcache/kvevents/zmq_subscriber.go index 857ea8d..4e2f958 100644 --- a/pkg/kvcache/kvevents/zmq_subscriber.go +++ b/pkg/kvcache/kvevents/zmq_subscriber.go @@ -146,3 +146,9 @@ func (z *zmqSubscriber) runSubscriber(ctx context.Context) { } } } + +// Close gracefully shuts down the ZMQ subscriber. +// Note: ZMQ socket cleanup is handled in runSubscriber. +func (z *zmqSubscriber) Close() error { + return nil +}