Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions cmd/batch/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package main

import (
"os"

ctrl "sigs.k8s.io/controller-runtime"

"github.com/llm-d/llm-d-inference-scheduler/cmd/batch/runner"
)

func main() {

if err := runner.NewRunner().Run(ctrl.SetupSignalHandler()); err != nil {
os.Exit(1)
}
}
99 changes: 99 additions & 0 deletions cmd/batch/runner/runner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package runner

import (
"context"
"flag"
"net/http"

"github.com/llm-d/llm-d-inference-scheduler/pkg/batch"
"github.com/llm-d/llm-d-inference-scheduler/pkg/batch/redis"
uberzap "go.uber.org/zap"
"go.uber.org/zap/zapcore"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

type Runner struct {
}

var (
setupLog = ctrl.Log.WithName("setup")
logVerbosity = flag.Int("v", logging.DEFAULT, "number for the log level verbosity")
concurrency = flag.Int("concurrency", 8, "number of concurrent workers")
endpoint = flag.String("endpoint", "", "inference endpoint")
)

func NewRunner() *Runner {
return &Runner{}
}

func (r *Runner) Run(ctx context.Context) error {
opts := zap.Options{
Development: true,
}
opts.BindFlags(flag.CommandLine)
flag.Parse()
initLogging(&opts)

/*if *tracing {
err := common.InitTracing(ctx, setupLog)
if err != nil {
return err
}
}*/

////////setupLog.Info("GIE build", "commit-sha", version.CommitSHA, "build-ref", version.BuildRef)

// Validate flags
if err := validateFlags(); err != nil {
setupLog.Error(err, "Failed to validate flags")
return err
}

// Print all flag values
flags := make(map[string]any)
flag.VisitAll(func(f *flag.Flag) {
flags[f.Name] = f.Value
})
setupLog.Info("Flags processed", "flags", flags)

httpClient := &http.Client{
// TODO: configure
}
var policy batch.RequestPolicy = batch.NewRandomRobinPolicy()

var impl batch.Flow = redis.NewRedisMQFlow("localhost:6379")
requestChannel := policy.MergeRequestChannels(impl.RequestChannels()).Channel
for w := 1; w <= *concurrency; w++ {
go batch.Worker(ctx, *endpoint, httpClient, requestChannel, impl.RetryChannel(), impl.ResultChannel())
}

impl.Start(ctx)

return nil
}

// TODO: is this dup of
func initLogging(opts *zap.Options) {
// Unless -zap-log-level is explicitly set, use -v
useV := true
flag.Visit(func(f *flag.Flag) {
if f.Name == "zap-log-level" {
useV = false
}
})
if useV {
// See https://pkg.go.dev/sigs.k8s.io/controller-runtime/pkg/log/zap#Options.Level
lvl := -1 * (*logVerbosity)
opts.Level = uberzap.NewAtomicLevelAt(zapcore.Level(int8(lvl)))
}

logger := zap.New(zap.UseFlagOptions(opts), zap.RawZapOpts(uberzap.AddCaller()))
ctrl.SetLogger(logger)
}

func validateFlags() error {

return nil
}
61 changes: 61 additions & 0 deletions pkg/batch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Batch Processor

## Overview
The batch processor (BP) provides asynchronous workflows for variable SLO-based inference requests.


## Architecture

An underlying implementation should provide persistent messaging that adhere to the interface defined in [api.go](api.go).

A pluggable request policy is used to merge multiple request channels into a single request channel on which the batch worker is listening.

An example for such a policy is a [Random Robin Policy](random_robin_policy.go).

Each [Batch Processor worker](worker.go) is responsible for pulling requests from the merged request channel, submit to the IGW and apply retry logic if needed.



### Requests

Request messages should have the following format:
```json
{
"id" : "unique identifier for result mapping",
"deadline" : "deadline in Unix seconds",
"payload" : {regular inference payload}
}
```

Example:
```json
{
"id" : "19933123533434",
"deadline" : "1764045130",
"payload": {"model":"food-review","prompt":"hi", "max_tokens":10,"temperature":0}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shimib This only supports /v1/completions. Is that on purpose?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't, endpoint will be configurable. Will be in the user guide

}
```

### Results

Messages on the results channel will have the following structure:

```json
{
"id" : "id mapped to the request",
"payload" : {/*inference payload*/} ,
// or
"error" : "error's reason"
}
```


## Implementations

### Redis

An example implementation based on Redis is provided which behaves as follows:

- Redis Lists as the request queues.
- Redis Sorted Set as the retry exponential backoff implementation.
- Redis List as the result queue.
41 changes: 41 additions & 0 deletions pkg/batch/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package batch

import "context"

type Flow interface {
// starts processing requests.
Start(ctx context.Context)

// returns the channel for requests. Implementation is responsible for populating this channel.
RequestChannels() []RequestChannel
// returns the channel that accepts messages to be retries with their backoff delay.
RetryChannel() chan RetryMessage
// returns the channel for storing the results.
ResultChannel() chan ResultMessage
}

type RequestPolicy interface {
MergeRequestChannels(channels []RequestChannel) RequestChannel
}

type RequestMessage struct {
Id string `json:"id"`
RetryCount int `json:"retry_count,omitempty"`
DeadlineUnixSec string `json:"deadline"`
Payload map[string]any `json:"payload"`
}

type RequestChannel struct {
Channel chan RequestMessage
Metadata map[string]any
}

type RetryMessage struct {
RequestMessage
BackoffDurationSeconds float64
}

type ResultMessage struct {
Id string `json:"id"`
Payload map[string]any `json:"payload"`
}
47 changes: 47 additions & 0 deletions pkg/batch/random_robin_policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package batch

import "reflect"

func NewRandomRobinPolicy() RequestPolicy {
return &RandomRobinPolicy{}
}

type RandomRobinPolicy struct {
}

func (r *RandomRobinPolicy) MergeRequestChannels(channels []RequestChannel) RequestChannel {
mergedChannel := make(chan RequestMessage)

cases := make([]reflect.SelectCase, len(channels))
for i, ch := range channels {
cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch.Channel)}
}

go func() {
for {
i1, val, ok := reflect.Select(cases)
if !ok {
// one of the channels is closed, remove it
newCases := make([]reflect.SelectCase, 0, len(cases)-1)
for i2, c := range cases {
if i2 != i1 {
newCases = append(newCases, c)
}
}
cases = newCases
if len(cases) == 0 {
close(mergedChannel)
break
}
} else {
mergedChannel <- val.Interface().(RequestMessage)
}

}
}()

return RequestChannel{
Channel: mergedChannel,
Metadata: map[string]any{},
}
}
41 changes: 41 additions & 0 deletions pkg/batch/random_robin_policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package batch

import (
"testing"
)

func TestProcessAllChannels(t *testing.T) {
msgsPerChannel := 5
channels := []RequestChannel{
{Channel: make(chan RequestMessage, msgsPerChannel), Metadata: map[string]any{}},
{Channel: make(chan RequestMessage, msgsPerChannel), Metadata: map[string]any{}},
{Channel: make(chan RequestMessage, msgsPerChannel), Metadata: map[string]any{}},
}
policy := NewRandomRobinPolicy()

// Send messages to each channel
for i, ch := range channels {
for range msgsPerChannel {
ch.Channel <- RequestMessage{Id: string(rune('A' + i))}
}
}
mergedChannel := policy.MergeRequestChannels(channels).Channel
close(channels[0].Channel)
close(channels[1].Channel)
close(channels[2].Channel)

counts := map[string]int{}
totalMessages := msgsPerChannel * 3
for range totalMessages {
msg := <-mergedChannel
counts[msg.Id]++

}

for i := range 3 {
id := string(rune('A' + i))
if counts[id] != msgsPerChannel {
t.Errorf("Expected %d messages from channel %s, got %d", msgsPerChannel, id, counts[id])
}
}
}
Loading