Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions cmd/node-termination-handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func main() {
log.Fatal().Err(err).Msg("Unable to instantiate a node for various kubernetes node functions,")
}

metrics, err := observability.InitMetrics(nthConfig.EnablePrometheus, nthConfig.PrometheusPort)
metrics, err := observability.InitMetrics(nthConfig)
if err != nil {
nthConfig.Print()
log.Fatal().Err(err).Msg("Unable to instantiate observability metrics,")
Expand Down Expand Up @@ -215,6 +215,12 @@ func main() {
}
log.Debug().Msgf("AWS Credentials retrieved from provider: %s", creds.ProviderName)

ec2Client := ec2.New(sess)

if nthConfig.EnablePrometheus {
go metrics.InitNodeMetrics(node, ec2Client)
}

completeLifecycleActionDelay := time.Duration(nthConfig.CompleteLifecycleActionDelaySeconds) * time.Second
sqsMonitor := sqsevent.SQSMonitor{
CheckIfManaged: nthConfig.CheckTagBeforeDraining,
Expand All @@ -224,7 +230,7 @@ func main() {
CancelChan: cancelChan,
SQS: sqsevent.GetSqsClient(sess),
ASG: autoscaling.New(sess),
EC2: ec2.New(sess),
EC2: ec2Client,
BeforeCompleteLifecycleAction: func() { <-time.After(completeLifecycleActionDelay) },
}
monitoringFns[sqsEvents] = sqsMonitor
Expand Down
98 changes: 98 additions & 0 deletions pkg/ec2helper/ec2helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2016-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package ec2helper

import (
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
)

type IEC2Helper interface {
GetInstanceIdsMapByTagKey(tag string) (map[string]bool, error)
}

type EC2Helper struct {
ec2ServiceClient ec2iface.EC2API
}

func New(ec2 ec2iface.EC2API) EC2Helper {
return EC2Helper{
ec2ServiceClient: ec2,
}
}

func (h EC2Helper) GetInstanceIdsByTagKey(tag string) ([]string, error) {
ids := []string{}
var nextToken string

for {
result, err := h.ec2ServiceClient.DescribeInstances(&ec2.DescribeInstancesInput{
Filters: []*ec2.Filter{
{
Name: aws.String("tag-key"),
Values: []*string{aws.String(tag)},
},
},
NextToken: &nextToken,
})

if err != nil {
return nil, err
}

if result == nil || result.Reservations == nil {
return nil, fmt.Errorf("failed to describe instances")
}

for _, reservation := range result.Reservations {
if reservation.Instances == nil {
continue
}
for _, instance := range reservation.Instances {
if instance == nil || instance.InstanceId == nil {
continue
}
ids = append(ids, *instance.InstanceId)
}
}

if result.NextToken == nil {
break
}
nextToken = *result.NextToken
}

return ids, nil
}

func (h EC2Helper) GetInstanceIdsMapByTagKey(tag string) (map[string]bool, error) {
idMap := map[string]bool{}
ids, err := h.GetInstanceIdsByTagKey(tag)
if err != nil {
return nil, err
}

if ids == nil {
return nil, fmt.Errorf("failed to describe instances")
}

for _, id := range ids {
idMap[id] = true
}

return idMap, nil
}
74 changes: 74 additions & 0 deletions pkg/ec2helper/ec2helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2016-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package ec2helper_test

import (
"testing"

"github.com/aws/aws-node-termination-handler/pkg/ec2helper"
h "github.com/aws/aws-node-termination-handler/pkg/test"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
)

const (
instanceId1 = "i-1"
instanceId2 = "i-2"
)

func TestGetInstanceIdsByTagKey(t *testing.T) {
ec2Mock := h.MockedEC2{
DescribeInstancesResp: getDescribeInstancesResp(),
}
ec2Helper := ec2helper.New(ec2Mock)
instanceIds, err := ec2Helper.GetInstanceIdsByTagKey("myNTHManagedTag")
h.Ok(t, err)

h.Equals(t, 2, len(instanceIds))
h.Equals(t, instanceId1, instanceIds[0])
h.Equals(t, instanceId2, instanceIds[1])
}

func TestGetInstanceIdsMapByTagKey(t *testing.T) {
ec2Mock := h.MockedEC2{
DescribeInstancesResp: getDescribeInstancesResp(),
}
ec2Helper := ec2helper.New(ec2Mock)
instanceIdsMap, err := ec2Helper.GetInstanceIdsMapByTagKey("myNTHManagedTag")
h.Ok(t, err)

_, exist := instanceIdsMap[instanceId1]
h.Equals(t, true, exist)
_, exist = instanceIdsMap[instanceId2]
h.Equals(t, true, exist)
_, exist = instanceIdsMap["non-existent instance id"]
h.Equals(t, false, exist)
}

func getDescribeInstancesResp() ec2.DescribeInstancesOutput {
return ec2.DescribeInstancesOutput{
Reservations: []*ec2.Reservation{
{
Instances: []*ec2.Instance{
{
InstanceId: aws.String(instanceId1),
},
{
InstanceId: aws.String(instanceId2),
},
},
},
},
}
}
37 changes: 37 additions & 0 deletions pkg/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"context"
"encoding/json"
"fmt"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -74,6 +75,7 @@ const (
var (
maxRetryDeadline time.Duration = 5 * time.Second
conflictRetryInterval time.Duration = 750 * time.Millisecond
instanceIDRegex = regexp.MustCompile(`^i-.*`)
)

// Node represents a kubernetes node with functions to manipulate its state via the kubernetes api server
Expand Down Expand Up @@ -635,6 +637,41 @@ func (n Node) fetchKubernetesNode(nodeName string) (*corev1.Node, error) {
return &matchingNodes.Items[0], nil
}

// fetchKubernetesNode will send an http request to the k8s api server and return list of AWS EC2 instance id
func (n Node) FetchKubernetesNodeInstanceIds() ([]string, error) {
ids := []string{}

if n.nthConfig.DryRun {
log.Info().Msgf("Would have retrieved nodes, but dry-run flag was set")
return ids, nil
}
matchingNodes, err := n.drainHelper.Client.CoreV1().Nodes().List(context.TODO(), metav1.ListOptions{})
if err != nil {
log.Warn().Msgf("Unable to list Nodes")
return nil, err
}

if matchingNodes == nil || matchingNodes.Items == nil {
return nil, fmt.Errorf("failed to list nodes")
}

for _, node := range matchingNodes.Items {
// sample providerID: aws:///us-west-2a/i-0abcd1234efgh5678
parts := strings.Split(node.Spec.ProviderID, "/")
if len(parts) < 2 {
log.Warn().Msgf("Found invalid providerID: %s", node.Spec.ProviderID)
continue
}

instanceId := parts[len(parts)-1]
if instanceIDRegex.MatchString(instanceId) {
ids = append(ids, parts[len(parts)-1])
}
}

return ids, nil
}

func (n Node) fetchAllPods(nodeName string) (*corev1.PodList, error) {
if n.nthConfig.DryRun {
log.Info().Msgf("Would have retrieved running pod list on node %s, but dry-run flag was set", nodeName)
Expand Down
32 changes: 31 additions & 1 deletion pkg/node/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package node_test

import (
"context"
"fmt"
"strconv"
"strings"
"testing"
Expand All @@ -35,7 +36,11 @@ import (
)

// Size of the fakeRecorder buffer
const recorderBufferSize = 10
const (
recorderBufferSize = 10
instanceId1 = "i-0abcd1234efgh5678"
instanceId2 = "i-0wxyz5678ijkl1234"
)

var nodeName = "NAME"

Expand Down Expand Up @@ -379,6 +384,31 @@ func TestUncordonIfRebootedTimeParseFailure(t *testing.T) {
h.Assert(t, err != nil, "Failed to return error on UncordonIfReboted failure to parse time")
}

func TestFetchKubernetesNodeInstanceIds(t *testing.T) {
client := fake.NewSimpleClientset(
&v1.Node{
ObjectMeta: metav1.ObjectMeta{Name: "node-1"},
Spec: v1.NodeSpec{ProviderID: fmt.Sprintf("aws:///us-west-2a/%s", instanceId1)},
},
&v1.Node{
ObjectMeta: metav1.ObjectMeta{Name: "node-2"},
Spec: v1.NodeSpec{ProviderID: fmt.Sprintf("aws:///us-west-2a/%s", instanceId2)},
},
)

_, err := client.CoreV1().Nodes().List(context.Background(), metav1.ListOptions{})
h.Ok(t, err)

node, err := newNode(config.Config{}, client)
h.Ok(t, err)

instanceIds, err := node.FetchKubernetesNodeInstanceIds()
h.Ok(t, err)
h.Equals(t, 2, len(instanceIds))
h.Equals(t, instanceId1, instanceIds[0])
h.Equals(t, instanceId2, instanceIds[1])
}

func TestFilterOutDaemonSetPods(t *testing.T) {
tNode, err := newNode(config.Config{IgnoreDaemonSets: true}, fake.NewSimpleClientset())
h.Ok(t, err)
Expand Down
Loading