Skip to content

Commit ab14c43

Browse files
authored
Support for EnvoyRateLimiter in Beam Python SDK (#37135)
* Support for EnvoyRateLimiter in Apache Beam * fix format issues * fix test formatting * Fix test and syntax * fix lint * Add dependency based on python version * revert setup to separete pr * fix lint * fix formatting * resolve comments
1 parent 78eb4a2 commit ab14c43

File tree

3 files changed

+462
-0
lines changed

3 files changed

+462
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""A simple example demonstrating usage of the EnvoyRateLimiter in a Beam
19+
pipeline.
20+
"""
21+
22+
import argparse
23+
import logging
24+
import time
25+
26+
import apache_beam as beam
27+
from apache_beam.io.components.rate_limiter import EnvoyRateLimiter
28+
from apache_beam.options.pipeline_options import PipelineOptions
29+
from apache_beam.utils import shared
30+
31+
32+
class SampleApiDoFn(beam.DoFn):
33+
"""A DoFn that simulates calling an external API with rate limiting."""
34+
def __init__(self, rls_address, domain, descriptors):
35+
self.rls_address = rls_address
36+
self.domain = domain
37+
self.descriptors = descriptors
38+
self._shared = shared.Shared()
39+
self.rate_limiter = None
40+
41+
def setup(self):
42+
# Initialize the rate limiter in setup()
43+
# We use shared.Shared() to ensure only one RateLimiter instance is created
44+
# per worker and shared across threads.
45+
def init_limiter():
46+
logging.info("Connecting to Envoy RLS at %s", self.rls_address)
47+
return EnvoyRateLimiter(
48+
service_address=self.rls_address,
49+
domain=self.domain,
50+
descriptors=self.descriptors,
51+
namespace='example_pipeline')
52+
53+
self.rate_limiter = self._shared.acquire(init_limiter)
54+
55+
def process(self, element):
56+
self.rate_limiter.throttle()
57+
58+
# Process the element mock API call
59+
logging.info("Processing element: %s", element)
60+
time.sleep(0.1)
61+
yield element
62+
63+
64+
def parse_known_args(argv):
65+
"""Parses args for the workflow."""
66+
parser = argparse.ArgumentParser()
67+
parser.add_argument(
68+
'--rls_address',
69+
default='localhost:8081',
70+
help='Address of the Envoy Rate Limit Service')
71+
return parser.parse_known_args(argv)
72+
73+
74+
def run(argv=None):
75+
known_args, pipeline_args = parse_known_args(argv)
76+
pipeline_options = PipelineOptions(pipeline_args)
77+
78+
with beam.Pipeline(options=pipeline_options) as p:
79+
_ = (
80+
p
81+
| 'Create' >> beam.Create(range(100))
82+
| 'RateLimit' >> beam.ParDo(
83+
SampleApiDoFn(
84+
rls_address=known_args.rls_address,
85+
domain="mongo_cps",
86+
descriptors=[{
87+
"database": "users"
88+
}])))
89+
90+
91+
if __name__ == '__main__':
92+
logging.getLogger().setLevel(logging.INFO)
93+
run()
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
Rate Limiter classes for controlling access to external resources.
20+
"""
21+
22+
import abc
23+
import logging
24+
import math
25+
import random
26+
import threading
27+
import time
28+
from typing import Dict
29+
from typing import List
30+
31+
import grpc
32+
from envoy_data_plane.envoy.extensions.common.ratelimit.v3 import RateLimitDescriptor
33+
from envoy_data_plane.envoy.extensions.common.ratelimit.v3 import RateLimitDescriptorEntry
34+
from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitRequest
35+
from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponse
36+
from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponseCode
37+
38+
from apache_beam.io.components import adaptive_throttler
39+
from apache_beam.metrics import Metrics
40+
41+
_LOGGER = logging.getLogger(__name__)
42+
43+
_RPC_MAX_RETRIES = 5
44+
_RPC_RETRY_DELAY_SECONDS = 10
45+
46+
47+
class RateLimiter(abc.ABC):
48+
"""Abstract base class for RateLimiters."""
49+
def __init__(self, namespace: str = ""):
50+
# Metrics collected from the RateLimiter
51+
# Metric updates are thread safe
52+
self.throttling_signaler = adaptive_throttler.ThrottlingSignaler(
53+
namespace=namespace)
54+
self.requests_counter = Metrics.counter(namespace, 'RatelimitRequestsTotal')
55+
self.requests_allowed = Metrics.counter(
56+
namespace, 'RatelimitRequestsAllowed')
57+
self.requests_throttled = Metrics.counter(
58+
namespace, 'RatelimitRequestsThrottled')
59+
self.rpc_errors = Metrics.counter(namespace, 'RatelimitRpcErrors')
60+
self.rpc_retries = Metrics.counter(namespace, 'RatelimitRpcRetries')
61+
self.rpc_latency = Metrics.distribution(namespace, 'RatelimitRpcLatencyMs')
62+
63+
@abc.abstractmethod
64+
def throttle(self, **kwargs) -> bool:
65+
"""Check if request should be throttled.
66+
67+
Args:
68+
**kwargs: Keyword arguments specific to the RateLimiter implementation.
69+
70+
Returns:
71+
bool: True if the request is allowed, False if retries exceeded.
72+
73+
Raises:
74+
Exception: If an underlying infrastructure error occurs (e.g. RPC
75+
failure).
76+
"""
77+
pass
78+
79+
80+
class EnvoyRateLimiter(RateLimiter):
81+
"""
82+
Rate limiter implementation that uses an external Envoy Rate Limit Service.
83+
"""
84+
def __init__(
85+
self,
86+
service_address: str,
87+
domain: str,
88+
descriptors: List[Dict[str, str]],
89+
timeout: float = 5.0,
90+
block_until_allowed: bool = True,
91+
retries: int = 3,
92+
namespace: str = ""):
93+
"""
94+
Args:
95+
service_address: Address of the Envoy RLS (e.g., 'localhost:8081').
96+
domain: The rate limit domain.
97+
descriptors: List of descriptors (key-value pairs).
98+
retries: Number of retries to attempt if rate limited, respected only if
99+
block_until_allowed is False.
100+
timeout: gRPC timeout in seconds.
101+
block_until_allowed: If enabled blocks until RateLimiter gets
102+
the token.
103+
namespace: the namespace to use for logging and signaling
104+
throttling is occurring.
105+
"""
106+
super().__init__(namespace=namespace)
107+
108+
self.service_address = service_address
109+
self.domain = domain
110+
self.descriptors = descriptors
111+
self.retries = retries
112+
self.timeout = timeout
113+
self.block_until_allowed = block_until_allowed
114+
self._stub = None
115+
self._lock = threading.Lock()
116+
117+
class RateLimitServiceStub(object):
118+
"""
119+
Wrapper for gRPC stub to be compatible with envoy_data_plane messages.
120+
121+
The envoy-data-plane package uses 'betterproto' which generates async stubs
122+
for 'grpclib'. As Beam uses standard synchronous 'grpcio',
123+
RateLimitServiceStub is a bridge class to use the betterproto Message types
124+
(RateLimitRequest) with a standard grpcio Channel.
125+
"""
126+
def __init__(self, channel):
127+
self.ShouldRateLimit = channel.unary_unary(
128+
'/envoy.service.ratelimit.v3.RateLimitService/ShouldRateLimit',
129+
request_serializer=RateLimitRequest.SerializeToString,
130+
response_deserializer=RateLimitResponse.FromString,
131+
)
132+
133+
def init_connection(self):
134+
if self._stub is None:
135+
# Acquire lock to safegaurd againest multiple DoFn threads sharing the
136+
# same RateLimiter instance, which is the case when using Shared().
137+
with self._lock:
138+
if self._stub is None:
139+
channel = grpc.insecure_channel(self.service_address)
140+
self._stub = EnvoyRateLimiter.RateLimitServiceStub(channel)
141+
142+
def throttle(self, hits_added: int = 1) -> bool:
143+
"""Calls the Envoy RLS to check for rate limits.
144+
145+
Args:
146+
hits_added: Number of hits to add to the rate limit.
147+
148+
Returns:
149+
bool: True if the request is allowed, False if retries exceeded.
150+
"""
151+
self.init_connection()
152+
153+
# execute thread-safe gRPC call
154+
# Convert descriptors to proto format
155+
proto_descriptors = []
156+
for d in self.descriptors:
157+
entries = []
158+
for k, v in d.items():
159+
entries.append(RateLimitDescriptorEntry(key=k, value=v))
160+
proto_descriptors.append(RateLimitDescriptor(entries=entries))
161+
162+
request = RateLimitRequest(
163+
domain=self.domain,
164+
descriptors=proto_descriptors,
165+
hits_addend=hits_added)
166+
167+
self.requests_counter.inc()
168+
attempt = 0
169+
throttled = False
170+
while True:
171+
if not self.block_until_allowed and attempt > self.retries:
172+
break
173+
174+
# retry loop
175+
for retry_attempt in range(_RPC_MAX_RETRIES):
176+
try:
177+
start_time = time.time()
178+
response = self._stub.ShouldRateLimit(request, timeout=self.timeout)
179+
self.rpc_latency.update(int((time.time() - start_time) * 1000))
180+
break
181+
except grpc.RpcError as e:
182+
if retry_attempt == _RPC_MAX_RETRIES - 1:
183+
_LOGGER.error(
184+
"[EnvoyRateLimiter] ratelimit service call failed: %s", e)
185+
self.rpc_errors.inc()
186+
raise e
187+
self.rpc_retries.inc()
188+
_LOGGER.warning(
189+
"[EnvoyRateLimiter] ratelimit service call failed, retrying: %s",
190+
e)
191+
time.sleep(_RPC_RETRY_DELAY_SECONDS)
192+
193+
if response.overall_code == RateLimitResponseCode.OK:
194+
self.requests_allowed.inc()
195+
throttled = True
196+
break
197+
elif response.overall_code == RateLimitResponseCode.OVER_LIMIT:
198+
self.requests_throttled.inc()
199+
# Ratelimit exceeded, sleep for duration until reset and retry
200+
# multiple rules can be set in the RLS config, so we need to find the
201+
# max duration
202+
sleep_s = 0.0
203+
if response.statuses:
204+
for status in response.statuses:
205+
if status.code == RateLimitResponseCode.OVER_LIMIT:
206+
dur = status.duration_until_reset
207+
# duration_until_reset is converted to timedelta by betterproto
208+
val = dur.total_seconds()
209+
if val > sleep_s:
210+
sleep_s = val
211+
212+
# Add 1% additive jitter to prevent thundering herd
213+
jitter = random.uniform(0, 0.01 * sleep_s)
214+
sleep_s += jitter
215+
216+
_LOGGER.warning("[EnvoyRateLimiter] Throttled for %s seconds", sleep_s)
217+
# signal throttled time to backend
218+
self.throttling_signaler.signal_throttled(math.ceil(sleep_s))
219+
time.sleep(sleep_s)
220+
attempt += 1
221+
else:
222+
_LOGGER.error(
223+
"[EnvoyRateLimiter] Unknown code from RLS: %s",
224+
response.overall_code)
225+
break
226+
return throttled

0 commit comments

Comments
 (0)