Skip to content

Commit 08320be

Browse files
fix: fix early cancel when RequestTimeout is provided for streaming requests (#221)
1 parent 978707d commit 08320be

File tree

3 files changed

+169
-7
lines changed

3 files changed

+169
-7
lines changed

client_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package openai_test
55
import (
66
"context"
77
"fmt"
8+
"io"
89
"net/http"
910
"reflect"
1011
"testing"
@@ -281,3 +282,115 @@ func TestContextDeadline(t *testing.T) {
281282
}
282283
}
283284
}
285+
286+
func TestContextDeadlineStreaming(t *testing.T) {
287+
testTimeout := time.After(3 * time.Second)
288+
testDone := make(chan struct{})
289+
290+
deadline := time.Now().Add(100 * time.Millisecond)
291+
deadlineCtx, cancel := context.WithDeadline(context.Background(), deadline)
292+
defer cancel()
293+
294+
go func() {
295+
client := openai.NewClient(
296+
option.WithHTTPClient(&http.Client{
297+
Transport: &closureTransport{
298+
fn: func(req *http.Request) (*http.Response, error) {
299+
return &http.Response{
300+
StatusCode: 200,
301+
Status: "200 OK",
302+
Body: io.NopCloser(
303+
io.Reader(readerFunc(func([]byte) (int, error) {
304+
<-req.Context().Done()
305+
return 0, req.Context().Err()
306+
})),
307+
),
308+
}, nil
309+
},
310+
},
311+
}),
312+
)
313+
stream := client.Chat.Completions.NewStreaming(deadlineCtx, openai.ChatCompletionNewParams{
314+
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionDeveloperMessageParam{
315+
Content: openai.F([]openai.ChatCompletionContentPartTextParam{{Text: openai.F("text"), Type: openai.F(openai.ChatCompletionContentPartTextTypeText)}}),
316+
Role: openai.F(openai.ChatCompletionDeveloperMessageParamRoleDeveloper),
317+
}}),
318+
Model: openai.F(openai.ChatModelO3Mini),
319+
})
320+
for stream.Next() {
321+
_ = stream.Current()
322+
}
323+
if stream.Err() == nil {
324+
t.Error("expected there to be a deadline error")
325+
}
326+
close(testDone)
327+
}()
328+
329+
select {
330+
case <-testTimeout:
331+
t.Fatal("client didn't finish in time")
332+
case <-testDone:
333+
if diff := time.Since(deadline); diff < -30*time.Millisecond || 30*time.Millisecond < diff {
334+
t.Fatalf("client did not return within 30ms of context deadline, got %s", diff)
335+
}
336+
}
337+
}
338+
339+
func TestContextDeadlineStreamingWithRequestTimeout(t *testing.T) {
340+
testTimeout := time.After(3 * time.Second)
341+
testDone := make(chan struct{})
342+
deadline := time.Now().Add(100 * time.Millisecond)
343+
344+
go func() {
345+
client := openai.NewClient(
346+
option.WithHTTPClient(&http.Client{
347+
Transport: &closureTransport{
348+
fn: func(req *http.Request) (*http.Response, error) {
349+
return &http.Response{
350+
StatusCode: 200,
351+
Status: "200 OK",
352+
Body: io.NopCloser(
353+
io.Reader(readerFunc(func([]byte) (int, error) {
354+
<-req.Context().Done()
355+
return 0, req.Context().Err()
356+
})),
357+
),
358+
}, nil
359+
},
360+
},
361+
}),
362+
)
363+
stream := client.Chat.Completions.NewStreaming(
364+
context.Background(),
365+
openai.ChatCompletionNewParams{
366+
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionDeveloperMessageParam{
367+
Content: openai.F([]openai.ChatCompletionContentPartTextParam{{Text: openai.F("text"), Type: openai.F(openai.ChatCompletionContentPartTextTypeText)}}),
368+
Role: openai.F(openai.ChatCompletionDeveloperMessageParamRoleDeveloper),
369+
}}),
370+
Model: openai.F(openai.ChatModelO3Mini),
371+
},
372+
option.WithRequestTimeout((100 * time.Millisecond)),
373+
)
374+
for stream.Next() {
375+
_ = stream.Current()
376+
}
377+
if stream.Err() == nil {
378+
t.Error("expected there to be a deadline error")
379+
}
380+
close(testDone)
381+
}()
382+
383+
select {
384+
case <-testTimeout:
385+
t.Fatal("client didn't finish in time")
386+
case <-testDone:
387+
if diff := time.Since(deadline); diff < -30*time.Millisecond || 30*time.Millisecond < diff {
388+
t.Fatalf("client did not return within 30ms of context deadline, got %s", diff)
389+
}
390+
}
391+
}
392+
393+
type readerFunc func([]byte) (int, error)
394+
395+
func (f readerFunc) Read(p []byte) (int, error) { return f(p) }
396+
func (f readerFunc) Close() error { return nil }

internal/requestconfig/requestconfig.go

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,41 @@ func parseRetryAfterHeader(resp *http.Response) (time.Duration, bool) {
294294
return 0, false
295295
}
296296

297+
// isBeforeContextDeadline reports whether the non-zero Time t is
298+
// before ctx's deadline. If ctx does not have a deadline, it
299+
// always reports true (the deadline is considered infinite).
300+
func isBeforeContextDeadline(t time.Time, ctx context.Context) bool {
301+
d, ok := ctx.Deadline()
302+
if !ok {
303+
return true
304+
}
305+
return t.Before(d)
306+
}
307+
308+
// bodyWithTimeout is an io.ReadCloser which can observe a context's cancel func
309+
// to handle timeouts etc. It wraps an existing io.ReadCloser.
310+
type bodyWithTimeout struct {
311+
stop func() // stops the time.Timer waiting to cancel the request
312+
rc io.ReadCloser
313+
}
314+
315+
func (b *bodyWithTimeout) Read(p []byte) (n int, err error) {
316+
n, err = b.rc.Read(p)
317+
if err == nil {
318+
return n, nil
319+
}
320+
if err == io.EOF {
321+
return n, err
322+
}
323+
return n, err
324+
}
325+
326+
func (b *bodyWithTimeout) Close() error {
327+
err := b.rc.Close()
328+
b.stop()
329+
return err
330+
}
331+
297332
func retryDelay(res *http.Response, retryCount int) time.Duration {
298333
// If the API asks us to wait a certain amount of time (and it's a reasonable amount),
299334
// just do what it says.
@@ -355,12 +390,17 @@ func (cfg *RequestConfig) Execute() (err error) {
355390
shouldSendRetryCount := cfg.Request.Header.Get("X-Stainless-Retry-Count") == "0"
356391

357392
var res *http.Response
393+
var cancel context.CancelFunc
358394
for retryCount := 0; retryCount <= cfg.MaxRetries; retryCount += 1 {
359395
ctx := cfg.Request.Context()
360-
if cfg.RequestTimeout != time.Duration(0) {
361-
var cancel context.CancelFunc
396+
if cfg.RequestTimeout != time.Duration(0) && isBeforeContextDeadline(time.Now().Add(cfg.RequestTimeout), ctx) {
362397
ctx, cancel = context.WithTimeout(ctx, cfg.RequestTimeout)
363-
defer cancel()
398+
defer func() {
399+
// The cancel function is nil if it was handed off to be handled in a different scope.
400+
if cancel != nil {
401+
cancel()
402+
}
403+
}()
364404
}
365405

366406
req := cfg.Request.Clone(ctx)
@@ -428,10 +468,15 @@ func (cfg *RequestConfig) Execute() (err error) {
428468
return &aerr
429469
}
430470

431-
if cfg.ResponseBodyInto == nil {
432-
return nil
433-
}
434-
if _, ok := cfg.ResponseBodyInto.(**http.Response); ok {
471+
_, intoCustomResponseBody := cfg.ResponseBodyInto.(**http.Response)
472+
if cfg.ResponseBodyInto == nil || intoCustomResponseBody {
473+
// We aren't reading the response body in this scope, but whoever is will need the
474+
// cancel func from the context to observe request timeouts.
475+
// Put the cancel function in the response body so it can be handled elsewhere.
476+
if cancel != nil {
477+
res.Body = &bodyWithTimeout{rc: res.Body, stop: cancel}
478+
cancel = nil
479+
}
435480
return nil
436481
}
437482

packages/ssestream/ssestream.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ func (s *eventStreamDecoder) Next() bool {
102102
}
103103
}
104104

105+
if s.scn.Err() != nil {
106+
s.err = s.scn.Err()
107+
}
108+
105109
return false
106110
}
107111

0 commit comments

Comments
 (0)