Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@ package volume

import (
"fmt"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
)

const (
// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/device_naming.html
possibleAttachmentDevicePrefix = "/dev/"
)

type describeVolumesProvider struct {
ec2Client ec2iface.EC2API
instanceID string
Expand Down Expand Up @@ -38,7 +44,7 @@ func (p *describeVolumesProvider) DeviceToSerialMap() (map[string]string, error)
for _, volume := range output.Volumes {
for _, attachment := range volume.Attachments {
if attachment.Device != nil && attachment.VolumeId != nil {
result[aws.StringValue(attachment.Device)] = aws.StringValue(attachment.VolumeId)
result[strings.TrimPrefix(aws.StringValue(attachment.Device), possibleAttachmentDevicePrefix)] = aws.StringValue(attachment.VolumeId)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@ import (
"errors"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/stretchr/testify/assert"
)

// construct the return results for the mocked DescribeTags api
var (
device1 = "/dev/xvdc"
device1 = "sda1"
volumeId1 = "vol-0303a1cc896c42d28"
volumeAttachment1 = ec2.VolumeAttachment{Device: &device1, VolumeId: &volumeId1}
volumeAttachment1 = ec2.VolumeAttachment{Device: aws.String("/dev/sda1"), VolumeId: aws.String(volumeId1)}
availabilityZone = "us-east-1a"
volume1 = ec2.Volume{
Attachments: []*ec2.VolumeAttachment{&volumeAttachment1},
Expand All @@ -25,15 +26,25 @@ var (
)

var (
device2 = "/dev/xvdf"
device2 = "xvdf"
volumeId2 = "vol-0c241693efb58734a"
volumeAttachment2 = ec2.VolumeAttachment{Device: &device2, VolumeId: &volumeId2}
volumeAttachment2 = ec2.VolumeAttachment{Device: aws.String("/dev/xvdf"), VolumeId: aws.String(volumeId2)}
volume2 = ec2.Volume{
Attachments: []*ec2.VolumeAttachment{&volumeAttachment2},
AvailabilityZone: &availabilityZone,
}
)

var (
device3 = "xvdda"
volumeId3 = "vol-09ada5ca79a65cdd2"
volumeAttachment3 = ec2.VolumeAttachment{Device: aws.String("xvdda"), VolumeId: aws.String(volumeId3)}
volume3 = ec2.Volume{
Attachments: []*ec2.VolumeAttachment{&volumeAttachment3},
AvailabilityZone: &availabilityZone,
}
)

type mockEC2Client struct {
ec2iface.EC2API

Expand All @@ -51,12 +62,12 @@ func (m *mockEC2Client) DescribeVolumes(input *ec2.DescribeVolumesInput) (*ec2.D
if input.NextToken == nil {
return &ec2.DescribeVolumesOutput{
NextToken: &device2,
Volumes: []*ec2.Volume{&volume1},
Volumes: []*ec2.Volume{&volume1, &volume2},
}, nil
}
return &ec2.DescribeVolumesOutput{
NextToken: nil,
Volumes: []*ec2.Volume{&volume2},
Volumes: []*ec2.Volume{&volume3},
}, nil
}

Expand All @@ -66,7 +77,7 @@ func TestDescribeVolumesProvider(t *testing.T) {
got, err := p.DeviceToSerialMap()
assert.NoError(t, err)
assert.Equal(t, 2, ec2Client.callCount)
want := map[string]string{device1: volumeId1, device2: volumeId2}
want := map[string]string{device1: volumeId1, device2: volumeId2, device3: volumeId3}
assert.Equal(t, want, got)
ec2Client.err = errors.New("test")
ec2Client.callCount = 0
Expand Down
54 changes: 5 additions & 49 deletions plugins/processors/ec2tagger/internal/volume/volume.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ package volume
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"sync"

Expand Down Expand Up @@ -40,25 +38,21 @@ type Cache interface {
type cache struct {
sync.RWMutex
// device name to serial mapping
cache map[string]string
provider Provider
fetchBlockName func(string) string
cache map[string]string
provider Provider
}

func NewCache(provider Provider) Cache {
return &cache{
cache: make(map[string]string),
provider: provider,
fetchBlockName: findNvmeBlockNameIfPresent,
cache: make(map[string]string),
provider: provider,
}
}

func (c *cache) add(devName, serial string) {
normalizedName := c.normalizeName(devName)

c.Lock()
defer c.Unlock()
c.cache[normalizedName] = serial
c.cache[devName] = serial
}

func (c *cache) reset() {
Expand Down Expand Up @@ -105,41 +99,3 @@ func (c *cache) Devices() []string {
defer c.RUnlock()
return maps.Keys(c.cache)
}

func (c *cache) normalizeName(devName string) string {
normalized := c.fetchBlockName(devName)
if normalized == "" {
normalized = devName
}

// to match the disk device tag
return strings.ReplaceAll(normalized, "/dev/", "")
}

// find nvme block name by symlink, if symlink doesn't exist, return ""
func findNvmeBlockNameIfPresent(devName string) string {
// for nvme(ssd), there is a symlink from devName to nvme block name, i.e. /dev/xvda -> /dev/nvme0n1
// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/nvme-ebs-volumes.html
hasRootFs := true
if _, err := os.Lstat("/rootfs/proc"); os.IsNotExist(err) {
hasRootFs = false
}
nvmeName := ""

if hasRootFs {
devName = "/rootfs" + devName
}

if info, err := os.Lstat(devName); err == nil {
if info.Mode()&os.ModeSymlink != 0 {
if path, err := filepath.EvalSymlinks(devName); err == nil {
nvmeName = path
}
}
}

if nvmeName != "" && hasRootFs {
nvmeName = strings.TrimPrefix(nvmeName, "/rootfs")
}
return nvmeName
}
13 changes: 6 additions & 7 deletions plugins/processors/ec2tagger/internal/volume/volume_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,26 @@ func TestCache(t *testing.T) {
testErr := errors.New("test")
p := &mockProvider{
serialMap: map[string]string{
"/dev/xvdf": "foo",
"xvdc": "bar",
"xvdc1": "baz",
"xvdf": "foo",
"nvme1n1": "foo",
"xvdc": "bar",
"xvdc1": "baz",
},
err: testErr,
}
c := NewCache(nil).(*cache)
c.fetchBlockName = func(s string) string {
return ""
}
assert.ErrorIs(t, c.Refresh(), errNoProviders)
c.provider = p
assert.ErrorIs(t, c.Refresh(), testErr)
p.err = nil
assert.NoError(t, c.Refresh())
assert.Equal(t, "foo", c.Serial("xvdf"))
assert.Equal(t, "foo", c.Serial("nvme1n1"))
assert.Equal(t, "bar", c.Serial("xvdc"))
assert.Equal(t, "baz", c.Serial("xvdc1"))
assert.Equal(t, "bar", c.Serial("xvdc2"))
assert.Equal(t, "", c.Serial("xvde"))
got := c.Devices()
sort.Strings(got)
assert.Equal(t, []string{"xvdc", "xvdc1", "xvdf"}, got)
assert.Equal(t, []string{"nvme1n1", "xvdc", "xvdc1", "xvdf"}, got)
}