Skip to content

Commit c81d40a

Browse files
committed
initial online batch support
Signed-off-by: Shimi Bandiel <[email protected]>
1 parent e4dbca4 commit c81d40a

File tree

11 files changed

+841
-1
lines changed

11 files changed

+841
-1
lines changed

cmd/batch/main.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package main
2+
3+
import (
4+
"os"
5+
6+
ctrl "sigs.k8s.io/controller-runtime"
7+
8+
"github.com/llm-d/llm-d-inference-scheduler/cmd/batch/runner"
9+
)
10+
11+
func main() {
12+
13+
if err := runner.NewRunner().Run(ctrl.SetupSignalHandler()); err != nil {
14+
os.Exit(1)
15+
}
16+
}

cmd/batch/runner/runner.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package runner
2+
3+
import (
4+
"context"
5+
"flag"
6+
"net/http"
7+
8+
"github.com/llm-d/llm-d-inference-scheduler/pkg/batch"
9+
"github.com/llm-d/llm-d-inference-scheduler/pkg/batch/redis"
10+
uberzap "go.uber.org/zap"
11+
"go.uber.org/zap/zapcore"
12+
ctrl "sigs.k8s.io/controller-runtime"
13+
"sigs.k8s.io/controller-runtime/pkg/log/zap"
14+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
15+
)
16+
17+
type Runner struct {
18+
}
19+
20+
var (
21+
setupLog = ctrl.Log.WithName("setup")
22+
logVerbosity = flag.Int("v", logging.DEFAULT, "number for the log level verbosity")
23+
concurrency = flag.Int("concurrency", 8, "number of concurrent workers")
24+
endpoint = flag.String("endpoint", "", "inference endpoint")
25+
)
26+
27+
func NewRunner() *Runner {
28+
return &Runner{}
29+
}
30+
31+
func (r *Runner) Run(ctx context.Context) error {
32+
opts := zap.Options{
33+
Development: true,
34+
}
35+
opts.BindFlags(flag.CommandLine)
36+
flag.Parse()
37+
initLogging(&opts)
38+
39+
/*if *tracing {
40+
err := common.InitTracing(ctx, setupLog)
41+
if err != nil {
42+
return err
43+
}
44+
}*/
45+
46+
////////setupLog.Info("GIE build", "commit-sha", version.CommitSHA, "build-ref", version.BuildRef)
47+
48+
// Validate flags
49+
if err := validateFlags(); err != nil {
50+
setupLog.Error(err, "Failed to validate flags")
51+
return err
52+
}
53+
54+
// Print all flag values
55+
flags := make(map[string]any)
56+
flag.VisitAll(func(f *flag.Flag) {
57+
flags[f.Name] = f.Value
58+
})
59+
setupLog.Info("Flags processed", "flags", flags)
60+
61+
httpClient := &http.Client{
62+
// TODO: configure
63+
}
64+
var policy batch.RequestPolicy = batch.NewRandomRobinPolicy()
65+
66+
var impl batch.Flow = redis.NewRedisMQFlow("localhost:6379")
67+
requestChannel := policy.MergeRequestChannels(impl.RequestChannels()).Channel
68+
for w := 1; w <= *concurrency; w++ {
69+
go batch.Worker(ctx, *endpoint, httpClient, requestChannel, impl.RetryChannel(), impl.ResultChannel())
70+
}
71+
72+
impl.Start(ctx)
73+
74+
return nil
75+
}
76+
77+
// TODO: is this dup of
78+
func initLogging(opts *zap.Options) {
79+
// Unless -zap-log-level is explicitly set, use -v
80+
useV := true
81+
flag.Visit(func(f *flag.Flag) {
82+
if f.Name == "zap-log-level" {
83+
useV = false
84+
}
85+
})
86+
if useV {
87+
// See https://pkg.go.dev/sigs.k8s.io/controller-runtime/pkg/log/zap#Options.Level
88+
lvl := -1 * (*logVerbosity)
89+
opts.Level = uberzap.NewAtomicLevelAt(zapcore.Level(int8(lvl)))
90+
}
91+
92+
logger := zap.New(zap.UseFlagOptions(opts), zap.RawZapOpts(uberzap.AddCaller()))
93+
ctrl.SetLogger(logger)
94+
}
95+
96+
func validateFlags() error {
97+
98+
return nil
99+
}

pkg/batch/README.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Batch Processor
2+
3+
## Overview
4+
The batch processor (BP) provides asynchronous workflows for variable SLO-based inference requests.
5+
6+
7+
## Architecture
8+
9+
An underlying implementation should provide persistent messaging that adhere to the interface defined in [api.go](api.go).
10+
11+
A pluggable request policy is used to merge multiple request channels into a single request channel on which the batch worker is listening.
12+
13+
An example for such a policy is a [Random Robin Policy](random_robin_policy.go).
14+
15+
Each [Batch Processor worker](worker.go) is responsible for pulling requests from the merged request channel, submit to the IGW and apply retry logic if needed.
16+
17+
18+
19+
### Requests
20+
21+
Request messages should have the following format:
22+
```json
23+
{
24+
"id" : "unique identifier for result mapping",
25+
"deadline" : "deadline in Unix seconds",
26+
"payload" : {regular inference payload}
27+
}
28+
```
29+
30+
Example:
31+
```json
32+
{
33+
"id" : "19933123533434",
34+
"deadline" : "1764045130",
35+
"payload": {"model":"food-review","prompt":"hi", "max_tokens":10,"temperature":0}
36+
}
37+
```
38+
39+
### Results
40+
41+
Messages on the results channel will have the following structure:
42+
43+
```json
44+
{
45+
"id" : "id mapped to the request",
46+
"payload" : {/*inference payload*/} ,
47+
// or
48+
"error" : "error's reason"
49+
}
50+
```
51+
52+
53+
## Implementations
54+
55+
### Redis
56+
57+
An example implementation based on Redis is provided which behaves as follows:
58+
59+
- Redis Lists as the request queues.
60+
- Redis Sorted Set as the retry exponential backoff implementation.
61+
- Redis List as the result queue.

pkg/batch/api.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package batch
2+
3+
import "context"
4+
5+
type Flow interface {
6+
// starts processing requests.
7+
Start(ctx context.Context)
8+
9+
// returns the channel for requests. Implementation is responsible for populating this channel.
10+
RequestChannels() []RequestChannel
11+
// returns the channel that accepts messages to be retries with their backoff delay.
12+
RetryChannel() chan RetryMessage
13+
// returns the channel for storing the results.
14+
ResultChannel() chan ResultMessage
15+
}
16+
17+
type RequestPolicy interface {
18+
MergeRequestChannels(channels []RequestChannel) RequestChannel
19+
}
20+
21+
type RequestMessage struct {
22+
Id string `json:"id"`
23+
RetryCount int `json:"retry_count,omitempty"`
24+
DeadlineUnixSec string `json:"deadline"`
25+
Payload map[string]any `json:"payload"`
26+
}
27+
28+
type RequestChannel struct {
29+
Channel chan RequestMessage
30+
Metadata map[string]any
31+
}
32+
33+
type RetryMessage struct {
34+
RequestMessage
35+
BackoffDurationSeconds float64
36+
}
37+
38+
type ResultMessage struct {
39+
Id string `json:"id"`
40+
Payload map[string]any `json:"payload"`
41+
}

pkg/batch/random_robin_policy.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package batch
2+
3+
import "reflect"
4+
5+
func NewRandomRobinPolicy() RequestPolicy {
6+
return &RandomRobinPolicy{}
7+
}
8+
9+
type RandomRobinPolicy struct {
10+
}
11+
12+
func (r *RandomRobinPolicy) MergeRequestChannels(channels []RequestChannel) RequestChannel {
13+
mergedChannel := make(chan RequestMessage)
14+
15+
cases := make([]reflect.SelectCase, len(channels))
16+
for i, ch := range channels {
17+
cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch.Channel)}
18+
}
19+
20+
go func() {
21+
for {
22+
i1, val, ok := reflect.Select(cases)
23+
if !ok {
24+
// one of the channels is closed, remove it
25+
newCases := make([]reflect.SelectCase, 0, len(cases)-1)
26+
for i2, c := range cases {
27+
if i2 != i1 {
28+
newCases = append(newCases, c)
29+
}
30+
}
31+
cases = newCases
32+
if len(cases) == 0 {
33+
close(mergedChannel)
34+
break
35+
}
36+
} else {
37+
mergedChannel <- val.Interface().(RequestMessage)
38+
}
39+
40+
}
41+
}()
42+
43+
return RequestChannel{
44+
Channel: mergedChannel,
45+
Metadata: map[string]any{},
46+
}
47+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package batch
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestProcessAllChannels(t *testing.T) {
8+
msgsPerChannel := 5
9+
channels := []RequestChannel{
10+
{Channel: make(chan RequestMessage, msgsPerChannel), Metadata: map[string]any{}},
11+
{Channel: make(chan RequestMessage, msgsPerChannel), Metadata: map[string]any{}},
12+
{Channel: make(chan RequestMessage, msgsPerChannel), Metadata: map[string]any{}},
13+
}
14+
policy := NewRandomRobinPolicy()
15+
16+
// Send messages to each channel
17+
for i, ch := range channels {
18+
for range msgsPerChannel {
19+
ch.Channel <- RequestMessage{Id: string(rune('A' + i))}
20+
}
21+
}
22+
mergedChannel := policy.MergeRequestChannels(channels).Channel
23+
close(channels[0].Channel)
24+
close(channels[1].Channel)
25+
close(channels[2].Channel)
26+
27+
counts := map[string]int{}
28+
totalMessages := msgsPerChannel * 3
29+
for range totalMessages {
30+
msg := <-mergedChannel
31+
counts[msg.Id]++
32+
33+
}
34+
35+
for i := range 3 {
36+
id := string(rune('A' + i))
37+
if counts[id] != msgsPerChannel {
38+
t.Errorf("Expected %d messages from channel %s, got %d", msgsPerChannel, id, counts[id])
39+
}
40+
}
41+
}

0 commit comments

Comments
 (0)