@@ -158,21 +158,9 @@ def write_connection_file(
158
158
cfg ["signature_scheme" ] = signature_scheme
159
159
cfg ["kernel_name" ] = kernel_name
160
160
161
- # Prevent over-writing a file that has already been written with the same
162
- # info. This is to prevent a race condition where the process has
163
- # already been launched but has not yet read the connection file.
164
- if os .path .exists (fname ):
165
- with open (fname ) as f :
166
- try :
167
- data = json .load (f )
168
- if data == cfg :
169
- return fname , cfg
170
- except Exception :
171
- pass
172
-
173
161
# Only ever write this file as user read/writeable
174
162
# This would otherwise introduce a vulnerability as a file has secrets
175
- # which would let others execute arbitrarily code as you
163
+ # which would let others execute arbitrary code as you
176
164
with secure_write (fname ) as f :
177
165
f .write (json .dumps (cfg , indent = 2 ))
178
166
@@ -579,18 +567,70 @@ def load_connection_info(self, info: KernelConnectionInfo) -> None:
579
567
if "signature_scheme" in info :
580
568
self .session .signature_scheme = info ["signature_scheme" ]
581
569
582
- def _force_connection_info (self , info : KernelConnectionInfo ) -> None :
583
- """Unconditionally loads connection info from a dict containing connection info .
570
+ def _reconcile_connection_info (self , info : KernelConnectionInfo ) -> None :
571
+ """Reconciles the connection information returned from the Provisioner .
584
572
585
- Overwrites connection info-based attributes, regardless of their current values
586
- and writes this information to the connection file.
573
+ Because some provisioners (like derivations of LocalProvisioner) may have already
574
+ written the connection file, this method needs to ensure that, if the connection
575
+ file exists, its contents match that of what was returned by the provisioner. If
576
+ the file does exist and its contents do not match, a ValueError is raised.
577
+
578
+ If the file does not exist, the connection information in 'info' is loaded into the
579
+ KernelManager and written to the file.
587
580
"""
588
- # Reset current ports to 0 and indicate file has not been written to enable override
589
- self ._connection_file_written = False
590
- for name in port_names :
591
- setattr (self , name , 0 )
592
- self .load_connection_info (info )
593
- self .write_connection_file ()
581
+ # Prevent over-writing a file that has already been written with the same
582
+ # info. This is to prevent a race condition where the process has
583
+ # already been launched but has not yet read the connection file - as is
584
+ # the case with LocalProvisioners.
585
+ file_exists : bool = False
586
+ if os .path .exists (self .connection_file ):
587
+ with open (self .connection_file ) as f :
588
+ file_info = json .load (f )
589
+ # Prior to the following comparison, we need to adjust the value of "key" to
590
+ # be bytes, otherwise the comparison below will fail.
591
+ file_info ["key" ] = file_info ["key" ].encode ()
592
+ if not self ._equal_connections (info , file_info ):
593
+ raise ValueError (
594
+ "Connection file already exists and does not match "
595
+ "the expected values returned from provisioner!"
596
+ )
597
+ file_exists = True
598
+
599
+ if not file_exists :
600
+ # Load the connection info and write out file. Note, this does not necessarily
601
+ # overwrite non-zero port values, so we'll validate afterward.
602
+ self .load_connection_info (info )
603
+ self .write_connection_file ()
604
+
605
+ # Ensure what is in KernelManager is what we expect. This will catch issues if the file
606
+ # already existed, yet it's contents differed from the KernelManager's (and provisioner).
607
+ km_info = self .get_connection_info ()
608
+ if not self ._equal_connections (info , km_info ):
609
+ raise ValueError (
610
+ "KernelManager's connection information already exists and does not match "
611
+ "the expected values returned from provisioner!"
612
+ )
613
+
614
+ @staticmethod
615
+ def _equal_connections (conn1 : KernelConnectionInfo , conn2 : KernelConnectionInfo ) -> bool :
616
+ """Compares pertinent keys of connection info data. Returns True if equivalent, False otherwise."""
617
+
618
+ pertinent_keys = [
619
+ "key" ,
620
+ "ip" ,
621
+ "stdin_port" ,
622
+ "iopub_port" ,
623
+ "shell_port" ,
624
+ "control_port" ,
625
+ "hb_port" ,
626
+ "transport" ,
627
+ "signature_scheme" ,
628
+ ]
629
+
630
+ for key in pertinent_keys :
631
+ if conn1 .get (key ) != conn2 .get (key ):
632
+ return False
633
+ return True
594
634
595
635
# --------------------------------------------------------------------------
596
636
# Creating connected sockets
0 commit comments