11package runtime
22
33import (
4+ "fmt"
5+ "net"
46 "net/http"
7+ "strconv"
58 "strings"
6-
7- "net"
9+ "time"
810
911 "golang.org/x/net/context"
1012 "google.golang.org/grpc/metadata"
1113)
1214
1315const metadataHeaderPrefix = "Grpc-Metadata-"
1416const metadataTrailerPrefix = "Grpc-Trailer-"
17+ const metadataGrpcTimeout = "Grpc-Timeout"
18+
1519const xForwardedFor = "X-Forwarded-For"
1620const xForwardedHost = "X-Forwarded-Host"
1721
22+ var (
23+ // DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound
24+ // header isn't present. If the value is 0 the sent `context` will not have a timeout.
25+ DefaultContextTimeout = 0 * time .Second
26+ )
27+
1828/*
1929AnnotateContext adds context information such as metadata from the request.
2030
@@ -23,6 +33,10 @@ will be the same context.
2333*/
2434func AnnotateContext (ctx context.Context , req * http.Request ) context.Context {
2535 var pairs []string
36+ timeout := DefaultContextTimeout
37+ if tm := req .Header .Get (metadataGrpcTimeout ); tm != "" {
38+ timeout , _ = timeoutDecode (tm )
39+ }
2640 for key , vals := range req .Header {
2741 for _ , val := range vals {
2842 if key == "Authorization" {
@@ -47,7 +61,9 @@ func AnnotateContext(ctx context.Context, req *http.Request) context.Context {
4761 pairs = append (pairs , strings .ToLower (xForwardedFor ), req .Header .Get (xForwardedFor )+ ", " + remoteIp )
4862 }
4963 }
50-
64+ if timeout != 0 {
65+ ctx , _ = context .WithTimeout (ctx , timeout )
66+ }
5167 if len (pairs ) == 0 {
5268 return ctx
5369 }
@@ -72,3 +88,38 @@ func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool)
7288 md , ok = ctx .Value (serverMetadataKey {}).(ServerMetadata )
7389 return
7490}
91+
92+ func timeoutDecode (s string ) (time.Duration , error ) {
93+ size := len (s )
94+ if size < 2 {
95+ return 0 , fmt .Errorf ("timeout string is too short: %q" , s )
96+ }
97+ d , ok := timeoutUnitToDuration (s [size - 1 ])
98+ if ! ok {
99+ return 0 , fmt .Errorf ("timeout unit is not recognized: %q" , s )
100+ }
101+ t , err := strconv .ParseInt (s [:size - 1 ], 10 , 64 )
102+ if err != nil {
103+ return 0 , err
104+ }
105+ return d * time .Duration (t ), nil
106+ }
107+
108+ func timeoutUnitToDuration (u uint8 ) (d time.Duration , ok bool ) {
109+ switch u {
110+ case 'H' :
111+ return time .Hour , true
112+ case 'M' :
113+ return time .Minute , true
114+ case 'S' :
115+ return time .Second , true
116+ case 'm' :
117+ return time .Millisecond , true
118+ case 'u' :
119+ return time .Microsecond , true
120+ case 'n' :
121+ return time .Nanosecond , true
122+ default :
123+ }
124+ return
125+ }
0 commit comments