@@ -127,6 +127,18 @@ def get_device_args_rewrite(view):
127127 return pk .hex
128128
129129
130+ def get_vpn_args_rewrite (view ):
131+ """
132+ Use only the PK parameter for calculating the cache key for VPN
133+ """
134+ pk = view .kwargs ["pk" ]
135+ try :
136+ pk = uuid .UUID (pk )
137+ except ValueError :
138+ return pk
139+ return pk .hex
140+
141+
130142class DeviceChecksumView (UpdateLastIpMixin , GetDeviceView ):
131143 """
132144 returns device's configuration checksum
@@ -463,24 +475,43 @@ class GetVpnView(SingleObjectMixin, View):
463475 model = Vpn
464476
465477 def get_object (self , * args , ** kwargs ):
466- queryset = self .model .objects .select_related ("organization" ). filter (
467- Q ( organization__is_active = True ) | Q ( organization__isnull = True )
468- )
478+ queryset = self .model .objects .select_related (
479+ "organization" , "ca" , "cert" , "subnet" , "ip"
480+ ). filter ( Q ( organization__is_active = True ) | Q ( organization__isnull = True ))
469481 return get_object_or_404 (queryset , * args , ** kwargs )
470482
483+ @cache_memoize (
484+ timeout = Vpn ._CHECKSUM_CACHE_TIMEOUT , args_rewrite = get_vpn_args_rewrite
485+ )
486+ def get_vpn (self ):
487+ pk = self .kwargs ["pk" ]
488+ logger .debug (f"retrieving VPN ID { pk } from DB" )
489+ return self .get_object (pk = pk )
490+
491+ @classmethod
492+ def invalidate_get_vpn_cache (cls , instance , ** kwargs ):
493+ """
494+ Called from signal receiver which performs cache invalidation
495+ """
496+ view = cls ()
497+ pk = str (instance .pk .hex )
498+ view .kwargs = {"pk" : pk }
499+ view .get_vpn .invalidate (view )
500+ logger .debug (f"invalidated view cache for VPN ID { pk } " )
501+
471502
472503class VpnChecksumView (GetVpnView ):
473504 """
474505 returns vpn's configuration checksum
475506 """
476507
477508 def get (self , request , * args , ** kwargs ):
478- vpn = self .get_object ( * args , ** kwargs )
509+ vpn = self .get_vpn ( )
479510 bad_request = forbid_unallowed (request , "GET" , "key" , vpn .key )
480511 if bad_request :
481512 return bad_request
482513 checksum_requested .send (sender = vpn .__class__ , instance = vpn , request = request )
483- return ControllerResponse (vpn .checksum , content_type = "text/plain" )
514+ return ControllerResponse (vpn .get_cached_checksum () , content_type = "text/plain" )
484515
485516
486517class VpnDownloadConfigView (GetVpnView ):
@@ -489,7 +520,7 @@ class VpnDownloadConfigView(GetVpnView):
489520 """
490521
491522 def get (self , request , * args , ** kwargs ):
492- vpn = self .get_object ( * args , ** kwargs )
523+ vpn = self .get_vpn ( )
493524 bad_request = forbid_unallowed (request , "GET" , "key" , vpn .key )
494525 if bad_request :
495526 return bad_request
0 commit comments