Skip to content

Commit 22e4f41

Browse files
authored
Merge pull request #65 from chrigu/feature/http-proxy
Feature/http proxy
2 parents be9a827 + aba9b96 commit 22e4f41

File tree

3 files changed

+149
-8
lines changed

3 files changed

+149
-8
lines changed

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Features
3636
* Ability to set priority for notifications
3737
* Ability to set collapse-key for notifications
3838
* Ability to use production or development APNs server
39+
* Support for basic HTTP-Proxies
3940

4041

4142
Installation

aioapns/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def __init__(
2323
use_sandbox: bool = False,
2424
no_cert_validation: bool = False,
2525
ssl_context: Optional[SSLContext] = None,
26+
proxy_host: Optional[str] = None,
27+
proxy_port: Optional[int] = None,
2628
err_func: Optional[
2729
Callable[
2830
[NotificationRequest, NotificationResult], Awaitable[None]
@@ -42,6 +44,8 @@ def __init__(
4244
use_sandbox=use_sandbox,
4345
no_cert_validation=no_cert_validation,
4446
ssl_context=ssl_context,
47+
proxy_host=proxy_host,
48+
proxy_port=proxy_port,
4549
)
4650
elif key and key_id and team_id and topic:
4751
self.pool = APNsKeyConnectionPool(
@@ -53,6 +57,8 @@ def __init__(
5357
max_connection_attempts=max_connection_attempts,
5458
use_sandbox=use_sandbox,
5559
ssl_context=ssl_context,
60+
proxy_host=proxy_host,
61+
proxy_port=proxy_port,
5662
)
5763
else:
5864
raise ValueError(

aioapns/connection.py

Lines changed: 142 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,8 @@ def __init__(
329329
max_connections: int = 10,
330330
max_connection_attempts: int = 5,
331331
use_sandbox: bool = False,
332+
proxy_host: Optional[str] = None,
333+
proxy_port: Optional[int] = None,
332334
) -> None:
333335
self.apns_topic = topic
334336
self.max_connections = max_connections
@@ -342,6 +344,10 @@ def __init__(
342344
self.connections: List[APNsBaseClientProtocol] = []
343345
self._lock = asyncio.Lock()
344346
self.max_connection_attempts = max_connection_attempts
347+
self.ssl_context: Optional[ssl.SSLContext] = None
348+
349+
self.proxy_host = proxy_host
350+
self.proxy_port = proxy_port
345351

346352
async def create_connection(self) -> APNsBaseClientProtocol:
347353
raise NotImplementedError
@@ -428,6 +434,31 @@ async def send_notification(
428434
logger.error("Failed to send after %d attempts.", attempts)
429435
raise MaxAttemptsExceeded
430436

437+
async def _create_proxy_connection(
438+
self, apns_protocol_factory
439+
) -> APNsBaseClientProtocol:
440+
assert self.proxy_host is not None, "proxy_host must be set"
441+
assert self.proxy_port is not None, "proxy_port must be set"
442+
443+
_, protocol = await self.loop.create_connection(
444+
protocol_factory=partial(
445+
HttpProxyProtocol,
446+
self.protocol_class.APNS_SERVER,
447+
self.protocol_class.APNS_PORT,
448+
self.loop,
449+
self.ssl_context,
450+
apns_protocol_factory,
451+
),
452+
host=self.proxy_host,
453+
port=self.proxy_port,
454+
)
455+
await protocol.apns_connection_ready.wait()
456+
457+
assert (
458+
protocol.apns_protocol is not None
459+
), "protocol.apns_protocol could not be set"
460+
return protocol.apns_protocol
461+
431462

432463
class APNsCertConnectionPool(APNsBaseConnectionPool):
433464
def __init__(
@@ -439,12 +470,16 @@ def __init__(
439470
use_sandbox: bool = False,
440471
no_cert_validation: bool = False,
441472
ssl_context: Optional[ssl.SSLContext] = None,
473+
proxy_host: Optional[str] = None,
474+
proxy_port: Optional[int] = None,
442475
) -> None:
443476
super(APNsCertConnectionPool, self).__init__(
444477
topic=topic,
445478
max_connections=max_connections,
446479
max_connection_attempts=max_connection_attempts,
447480
use_sandbox=use_sandbox,
481+
proxy_host=proxy_host,
482+
proxy_port=proxy_port,
448483
)
449484

450485
self.cert_file = cert_file
@@ -463,13 +498,23 @@ def __init__(
463498
self.apns_topic = cert.get_subject().UID
464499

465500
async def create_connection(self) -> APNsBaseClientProtocol:
501+
apns_protocol_factory = partial(
502+
self.protocol_class,
503+
self.apns_topic,
504+
self.loop,
505+
self.discard_connection,
506+
)
507+
508+
if self.proxy_host and self.proxy_port:
509+
return await self._create_proxy_connection(apns_protocol_factory)
510+
else:
511+
return await self._create_connection(apns_protocol_factory)
512+
513+
async def _create_connection(
514+
self, apns_protocol_factory
515+
) -> APNsBaseClientProtocol:
466516
_, protocol = await self.loop.create_connection(
467-
protocol_factory=partial(
468-
self.protocol_class,
469-
self.apns_topic,
470-
self.loop,
471-
self.discard_connection,
472-
),
517+
protocol_factory=apns_protocol_factory,
473518
host=self.protocol_class.APNS_SERVER,
474519
port=self.protocol_class.APNS_PORT,
475520
ssl=self.ssl_context,
@@ -488,12 +533,16 @@ def __init__(
488533
max_connection_attempts: int = 5,
489534
use_sandbox: bool = False,
490535
ssl_context: Optional[ssl.SSLContext] = None,
536+
proxy_host: Optional[str] = None,
537+
proxy_port: Optional[int] = None,
491538
) -> None:
492539
super(APNsKeyConnectionPool, self).__init__(
493540
topic=topic,
494541
max_connections=max_connections,
495542
max_connection_attempts=max_connection_attempts,
496543
use_sandbox=use_sandbox,
544+
proxy_host=proxy_host,
545+
proxy_port=proxy_port,
497546
)
498547

499548
self.ssl_context = ssl_context or ssl.create_default_context()
@@ -508,16 +557,101 @@ async def create_connection(self) -> APNsBaseClientProtocol:
508557
auth_provider = JWTAuthorizationHeaderProvider(
509558
key=self.key, key_id=self.key_id, team_id=self.team_id
510559
)
511-
_, protocol = await self.loop.create_connection(
512-
protocol_factory=partial(
560+
apns_protocol_factory = (
561+
partial(
513562
self.protocol_class,
514563
self.apns_topic,
515564
self.loop,
516565
self.discard_connection,
517566
auth_provider,
518567
),
568+
)
569+
570+
if self.proxy_host and self.proxy_port:
571+
return await self._create_proxy_connection(apns_protocol_factory)
572+
else:
573+
return await self._create_connection(apns_protocol_factory)
574+
575+
async def _create_connection(
576+
self, apns_protocol_factory
577+
) -> APNsBaseClientProtocol:
578+
_, protocol = await self.loop.create_connection(
579+
protocol_factory=apns_protocol_factory,
519580
host=self.protocol_class.APNS_SERVER,
520581
port=self.protocol_class.APNS_PORT,
521582
ssl=self.ssl_context,
522583
)
523584
return protocol
585+
586+
587+
class HttpProxyProtocol(asyncio.Protocol):
588+
def __init__(
589+
self,
590+
apns_host: str,
591+
apns_port: int,
592+
loop: asyncio.AbstractEventLoop,
593+
ssl_context: ssl.SSLContext,
594+
protocol_factory,
595+
):
596+
self.apns_host = apns_host
597+
self.apns_port = apns_port
598+
self.buffer = bytearray()
599+
self.loop = loop
600+
self.ssl_context = ssl_context
601+
self.apns_protocol_factory = protocol_factory
602+
self.apns_protocol: Optional[APNsBaseClientProtocol] = None
603+
self.transport = None
604+
self.apns_connection_ready = (
605+
asyncio.Event()
606+
) # Event to signal APNs readiness
607+
608+
def connection_made(self, transport):
609+
logger.debug(
610+
"Proxy connection made.",
611+
)
612+
self.transport = transport
613+
connect_request = (
614+
f"CONNECT {self.apns_host}:{self.apns_port} "
615+
f"HTTP/1.1\r\nHost: "
616+
f"{self.apns_host}\r\nConnection: close\r\n\r\n"
617+
)
618+
self.transport.write(connect_request.encode("utf-8"))
619+
620+
def data_received(self, data):
621+
# Data is usually received in bytes,
622+
# so you might want to decode or process it
623+
logger.debug("Raw data received: %s", data)
624+
self.buffer.extend(data)
625+
# some proxies send "HTTP/1.1 200 Connection established"
626+
# others "HTTP/1.1 200 Connected"
627+
if b"HTTP/1.1 200 Connect" in data:
628+
logger.debug(
629+
"Proxy tunnel established.",
630+
)
631+
asyncio.create_task(self.create_apns_connection())
632+
else:
633+
logger.debug(
634+
"Data received (before APNs connection establishment): %s",
635+
data.decode(),
636+
)
637+
638+
async def create_apns_connection(self):
639+
# Use the existing transport to create a new APNs connection
640+
logger.debug(
641+
"Initiating APNs connection.",
642+
)
643+
sock = self.transport.get_extra_info("socket")
644+
_, self.apns_protocol = await self.loop.create_connection(
645+
self.apns_protocol_factory,
646+
server_hostname=self.apns_host,
647+
ssl=self.ssl_context,
648+
sock=sock,
649+
)
650+
# Signal that APNs connection is ready
651+
self.apns_connection_ready.set()
652+
653+
def connection_lost(self, exc):
654+
logger.debug(
655+
"Proxy connection lost.",
656+
)
657+
self.transport.close()

0 commit comments

Comments
 (0)