@@ -329,6 +329,8 @@ def __init__(
329329 max_connections : int = 10 ,
330330 max_connection_attempts : int = 5 ,
331331 use_sandbox : bool = False ,
332+ proxy_host : Optional [str ] = None ,
333+ proxy_port : Optional [int ] = None ,
332334 ) -> None :
333335 self .apns_topic = topic
334336 self .max_connections = max_connections
@@ -342,6 +344,10 @@ def __init__(
342344 self .connections : List [APNsBaseClientProtocol ] = []
343345 self ._lock = asyncio .Lock ()
344346 self .max_connection_attempts = max_connection_attempts
347+ self .ssl_context : Optional [ssl .SSLContext ] = None
348+
349+ self .proxy_host = proxy_host
350+ self .proxy_port = proxy_port
345351
346352 async def create_connection (self ) -> APNsBaseClientProtocol :
347353 raise NotImplementedError
@@ -428,6 +434,31 @@ async def send_notification(
428434 logger .error ("Failed to send after %d attempts." , attempts )
429435 raise MaxAttemptsExceeded
430436
437+ async def _create_proxy_connection (
438+ self , apns_protocol_factory
439+ ) -> APNsBaseClientProtocol :
440+ assert self .proxy_host is not None , "proxy_host must be set"
441+ assert self .proxy_port is not None , "proxy_port must be set"
442+
443+ _ , protocol = await self .loop .create_connection (
444+ protocol_factory = partial (
445+ HttpProxyProtocol ,
446+ self .protocol_class .APNS_SERVER ,
447+ self .protocol_class .APNS_PORT ,
448+ self .loop ,
449+ self .ssl_context ,
450+ apns_protocol_factory ,
451+ ),
452+ host = self .proxy_host ,
453+ port = self .proxy_port ,
454+ )
455+ await protocol .apns_connection_ready .wait ()
456+
457+ assert (
458+ protocol .apns_protocol is not None
459+ ), "protocol.apns_protocol could not be set"
460+ return protocol .apns_protocol
461+
431462
432463class APNsCertConnectionPool (APNsBaseConnectionPool ):
433464 def __init__ (
@@ -439,12 +470,16 @@ def __init__(
439470 use_sandbox : bool = False ,
440471 no_cert_validation : bool = False ,
441472 ssl_context : Optional [ssl .SSLContext ] = None ,
473+ proxy_host : Optional [str ] = None ,
474+ proxy_port : Optional [int ] = None ,
442475 ) -> None :
443476 super (APNsCertConnectionPool , self ).__init__ (
444477 topic = topic ,
445478 max_connections = max_connections ,
446479 max_connection_attempts = max_connection_attempts ,
447480 use_sandbox = use_sandbox ,
481+ proxy_host = proxy_host ,
482+ proxy_port = proxy_port ,
448483 )
449484
450485 self .cert_file = cert_file
@@ -463,13 +498,23 @@ def __init__(
463498 self .apns_topic = cert .get_subject ().UID
464499
465500 async def create_connection (self ) -> APNsBaseClientProtocol :
501+ apns_protocol_factory = partial (
502+ self .protocol_class ,
503+ self .apns_topic ,
504+ self .loop ,
505+ self .discard_connection ,
506+ )
507+
508+ if self .proxy_host and self .proxy_port :
509+ return await self ._create_proxy_connection (apns_protocol_factory )
510+ else :
511+ return await self ._create_connection (apns_protocol_factory )
512+
513+ async def _create_connection (
514+ self , apns_protocol_factory
515+ ) -> APNsBaseClientProtocol :
466516 _ , protocol = await self .loop .create_connection (
467- protocol_factory = partial (
468- self .protocol_class ,
469- self .apns_topic ,
470- self .loop ,
471- self .discard_connection ,
472- ),
517+ protocol_factory = apns_protocol_factory ,
473518 host = self .protocol_class .APNS_SERVER ,
474519 port = self .protocol_class .APNS_PORT ,
475520 ssl = self .ssl_context ,
@@ -488,12 +533,16 @@ def __init__(
488533 max_connection_attempts : int = 5 ,
489534 use_sandbox : bool = False ,
490535 ssl_context : Optional [ssl .SSLContext ] = None ,
536+ proxy_host : Optional [str ] = None ,
537+ proxy_port : Optional [int ] = None ,
491538 ) -> None :
492539 super (APNsKeyConnectionPool , self ).__init__ (
493540 topic = topic ,
494541 max_connections = max_connections ,
495542 max_connection_attempts = max_connection_attempts ,
496543 use_sandbox = use_sandbox ,
544+ proxy_host = proxy_host ,
545+ proxy_port = proxy_port ,
497546 )
498547
499548 self .ssl_context = ssl_context or ssl .create_default_context ()
@@ -508,16 +557,101 @@ async def create_connection(self) -> APNsBaseClientProtocol:
508557 auth_provider = JWTAuthorizationHeaderProvider (
509558 key = self .key , key_id = self .key_id , team_id = self .team_id
510559 )
511- _ , protocol = await self . loop . create_connection (
512- protocol_factory = partial (
560+ apns_protocol_factory = (
561+ partial (
513562 self .protocol_class ,
514563 self .apns_topic ,
515564 self .loop ,
516565 self .discard_connection ,
517566 auth_provider ,
518567 ),
568+ )
569+
570+ if self .proxy_host and self .proxy_port :
571+ return await self ._create_proxy_connection (apns_protocol_factory )
572+ else :
573+ return await self ._create_connection (apns_protocol_factory )
574+
575+ async def _create_connection (
576+ self , apns_protocol_factory
577+ ) -> APNsBaseClientProtocol :
578+ _ , protocol = await self .loop .create_connection (
579+ protocol_factory = apns_protocol_factory ,
519580 host = self .protocol_class .APNS_SERVER ,
520581 port = self .protocol_class .APNS_PORT ,
521582 ssl = self .ssl_context ,
522583 )
523584 return protocol
585+
586+
587+ class HttpProxyProtocol (asyncio .Protocol ):
588+ def __init__ (
589+ self ,
590+ apns_host : str ,
591+ apns_port : int ,
592+ loop : asyncio .AbstractEventLoop ,
593+ ssl_context : ssl .SSLContext ,
594+ protocol_factory ,
595+ ):
596+ self .apns_host = apns_host
597+ self .apns_port = apns_port
598+ self .buffer = bytearray ()
599+ self .loop = loop
600+ self .ssl_context = ssl_context
601+ self .apns_protocol_factory = protocol_factory
602+ self .apns_protocol : Optional [APNsBaseClientProtocol ] = None
603+ self .transport = None
604+ self .apns_connection_ready = (
605+ asyncio .Event ()
606+ ) # Event to signal APNs readiness
607+
608+ def connection_made (self , transport ):
609+ logger .debug (
610+ "Proxy connection made." ,
611+ )
612+ self .transport = transport
613+ connect_request = (
614+ f"CONNECT { self .apns_host } :{ self .apns_port } "
615+ f"HTTP/1.1\r \n Host: "
616+ f"{ self .apns_host } \r \n Connection: close\r \n \r \n "
617+ )
618+ self .transport .write (connect_request .encode ("utf-8" ))
619+
620+ def data_received (self , data ):
621+ # Data is usually received in bytes,
622+ # so you might want to decode or process it
623+ logger .debug ("Raw data received: %s" , data )
624+ self .buffer .extend (data )
625+ # some proxies send "HTTP/1.1 200 Connection established"
626+ # others "HTTP/1.1 200 Connected"
627+ if b"HTTP/1.1 200 Connect" in data :
628+ logger .debug (
629+ "Proxy tunnel established." ,
630+ )
631+ asyncio .create_task (self .create_apns_connection ())
632+ else :
633+ logger .debug (
634+ "Data received (before APNs connection establishment): %s" ,
635+ data .decode (),
636+ )
637+
638+ async def create_apns_connection (self ):
639+ # Use the existing transport to create a new APNs connection
640+ logger .debug (
641+ "Initiating APNs connection." ,
642+ )
643+ sock = self .transport .get_extra_info ("socket" )
644+ _ , self .apns_protocol = await self .loop .create_connection (
645+ self .apns_protocol_factory ,
646+ server_hostname = self .apns_host ,
647+ ssl = self .ssl_context ,
648+ sock = sock ,
649+ )
650+ # Signal that APNs connection is ready
651+ self .apns_connection_ready .set ()
652+
653+ def connection_lost (self , exc ):
654+ logger .debug (
655+ "Proxy connection lost." ,
656+ )
657+ self .transport .close ()
0 commit comments