88 "errors"
99 "fmt"
1010 "net/http"
11+ "net/netip"
1112 "strings"
1213 "time"
1314
@@ -17,6 +18,7 @@ import (
1718 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2"
1819 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6"
1920 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v6"
21+ "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6"
2022 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
2123 "github.com/google/uuid"
2224 corev1 "k8s.io/api/core/v1"
@@ -39,6 +41,7 @@ type Cluster struct {
3941 ClusterParams * ClusterParams
4042 Maintenance * armcontainerservice.MaintenanceConfiguration
4143 DebugPod * corev1.Pod
44+ Bastion * armnetwork.BastionHost
4245}
4346
4447// Returns true if the cluster is configured with Azure CNI
@@ -66,6 +69,11 @@ func prepareCluster(ctx context.Context, cluster *armcontainerservice.ManagedClu
6669 return nil , fmt .Errorf ("get or create cluster: %w" , err )
6770 }
6871
72+ bastion , err := getOrCreateBastion (ctx , cluster )
73+ if err != nil {
74+ return nil , fmt .Errorf ("get or create bastion: %w" , err )
75+ }
76+
6977 maintenance , err := getOrCreateMaintenanceConfiguration (ctx , cluster )
7078 if err != nil {
7179 return nil , fmt .Errorf ("get or create maintenance configuration: %w" , err )
@@ -143,6 +151,7 @@ func prepareCluster(ctx context.Context, cluster *armcontainerservice.ManagedClu
143151 Maintenance : maintenance ,
144152 ClusterParams : clusterParams ,
145153 DebugPod : hostPod ,
154+ Bastion : bastion ,
146155 }, nil
147156}
148157
@@ -470,6 +479,205 @@ func createNewMaintenanceConfiguration(ctx context.Context, cluster *armcontaine
470479 return & maintenance , nil
471480}
472481
482+ func getOrCreateBastion (ctx context.Context , cluster * armcontainerservice.ManagedCluster ) (* armnetwork.BastionHost , error ) {
483+ nodeRG := * cluster .Properties .NodeResourceGroup
484+ bastionName := fmt .Sprintf ("%s-bastion" , * cluster .Name )
485+
486+ existing , err := config .Azure .BastionHosts .Get (ctx , nodeRG , bastionName , nil )
487+ var azErr * azcore.ResponseError
488+ if errors .As (err , & azErr ) && azErr .StatusCode == http .StatusNotFound {
489+ return createNewBastion (ctx , cluster )
490+ }
491+ if err != nil {
492+ return nil , fmt .Errorf ("failed to get bastion %q in rg %q: %w" , bastionName , nodeRG , err )
493+ }
494+ return & existing .BastionHost , nil
495+ }
496+
497+ func createNewBastion (ctx context.Context , cluster * armcontainerservice.ManagedCluster ) (* armnetwork.BastionHost , error ) {
498+ nodeRG := * cluster .Properties .NodeResourceGroup
499+ location := * cluster .Location
500+ bastionName := fmt .Sprintf ("%s-bastion" , * cluster .Name )
501+ publicIPName := fmt .Sprintf ("%s-bastion-pip" , * cluster .Name )
502+ publicIPName = sanitizeAzureResourceName (publicIPName )
503+
504+ vnet , err := getClusterVNet (ctx , nodeRG )
505+ if err != nil {
506+ return nil , fmt .Errorf ("get cluster vnet in rg %q: %w" , nodeRG , err )
507+ }
508+
509+ // Azure Bastion requires a dedicated subnet named AzureBastionSubnet. Standard SKU (required for
510+ // native client support/tunneling) requires at least a /26.
511+ bastionSubnetName := "AzureBastionSubnet"
512+ bastionSubnetPrefix := "10.226.0.0/26"
513+ if _ , err := netip .ParsePrefix (bastionSubnetPrefix ); err != nil {
514+ return nil , fmt .Errorf ("invalid bastion subnet prefix %q: %w" , bastionSubnetPrefix , err )
515+ }
516+
517+ var bastionSubnetID string
518+ bastionSubnet , subnetGetErr := config .Azure .Subnet .Get (ctx , nodeRG , vnet .name , bastionSubnetName , nil )
519+ if subnetGetErr != nil {
520+ var subnetAzErr * azcore.ResponseError
521+ if ! errors .As (subnetGetErr , & subnetAzErr ) || subnetAzErr .StatusCode != http .StatusNotFound {
522+ return nil , fmt .Errorf ("get subnet %q in vnet %q rg %q: %w" , bastionSubnetName , vnet .name , nodeRG , subnetGetErr )
523+ }
524+
525+ logf (ctx , "creating subnet %s in VNet %s (rg %s)" , bastionSubnetName , vnet .name , nodeRG )
526+ subnetParams := armnetwork.Subnet {
527+ Properties : & armnetwork.SubnetPropertiesFormat {
528+ AddressPrefix : to .Ptr (bastionSubnetPrefix ),
529+ },
530+ }
531+ subnetPoller , err := config .Azure .Subnet .BeginCreateOrUpdate (ctx , nodeRG , vnet .name , bastionSubnetName , subnetParams , nil )
532+ if err != nil {
533+ return nil , fmt .Errorf ("failed to start creating bastion subnet: %w" , err )
534+ }
535+ bastionSubnet , err := subnetPoller .PollUntilDone (ctx , config .DefaultPollUntilDoneOptions )
536+ if err != nil {
537+ return nil , fmt .Errorf ("failed to create bastion subnet: %w" , err )
538+ }
539+ bastionSubnetID = * bastionSubnet .ID
540+ } else {
541+ bastionSubnetID = * bastionSubnet .ID
542+ }
543+
544+ // Public IP for Bastion
545+ pipParams := armnetwork.PublicIPAddress {
546+ Location : to .Ptr (location ),
547+ SKU : & armnetwork.PublicIPAddressSKU {
548+ Name : to .Ptr (armnetwork .PublicIPAddressSKUNameStandard ),
549+ },
550+ Properties : & armnetwork.PublicIPAddressPropertiesFormat {
551+ PublicIPAllocationMethod : to .Ptr (armnetwork .IPAllocationMethodStatic ),
552+ },
553+ }
554+
555+ logf (ctx , "creating bastion public IP %s (rg %s)" , publicIPName , nodeRG )
556+ pipPoller , err := config .Azure .PublicIPAddresses .BeginCreateOrUpdate (ctx , nodeRG , publicIPName , pipParams , nil )
557+ if err != nil {
558+ return nil , fmt .Errorf ("failed to start creating bastion public IP: %w" , err )
559+ }
560+ pipResp , err := pipPoller .PollUntilDone (ctx , config .DefaultPollUntilDoneOptions )
561+ if err != nil {
562+ return nil , fmt .Errorf ("failed to create bastion public IP: %w" , err )
563+ }
564+ if pipResp .ID == nil {
565+ return nil , fmt .Errorf ("bastion public IP response missing ID" )
566+ }
567+
568+ bastion := armnetwork.BastionHost {
569+ Location : to .Ptr (location ),
570+ SKU : & armnetwork.SKU {
571+ Name : to .Ptr (armnetwork .BastionHostSKUNameStandard ),
572+ },
573+ Properties : & armnetwork.BastionHostPropertiesFormat {
574+ // Native client support is enabled via tunneling.
575+ EnableTunneling : to .Ptr (true ),
576+ IPConfigurations : []* armnetwork.BastionHostIPConfiguration {
577+ {
578+ Name : to .Ptr ("bastion-ipcfg" ),
579+ Properties : & armnetwork.BastionHostIPConfigurationPropertiesFormat {
580+ Subnet : & armnetwork.SubResource {
581+ ID : to .Ptr (bastionSubnetID ),
582+ },
583+ PublicIPAddress : & armnetwork.SubResource {
584+ ID : pipResp .ID ,
585+ },
586+ },
587+ },
588+ },
589+ },
590+ }
591+
592+ logf (ctx , "creating bastion %s (native client/tunneling enabled) in rg %s" , bastionName , nodeRG )
593+ bastionPoller , err := config .Azure .BastionHosts .BeginCreateOrUpdate (ctx , nodeRG , bastionName , bastion , nil )
594+ if err != nil {
595+ return nil , fmt .Errorf ("failed to start creating bastion: %w" , err )
596+ }
597+ resp , err := bastionPoller .PollUntilDone (ctx , config .DefaultPollUntilDoneOptions )
598+ if err != nil {
599+ return nil , fmt .Errorf ("failed to create bastion: %w" , err )
600+ }
601+
602+ if err := verifyBastion (ctx , cluster , & resp .BastionHost ); err != nil {
603+ return nil , fmt .Errorf ("failed to verify bastion: %w" , err )
604+ }
605+ return & resp .BastionHost , nil
606+ }
607+
608+ func verifyBastion (ctx context.Context , cluster * armcontainerservice.ManagedCluster , bastion * armnetwork.BastionHost ) error {
609+ nodeRG := * cluster .Properties .NodeResourceGroup
610+ vmssName , err := getSystemPoolVMSSName (ctx , cluster )
611+ if err != nil {
612+ return err
613+ }
614+
615+ var vmssVM * armcompute.VirtualMachineScaleSetVM
616+ pager := config .Azure .VMSSVM .NewListPager (nodeRG , vmssName , nil )
617+ if pager .More () {
618+ page , err := pager .NextPage (ctx )
619+ if err != nil {
620+ return fmt .Errorf ("list vmss vms for %q in rg %q: %w" , vmssName , nodeRG , err )
621+ }
622+ if len (page .Value ) > 0 {
623+ vmssVM = page .Value [0 ]
624+ }
625+ }
626+
627+ ctx , cancel := context .WithCancel (ctx )
628+ defer cancel ()
629+ localPort , pid , err := startBastionTunnel (ctx , * bastion .Name , nodeRG , * vmssVM .ID )
630+ if err != nil {
631+ return err
632+ }
633+
634+ defer cleanupBastionTunnel (localPort , pid )
635+
636+ result , err := runSSHCommandWithPrivateKeyFile (ctx , localPort , "uname -a" , config .SysSSHPrivateKey )
637+ if err != nil {
638+ return err
639+ }
640+ if strings .Contains (result .stdout , * vmssVM .Name ) {
641+ return nil
642+ }
643+ return fmt .Errorf ("Executed ssh on wrong VM: %s" , result .stdout )
644+ }
645+
646+ func getSystemPoolVMSSName (ctx context.Context , cluster * armcontainerservice.ManagedCluster ) (string , error ) {
647+ nodeRG := * cluster .Properties .NodeResourceGroup
648+ var systemPoolName string
649+ for _ , pool := range cluster .Properties .AgentPoolProfiles {
650+ if strings .EqualFold (string (* pool .Mode ), "System" ) {
651+ systemPoolName = * pool .Name
652+ }
653+ }
654+ pager := config .Azure .VMSS .NewListPager (nodeRG , nil )
655+ if pager .More () {
656+ page , err := pager .NextPage (ctx )
657+ if err != nil {
658+ return "" , fmt .Errorf ("list vmss in rg %q: %w" , nodeRG , err )
659+ }
660+ for _ , vmss := range page .Value {
661+ if strings .Contains (strings .ToLower (* vmss .Name ), strings .ToLower (systemPoolName )) {
662+ return * vmss .Name , nil
663+ }
664+ }
665+ }
666+ return "" , fmt .Errorf ("no matching VMSS found for system pool %q in rg %q" , systemPoolName , nodeRG )
667+ }
668+
669+ func sanitizeAzureResourceName (name string ) string {
670+ // Azure resource name restrictions vary by type. For our usage here (Public IP name) we just
671+ // keep it simple and strip problematic characters.
672+ replacer := strings .NewReplacer ("/" , "-" , "\\ " , "-" , ":" , "-" , "_" , "-" , " " , "-" )
673+ name = replacer .Replace (name )
674+ name = strings .Trim (name , "-" )
675+ if len (name ) > 80 {
676+ name = name [:80 ]
677+ }
678+ return name
679+ }
680+
473681type VNet struct {
474682 name string
475683 subnetId string
0 commit comments