Skip to content

Commit 26f23ca

Browse files
committed
Support for EnvoyRateLimiter in Apache Beam
1 parent 547ab60 commit 26f23ca

File tree

4 files changed

+448
-0
lines changed

4 files changed

+448
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""A simple example demonstrating usage of the EnvoyRateLimiter in a Beam pipeline."""
19+
20+
import argparse
21+
import logging
22+
import time
23+
24+
import apache_beam as beam
25+
from apache_beam.utils import shared
26+
from apache_beam.io.components.rate_limiter import EnvoyRateLimiter
27+
from apache_beam.options.pipeline_options import PipelineOptions
28+
29+
30+
class SampleApiDoFn(beam.DoFn):
31+
"""A DoFn that simulates calling an external API with rate limiting."""
32+
33+
def __init__(self, rls_address, domain, descriptors):
34+
self.rls_address = rls_address
35+
self.domain = domain
36+
self.descriptors = descriptors
37+
self._shared = shared.Shared()
38+
self.rate_limiter = None
39+
40+
def setup(self):
41+
# Initialize the rate limiter in setup()
42+
# We use shared.Shared() to ensure only one RateLimiter instance is created
43+
# per worker and shared across threads.
44+
def init_limiter():
45+
logging.info(f"Connecting to Envoy RLS at {self.rls_address}")
46+
return EnvoyRateLimiter(
47+
service_address=self.rls_address,
48+
domain=self.domain,
49+
descriptors=self.descriptors,
50+
namespace='example_pipeline'
51+
)
52+
self.rate_limiter = self._shared.acquire(init_limiter)
53+
54+
def process(self, element):
55+
self.rate_limiter.throttle()
56+
57+
# Process the element mock API call
58+
logging.info(f"Processing element: {element}")
59+
time.sleep(0.1)
60+
yield element
61+
62+
def parse_known_args(argv):
63+
"""Parses args for the workflow."""
64+
parser = argparse.ArgumentParser()
65+
parser.add_argument(
66+
'--rls_address',
67+
default='localhost:8081',
68+
help='Address of the Envoy Rate Limit Service')
69+
return parser.parse_known_args(argv)
70+
71+
72+
def run(argv=None):
73+
known_args, pipeline_args = parse_known_args(argv)
74+
pipeline_options = PipelineOptions(pipeline_args)
75+
76+
with beam.Pipeline(options=pipeline_options) as p:
77+
(
78+
p
79+
| 'Create' >> beam.Create(range(100))
80+
| 'RateLimit' >> beam.ParDo(SampleApiDoFn(
81+
rls_address=known_args.rls_address,
82+
domain="mongo_cps",
83+
descriptors=[{"database": "users"}]))
84+
)
85+
86+
87+
if __name__ == '__main__':
88+
logging.getLogger().setLevel(logging.INFO)
89+
run()
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
Rate Limiter classes for controlling access to external resources.
20+
"""
21+
22+
import abc
23+
from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitRequest
24+
from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponse
25+
from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponseCode
26+
from envoy_data_plane.envoy.extensions.common.ratelimit.v3 import RateLimitDescriptor
27+
from envoy_data_plane.envoy.extensions.common.ratelimit.v3 import RateLimitDescriptorEntry
28+
import logging
29+
import time
30+
import threading
31+
import random
32+
from typing import Dict
33+
from typing import List
34+
import grpc
35+
import time
36+
from apache_beam.io.components import adaptive_throttler
37+
from apache_beam.metrics import Metrics
38+
39+
_LOGGER = logging.getLogger(__name__)
40+
41+
_MAX_CONNECTION_RETRIES = 5
42+
_RETRY_DELAY_SECONDS = 10
43+
44+
class RateLimiter(abc.ABC):
45+
"""Abstract base class for RateLimiters."""
46+
def __init__(self, namespace: str = ""):
47+
# Metrics collected from the RateLimiter
48+
# Metric updates are thread safe
49+
self.throttling_signaler = adaptive_throttler.ThrottlingSignaler(
50+
namespace=namespace)
51+
self.requests_counter = Metrics.counter(namespace, 'envoyRatelimitRequestsTotal')
52+
self.requests_allowed = Metrics.counter(namespace, 'envoyRatelimitRequestsAllowed')
53+
self.requests_throttled = Metrics.counter(namespace, 'envoyRatelimitRequestsThrottled')
54+
self.rpc_errors = Metrics.counter(namespace, 'envoyRatelimitRpcErrors')
55+
self.rpc_retries = Metrics.counter(namespace, 'envoyRatelimitRpcRetries')
56+
self.rpc_latency = Metrics.distribution(namespace, 'envoyRatelimitRpcLatencyMs')
57+
58+
@abc.abstractmethod
59+
def throttle(self, **kwargs) -> bool:
60+
"""Check if request should be throttled.
61+
62+
Args:
63+
**kwargs: Keyword arguments specific to the RateLimiter implementation.
64+
65+
Returns:
66+
bool: True if the request is allowed, False if retries exceeded.
67+
68+
Raises:
69+
Exception: If an underlying infrastructure error occurs (e.g. RPC failure).
70+
"""
71+
pass
72+
73+
74+
75+
class EnvoyRateLimiter(RateLimiter):
76+
"""
77+
Rate limiter implementation that uses an external Envoy Rate Limit Service.
78+
"""
79+
def __init__(
80+
self,
81+
service_address: str,
82+
domain: str,
83+
descriptors: List[Dict[str, str]],
84+
timeout: float = 1.0,
85+
block_until_allowed: bool = True,
86+
retries: int = 3,
87+
namespace: str = ""):
88+
"""
89+
Args:
90+
service_address: Address of the Envoy RLS (e.g., 'localhost:8081').
91+
domain: The rate limit domain.
92+
descriptors: List of descriptors (key-value pairs).
93+
retries: Number of retries to attempt if rate limited, respected only if
94+
block_until_allowed is False.
95+
timeout: gRPC timeout in seconds.
96+
block_until_allowed: If enabled blocks until RateLimiter gets
97+
the token.
98+
namespace: the namespace to use for logging and signaling
99+
throttling is occurring.
100+
"""
101+
super().__init__(namespace=namespace)
102+
103+
self.service_address = service_address
104+
self.domain = domain
105+
self.descriptors = descriptors
106+
self.retries = retries
107+
self.timeout = timeout
108+
self.block_until_allowed = block_until_allowed
109+
self._stub = None
110+
self._lock = threading.Lock()
111+
112+
class RateLimitServiceStub(object):
113+
"""
114+
Wrapper for gRPC stub to be compatible with envoy_data_plane messages.
115+
116+
The envoy-data-plane package uses 'betterproto' which generates async stubs
117+
for 'grpclib'. As Beam uses standard synchronous 'grpcio', RateLimitServiceStub
118+
is a bridge class to use the betterproto Message types (RateLimitRequest)
119+
with a standard grpcio Channel.
120+
"""
121+
def __init__(self, channel):
122+
self.ShouldRateLimit = channel.unary_unary(
123+
'/envoy.service.ratelimit.v3.RateLimitService/ShouldRateLimit',
124+
request_serializer=RateLimitRequest.SerializeToString,
125+
response_deserializer=RateLimitResponse.FromString,
126+
)
127+
128+
def init_connection(self):
129+
if self._stub is None:
130+
# Acquire lock to safegaurd againest multiple DoFn threads sharing the same
131+
# RateLimiter instance, which is the case when using Shared().
132+
with self._lock:
133+
if self._stub is None:
134+
channel = grpc.insecure_channel(self.service_address)
135+
self._stub = EnvoyRateLimiter.RateLimitServiceStub(channel)
136+
137+
def throttle(self, hits_added: int = 1) -> bool:
138+
"""Calls the Envoy RLS to check for rate limits.
139+
140+
Args:
141+
hits_added: Number of hits to add to the rate limit.
142+
143+
Returns:
144+
bool: True if the request is allowed, False if retries exceeded.
145+
"""
146+
self.init_connection()
147+
148+
# execute thread-safe gRPC call
149+
# Convert descriptors to proto format
150+
proto_descriptors = []
151+
for d in self.descriptors:
152+
entries = []
153+
for k, v in d.items():
154+
entries.append(RateLimitDescriptorEntry(key=k, value=v))
155+
proto_descriptors.append(RateLimitDescriptor(entries=entries))
156+
157+
158+
request = RateLimitRequest(
159+
domain=self.domain, descriptors=proto_descriptors, hits_addend=hits_added)
160+
161+
self.requests_counter.inc()
162+
attempt = 0
163+
throttled = False
164+
while True:
165+
if not self.block_until_allowed and attempt > self.retries:
166+
break
167+
168+
# Connection retry loop
169+
for conn_attempt in range(_MAX_CONNECTION_RETRIES):
170+
try:
171+
start_time = time.time()
172+
response = self._stub.ShouldRateLimit(request, timeout=self.timeout)
173+
self.rpc_latency.update(int((time.time() - start_time) * 1000))
174+
break
175+
except grpc.RpcError as e:
176+
self.rpc_errors.inc()
177+
if conn_attempt == _MAX_CONNECTION_RETRIES:
178+
_LOGGER.error("Envoy RLS Connection Failed: %s", e)
179+
raise e
180+
self.rpc_retries.inc()
181+
_LOGGER.error("Envoy RLS Connection Failed, retrying: %s", e)
182+
time.sleep(_RETRY_DELAY_SECONDS)
183+
184+
if response.overall_code == RateLimitResponseCode.OK:
185+
self.requests_allowed.inc()
186+
throttled = True
187+
break
188+
elif response.overall_code == RateLimitResponseCode.OVER_LIMIT:
189+
self.requests_throttled.inc()
190+
# Ratelimit exceeded, sleep for duration until reset and retry
191+
# multiple rules can be set in the RLS config, so we need to find the max duration
192+
sleep_s = 0.0
193+
if response.statuses:
194+
for status in response.statuses:
195+
if status.code == RateLimitResponseCode.OVER_LIMIT:
196+
dur = status.duration_until_reset
197+
# duration_until_reset is converted to timedelta by betterproto
198+
# timedelta has microseconds precision
199+
val = dur.total_seconds()
200+
if val > sleep_s:
201+
sleep_s = val
202+
203+
_LOGGER.warning("Throttled for %s seconds", sleep_s)
204+
# signal throttled time to backend
205+
self.throttling_signaler.signal_throttled(int(sleep_s))
206+
time.sleep(sleep_s)
207+
attempt += 1
208+
else:
209+
_LOGGER.error(
210+
"Envoy RLS returned unknown code: %s", response.overall_code)
211+
break
212+
return throttled

0 commit comments

Comments
 (0)