Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions cmd/node-termination-handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func main() {
log.Fatal().Err(err).Msg("Unable to instantiate a node for various kubernetes node functions,")
}

metrics, err := observability.InitMetrics(nthConfig.EnablePrometheus, nthConfig.PrometheusPort)
metrics, err := observability.InitMetrics(nthConfig)
if err != nil {
nthConfig.Print()
log.Fatal().Err(err).Msg("Unable to instantiate observability metrics,")
Expand Down Expand Up @@ -215,6 +215,10 @@ func main() {
}
log.Debug().Msgf("AWS Credentials retrieved from provider: %s", creds.ProviderName)

ec2Client := ec2.New(sess)

go metrics.InitNodeMetrics(node, ec2Client)

completeLifecycleActionDelay := time.Duration(nthConfig.CompleteLifecycleActionDelaySeconds) * time.Second
sqsMonitor := sqsevent.SQSMonitor{
CheckIfManaged: nthConfig.CheckTagBeforeDraining,
Expand All @@ -224,7 +228,7 @@ func main() {
CancelChan: cancelChan,
SQS: sqsevent.GetSqsClient(sess),
ASG: autoscaling.New(sess),
EC2: ec2.New(sess),
EC2: ec2Client,
BeforeCompleteLifecycleAction: func() { <-time.After(completeLifecycleActionDelay) },
}
monitoringFns[sqsEvents] = sqsMonitor
Expand Down
98 changes: 98 additions & 0 deletions pkg/ec2helper/ec2helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2016-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package ec2helper

import (
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
)

type IEC2Helper interface {
GetInstanceIdsMapByTagKey(tag string) (map[string]bool, error)
}

type EC2Helper struct {
ec2ServiceClient ec2iface.EC2API
}

func New(ec2 ec2iface.EC2API) EC2Helper {
return EC2Helper{
ec2ServiceClient: ec2,
}
}

func (h EC2Helper) GetInstanceIdsByTagKey(tag string) ([]string, error) {
ids := []string{}
var nextToken string

for {
result, err := h.ec2ServiceClient.DescribeInstances(&ec2.DescribeInstancesInput{
Filters: []*ec2.Filter{
{
Name: aws.String("tag-key"),
Values: []*string{aws.String(tag)},
},
},
NextToken: &nextToken,
})

if err != nil {
return nil, err
}

if result == nil || result.Reservations == nil {
return nil, fmt.Errorf("failed to describe instances")
}

for _, reservation := range result.Reservations {
if reservation.Instances == nil {
continue
}
for _, instance := range reservation.Instances {
if instance == nil || instance.InstanceId == nil {
continue
}
ids = append(ids, *instance.InstanceId)
}
}

if result.NextToken == nil {
break
}
nextToken = *result.NextToken
}

return ids, nil
}

func (h EC2Helper) GetInstanceIdsMapByTagKey(tag string) (map[string]bool, error) {
idMap := map[string]bool{}
ids, err := h.GetInstanceIdsByTagKey(tag)
if err != nil {
return nil, err
}

if ids == nil {
return nil, fmt.Errorf("failed to describe instances")
}

for _, id := range ids {
idMap[id] = true
}

return idMap, nil
}
74 changes: 74 additions & 0 deletions pkg/ec2helper/ec2helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2016-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package ec2helper_test

import (
"testing"

"github.com/aws/aws-node-termination-handler/pkg/ec2helper"
h "github.com/aws/aws-node-termination-handler/pkg/test"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
)

const (
instanceId1 = "i-1"
instanceId2 = "i-2"
)

func TestGetInstanceIdsByTagKey(t *testing.T) {
ec2Mock := h.MockedEC2{
DescribeInstancesResp: getDescribeInstancesResp(),
}
ec2Helper := ec2helper.New(ec2Mock)
instanceIds, err := ec2Helper.GetInstanceIdsByTagKey("myNTHManagedTag")
h.Ok(t, err)

h.Equals(t, 2, len(instanceIds))
h.Equals(t, instanceId1, instanceIds[0])
h.Equals(t, instanceId2, instanceIds[1])
}

func TestGetInstanceIdsMapByTagKey(t *testing.T) {
ec2Mock := h.MockedEC2{
DescribeInstancesResp: getDescribeInstancesResp(),
}
ec2Helper := ec2helper.New(ec2Mock)
instanceIdsMap, err := ec2Helper.GetInstanceIdsMapByTagKey("myNTHManagedTag")
h.Ok(t, err)

_, exist := instanceIdsMap[instanceId1]
h.Equals(t, true, exist)
_, exist = instanceIdsMap[instanceId2]
h.Equals(t, true, exist)
_, exist = instanceIdsMap["non-existent instance id"]
h.Equals(t, false, exist)
}

func getDescribeInstancesResp() ec2.DescribeInstancesOutput {
return ec2.DescribeInstancesOutput{
Reservations: []*ec2.Reservation{
{
Instances: []*ec2.Instance{
{
InstanceId: aws.String(instanceId1),
},
{
InstanceId: aws.String(instanceId2),
},
},
},
},
}
}
37 changes: 37 additions & 0 deletions pkg/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"context"
"encoding/json"
"fmt"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -74,6 +75,7 @@ const (
var (
maxRetryDeadline time.Duration = 5 * time.Second
conflictRetryInterval time.Duration = 750 * time.Millisecond
instanceIDRegex = regexp.MustCompile(`^i-.*`)
)

// Node represents a kubernetes node with functions to manipulate its state via the kubernetes api server
Expand Down Expand Up @@ -635,6 +637,41 @@ func (n Node) fetchKubernetesNode(nodeName string) (*corev1.Node, error) {
return &matchingNodes.Items[0], nil
}

// fetchKubernetesNode will send an http request to the k8s api server and return list of AWS EC2 instance id
func (n Node) FetchKubernetesNodeInstanceIds() ([]string, error) {
ids := []string{}

if n.nthConfig.DryRun {
log.Info().Msgf("Would have retrieved nodes, but dry-run flag was set")
return ids, nil
}
matchingNodes, err := n.drainHelper.Client.CoreV1().Nodes().List(context.TODO(), metav1.ListOptions{})
if err != nil {
log.Warn().Msgf("Unable to list Nodes")
return nil, err
}

if matchingNodes == nil || matchingNodes.Items == nil {
return nil, fmt.Errorf("failed to list nodes")
}

for _, node := range matchingNodes.Items {
// sample providerID: aws:///us-west-2a/i-0abcd1234efgh5678
parts := strings.Split(node.Spec.ProviderID, "/")
if len(parts) < 2 {
log.Warn().Msgf("Found invalid providerID: %s", node.Spec.ProviderID)
continue
}

instanceId := parts[len(parts)-1]
if instanceIDRegex.MatchString(instanceId) {
ids = append(ids, parts[len(parts)-1])
}
}

return ids, nil
}

func (n Node) fetchAllPods(nodeName string) (*corev1.PodList, error) {
if n.nthConfig.DryRun {
log.Info().Msgf("Would have retrieved running pod list on node %s, but dry-run flag was set", nodeName)
Expand Down
32 changes: 31 additions & 1 deletion pkg/node/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package node_test

import (
"context"
"fmt"
"strconv"
"strings"
"testing"
Expand All @@ -35,7 +36,11 @@ import (
)

// Size of the fakeRecorder buffer
const recorderBufferSize = 10
const (
recorderBufferSize = 10
instanceId1 = "i-0abcd1234efgh5678"
instanceId2 = "i-0wxyz5678ijkl1234"
)

var nodeName = "NAME"

Expand Down Expand Up @@ -379,6 +384,31 @@ func TestUncordonIfRebootedTimeParseFailure(t *testing.T) {
h.Assert(t, err != nil, "Failed to return error on UncordonIfReboted failure to parse time")
}

func TestFetchKubernetesNodeInstanceIds(t *testing.T) {
client := fake.NewSimpleClientset(
&v1.Node{
ObjectMeta: metav1.ObjectMeta{Name: "node-1"},
Spec: v1.NodeSpec{ProviderID: fmt.Sprintf("aws:///us-west-2a/%s", instanceId1)},
},
&v1.Node{
ObjectMeta: metav1.ObjectMeta{Name: "node-2"},
Spec: v1.NodeSpec{ProviderID: fmt.Sprintf("aws:///us-west-2a/%s", instanceId2)},
},
)

_, err := client.CoreV1().Nodes().List(context.Background(), metav1.ListOptions{})
h.Ok(t, err)

node, err := newNode(config.Config{}, client)
h.Ok(t, err)

instanceIds, err := node.FetchKubernetesNodeInstanceIds()
h.Ok(t, err)
h.Equals(t, 2, len(instanceIds))
h.Equals(t, instanceId1, instanceIds[0])
h.Equals(t, instanceId2, instanceIds[1])
}

func TestFilterOutDaemonSetPods(t *testing.T) {
tNode, err := newNode(config.Config{IgnoreDaemonSets: true}, fake.NewSimpleClientset())
h.Ok(t, err)
Expand Down
Loading