2828from typing import List
2929
3030import grpc
31- from envoy_data_plane .envoy .extensions .common .ratelimit .v3 import RateLimitDescriptor
32- from envoy_data_plane .envoy .extensions .common .ratelimit .v3 import RateLimitDescriptorEntry
33- from envoy_data_plane .envoy .service .ratelimit .v3 import RateLimitRequest
34- from envoy_data_plane .envoy .service .ratelimit .v3 import RateLimitResponse
35- from envoy_data_plane .envoy .service .ratelimit .v3 import RateLimitResponseCode
31+ from envoy .extensions .common .ratelimit .v3 import ratelimit_pb2
32+ from envoy .service .ratelimit .v3 import rls_pb2
33+ from envoy .service .ratelimit .v3 import rls_pb2_grpc
3634
3735from apache_beam .io .components import adaptive_throttler
3836from apache_beam .metrics import Metrics
@@ -114,30 +112,14 @@ def __init__(
114112 self ._stub = None
115113 self ._lock = threading .Lock ()
116114
117- class RateLimitServiceStub (object ):
118- """
119- Wrapper for gRPC stub to be compatible with envoy_data_plane messages.
120-
121- The envoy-data-plane package uses 'betterproto' which generates async stubs
122- for 'grpclib'. As Beam uses standard synchronous 'grpcio', RateLimitServiceStub
123- is a bridge class to use the betterproto Message types (RateLimitRequest)
124- with a standard grpcio Channel.
125- """
126- def __init__ (self , channel ):
127- self .ShouldRateLimit = channel .unary_unary (
128- '/envoy.service.ratelimit.v3.RateLimitService/ShouldRateLimit' ,
129- request_serializer = RateLimitRequest .SerializeToString ,
130- response_deserializer = RateLimitResponse .FromString ,
131- )
132-
133115 def init_connection (self ):
134116 if self ._stub is None :
135117 # Acquire lock to safegaurd againest multiple DoFn threads sharing the same
136118 # RateLimiter instance, which is the case when using Shared().
137119 with self ._lock :
138120 if self ._stub is None :
139121 channel = grpc .insecure_channel (self .service_address )
140- self ._stub = EnvoyRateLimiter .RateLimitServiceStub (channel )
122+ self ._stub = rls_pb2_grpc .RateLimitServiceStub (channel )
141123
142124 def throttle (self , hits_added : int = 1 ) -> bool :
143125 """Calls the Envoy RLS to check for rate limits.
@@ -156,10 +138,10 @@ def throttle(self, hits_added: int = 1) -> bool:
156138 for d in self .descriptors :
157139 entries = []
158140 for k , v in d .items ():
159- entries .append (RateLimitDescriptorEntry (key = k , value = v ))
160- proto_descriptors .append (RateLimitDescriptor (entries = entries ))
141+ entries .append (ratelimit_pb2 . RateLimitDescriptor . Entry (key = k , value = v ))
142+ proto_descriptors .append (ratelimit_pb2 . RateLimitDescriptor (entries = entries ))
161143
162- request = RateLimitRequest (
144+ request = rls_pb2 . RateLimitRequest (
163145 domain = self .domain ,
164146 descriptors = proto_descriptors ,
165147 hits_addend = hits_added )
@@ -188,22 +170,21 @@ def throttle(self, hits_added: int = 1) -> bool:
188170 "[EnvoyRateLimiter] Connection Failed, retrying: %s" , e )
189171 time .sleep (_RETRY_DELAY_SECONDS )
190172
191- if response .overall_code == RateLimitResponseCode .OK :
173+ if response .overall_code == rls_pb2 . RateLimitResponse .OK :
192174 self .requests_allowed .inc ()
193175 throttled = True
194176 break
195- elif response .overall_code == RateLimitResponseCode .OVER_LIMIT :
177+ elif response .overall_code == rls_pb2 . RateLimitResponse .OVER_LIMIT :
196178 self .requests_throttled .inc ()
197179 # Ratelimit exceeded, sleep for duration until reset and retry
198180 # multiple rules can be set in the RLS config, so we need to find the max duration
199181 sleep_s = 0.0
200182 if response .statuses :
201183 for status in response .statuses :
202- if status .code == RateLimitResponseCode .OVER_LIMIT :
184+ if status .code == rls_pb2 . RateLimitResponse .OVER_LIMIT :
203185 dur = status .duration_until_reset
204- # duration_until_reset is converted to timedelta by betterproto
205- # duration_until_reset has microsecond precision
206- val = dur .total_seconds ()
186+ # duration_until_reset is google.protobuf.Duration
187+ val = dur .seconds + dur .nanos / 1e9
207188 if val > sleep_s :
208189 sleep_s = val
209190
0 commit comments