Skip to content

Commit eac1f3c

Browse files
(Backport v1.2.x)fix: prevent hash collisions while resolving subnets, security groups and AMIs from nodeclass selectors (#8660)
Co-authored-by: Saurav Agarwalla <saurav-agarwalla@users.noreply.github.com>
1 parent 95881e1 commit eac1f3c

File tree

16 files changed

+640
-147
lines changed

16 files changed

+640
-147
lines changed

pkg/cloudprovider/suite_test.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ var _ = Describe("CloudProvider", func() {
656656
},
657657
},
658658
})
659-
awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{
659+
awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{
660660
SecurityGroups: []ec2types.SecurityGroup{
661661
{
662662
GroupId: aws.String(validSecurityGroup),
@@ -670,7 +670,7 @@ var _ = Describe("CloudProvider", func() {
670670
},
671671
},
672672
})
673-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{
673+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{
674674
Subnets: []ec2types.Subnet{
675675
{
676676
SubnetId: aws.String(validSubnet1),
@@ -1152,7 +1152,7 @@ var _ = Describe("CloudProvider", func() {
11521152
})
11531153
It("should launch instances into subnet with the most available IP addresses", func() {
11541154
awsEnv.SubnetCache.Flush()
1155-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
1155+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
11561156
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(10),
11571157
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}},
11581158
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(100),
@@ -1169,7 +1169,7 @@ var _ = Describe("CloudProvider", func() {
11691169
})
11701170
It("should launch instances into subnet with the most available IP addresses in-between cache refreshes", func() {
11711171
awsEnv.SubnetCache.Flush()
1172-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
1172+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
11731173
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(10),
11741174
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}},
11751175
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(11),
@@ -1197,7 +1197,7 @@ var _ = Describe("CloudProvider", func() {
11971197
Expect(fake.SubnetsFromFleetRequest(createFleetInput)).To(ConsistOf("test-subnet-1"))
11981198
})
11991199
It("should update in-flight IPs when a CreateFleet error occurs", func() {
1200-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
1200+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
12011201
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailableIpAddressCount: aws.Int32(10),
12021202
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}},
12031203
}})
@@ -1208,12 +1208,20 @@ var _ = Describe("CloudProvider", func() {
12081208
Expect(len(bindings)).To(Equal(0))
12091209
})
12101210
It("should launch instances into subnets that are excluded by another NodePool", func() {
1211-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
1212-
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(10),
1213-
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}},
1214-
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(100),
1215-
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-2")}}},
1216-
}})
1211+
awsEnv.EC2API.Subnets.Store("test-zone-1a", ec2types.Subnet{
1212+
SubnetId: aws.String("test-subnet-1"),
1213+
AvailabilityZone: aws.String("test-zone-1a"),
1214+
AvailabilityZoneId: aws.String("tstz1-1a"),
1215+
AvailableIpAddressCount: aws.Int32(10),
1216+
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}},
1217+
})
1218+
awsEnv.EC2API.Subnets.Store("test-zone-1b", ec2types.Subnet{
1219+
SubnetId: aws.String("test-subnet-2"),
1220+
AvailabilityZone: aws.String("test-zone-1b"),
1221+
AvailabilityZoneId: aws.String("tstz1-1a"),
1222+
AvailableIpAddressCount: aws.Int32(100),
1223+
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-2")}},
1224+
})
12171225
nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{{Tags: map[string]string{"Name": "test-subnet-1"}}}
12181226
ExpectApplied(ctx, env.Client, nodePool, nodeClass)
12191227
controller := nodeclass.NewController(env.Client, recorder, awsEnv.SubnetProvider, awsEnv.SecurityGroupProvider, awsEnv.AMIProvider, awsEnv.InstanceProfileProvider, awsEnv.LaunchTemplateProvider)

pkg/controllers/nodeclass/ami_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ var _ = Describe("NodeClass AMI Status Controller", func() {
631631
awsEnv.Clock.Step(40 * time.Minute)
632632

633633
// Flush Cache
634-
awsEnv.EC2Cache.Flush()
634+
awsEnv.AMICache.Flush()
635635

636636
ExpectObjectReconciled(ctx, env.Client, controller, nodeClass)
637637
nodeClass = ExpectExists(ctx, env.Client, nodeClass)
@@ -730,7 +730,7 @@ var _ = Describe("NodeClass AMI Status Controller", func() {
730730
},
731731
})
732732

733-
awsEnv.EC2Cache.Flush()
733+
awsEnv.AMICache.Flush()
734734

735735
ExpectApplied(ctx, env.Client, nodeClass)
736736
ExpectObjectReconciled(ctx, env.Client, controller, nodeClass)

pkg/controllers/nodeclass/subnet_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() {
8080
Expect(nodeClass.StatusConditions().IsTrue(v1.ConditionTypeSubnetsReady)).To(BeTrue())
8181
})
8282
It("Should have the correct ordering for the Subnets", func() {
83-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
83+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
8484
{SubnetId: aws.String("subnet-test1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(20)},
8585
{SubnetId: aws.String("subnet-test2"), AvailabilityZone: aws.String("test-zone-1b"), AvailabilityZoneId: aws.String("tstz1-1b"), AvailableIpAddressCount: aws.Int32(100)},
8686
{SubnetId: aws.String("subnet-test3"), AvailabilityZone: aws.String("test-zone-1c"), AvailabilityZoneId: aws.String("tstz1-1c"), AvailableIpAddressCount: aws.Int32(50)},

pkg/controllers/providers/ssm/invalidation/controller.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ import (
2929
v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1"
3030
"github.com/aws/karpenter-provider-aws/pkg/providers/amifamily"
3131
"github.com/aws/karpenter-provider-aws/pkg/providers/ssm"
32+
33+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
34+
"k8s.io/apimachinery/pkg/util/uuid"
3235
)
3336

3437
// The SSM Invalidation controller is responsible for invalidating "latest" SSM parameters when they point to deprecated
@@ -66,6 +69,9 @@ func (c *Controller) Reconcile(ctx context.Context) (reconcile.Result, error) {
6669
amis := []amifamily.AMI{}
6770
for _, nodeClass := range lo.Map(lo.Keys(amiIDsToParameters), func(amiID string, _ int) *v1.EC2NodeClass {
6871
return &v1.EC2NodeClass{
72+
ObjectMeta: metav1.ObjectMeta{
73+
UID: uuid.NewUUID(), // ensures that this doesn't hit the AMI cache.
74+
},
6975
Spec: v1.EC2NodeClassSpec{
7076
AMISelectorTerms: []v1.AMISelectorTerm{{ID: amiID}},
7177
},

pkg/fake/ec2api.go

Lines changed: 98 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ type CapacityPool struct {
4848
type EC2Behavior struct {
4949
DescribeImagesOutput AtomicPtr[ec2.DescribeImagesOutput]
5050
DescribeLaunchTemplatesOutput AtomicPtr[ec2.DescribeLaunchTemplatesOutput]
51-
DescribeSubnetsOutput AtomicPtr[ec2.DescribeSubnetsOutput]
52-
DescribeSecurityGroupsOutput AtomicPtr[ec2.DescribeSecurityGroupsOutput]
51+
DescribeSubnetsBehavior MockedFunction[ec2.DescribeSubnetsInput, ec2.DescribeSubnetsOutput]
52+
DescribeSecurityGroupsBehavior MockedFunction[ec2.DescribeSecurityGroupsInput, ec2.DescribeSecurityGroupsOutput]
5353
DescribeInstanceTypesOutput AtomicPtr[ec2.DescribeInstanceTypesOutput]
5454
DescribeInstanceTypeOfferingsOutput AtomicPtr[ec2.DescribeInstanceTypeOfferingsOutput]
5555
DescribeAvailabilityZonesOutput AtomicPtr[ec2.DescribeAvailabilityZonesOutput]
@@ -60,6 +60,7 @@ type EC2Behavior struct {
6060
CreateTagsBehavior MockedFunction[ec2.CreateTagsInput, ec2.CreateTagsOutput]
6161
CalledWithCreateLaunchTemplateInput AtomicPtrSlice[ec2.CreateLaunchTemplateInput]
6262
CalledWithDescribeImagesInput AtomicPtrSlice[ec2.DescribeImagesInput]
63+
Subnets sync.Map
6364
Instances sync.Map
6465
LaunchTemplates sync.Map
6566
InsufficientCapacityPools atomic.Slice[CapacityPool]
@@ -83,8 +84,8 @@ var DefaultSupportedUsageClasses = []ec2types.UsageClassType{ec2types.UsageClass
8384
func (e *EC2API) Reset() {
8485
e.DescribeImagesOutput.Reset()
8586
e.DescribeLaunchTemplatesOutput.Reset()
86-
e.DescribeSubnetsOutput.Reset()
87-
e.DescribeSecurityGroupsOutput.Reset()
87+
e.DescribeSubnetsBehavior.Reset()
88+
e.DescribeSecurityGroupsBehavior.Reset()
8889
e.DescribeInstanceTypesOutput.Reset()
8990
e.DescribeInstanceTypeOfferingsOutput.Reset()
9091
e.DescribeAvailabilityZonesOutput.Reset()
@@ -380,107 +381,109 @@ func (e *EC2API) DeleteLaunchTemplate(_ context.Context, input *ec2.DeleteLaunch
380381
}
381382

382383
func (e *EC2API) DescribeSubnets(_ context.Context, input *ec2.DescribeSubnetsInput, _ ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) {
383-
if !e.NextError.IsNil() {
384-
defer e.NextError.Reset()
385-
return nil, e.NextError.Get()
386-
}
387-
if !e.DescribeSubnetsOutput.IsNil() {
388-
describeSubnetsOutput := e.DescribeSubnetsOutput.Clone()
389-
describeSubnetsOutput.Subnets = FilterDescribeSubnets(describeSubnetsOutput.Subnets, input.Filters)
390-
return describeSubnetsOutput, nil
391-
}
392-
subnets := []ec2types.Subnet{
393-
{
394-
SubnetId: aws.String("subnet-test1"),
395-
AvailabilityZone: aws.String("test-zone-1a"),
396-
AvailabilityZoneId: aws.String("tstz1-1a"),
397-
AvailableIpAddressCount: aws.Int32(100),
398-
MapPublicIpOnLaunch: aws.Bool(false),
399-
Tags: []ec2types.Tag{
400-
{Key: aws.String("Name"), Value: aws.String("test-subnet-1")},
401-
{Key: aws.String("foo"), Value: aws.String("bar")},
384+
return e.DescribeSubnetsBehavior.Invoke(input, func(input *ec2.DescribeSubnetsInput) (*ec2.DescribeSubnetsOutput, error) {
385+
output := &ec2.DescribeSubnetsOutput{}
386+
e.Subnets.Range(func(key, value any) bool {
387+
subnet := value.(ec2types.Subnet)
388+
if lo.Contains(input.SubnetIds, lo.FromPtr(subnet.SubnetId)) || len(input.Filters) != 0 && len(FilterDescribeSubnets([]ec2types.Subnet{subnet}, input.Filters)) != 0 {
389+
output.Subnets = append(output.Subnets, subnet)
390+
}
391+
return true
392+
})
393+
if len(output.Subnets) != 0 {
394+
return output, nil
395+
}
396+
397+
defaultSubnets := []ec2types.Subnet{
398+
{
399+
SubnetId: aws.String("subnet-test1"),
400+
AvailabilityZone: aws.String("test-zone-1a"),
401+
AvailabilityZoneId: aws.String("tstz1-1a"),
402+
AvailableIpAddressCount: aws.Int32(100),
403+
MapPublicIpOnLaunch: aws.Bool(false),
404+
Tags: []ec2types.Tag{
405+
{Key: aws.String("Name"), Value: aws.String("test-subnet-1")},
406+
{Key: aws.String("foo"), Value: aws.String("bar")},
407+
},
408+
VpcId: aws.String("vpc-test1"),
402409
},
403-
},
404-
{
405-
SubnetId: aws.String("subnet-test2"),
406-
AvailabilityZone: aws.String("test-zone-1b"),
407-
AvailabilityZoneId: aws.String("tstz1-1b"),
408-
AvailableIpAddressCount: aws.Int32(100),
409-
MapPublicIpOnLaunch: aws.Bool(true),
410-
Tags: []ec2types.Tag{
411-
{Key: aws.String("Name"), Value: aws.String("test-subnet-2")},
412-
{Key: aws.String("foo"), Value: aws.String("bar")},
410+
{
411+
SubnetId: aws.String("subnet-test2"),
412+
AvailabilityZone: aws.String("test-zone-1b"),
413+
AvailabilityZoneId: aws.String("tstz1-1b"),
414+
AvailableIpAddressCount: aws.Int32(100),
415+
MapPublicIpOnLaunch: aws.Bool(true),
416+
Tags: []ec2types.Tag{
417+
{Key: aws.String("Name"), Value: aws.String("test-subnet-2")},
418+
{Key: aws.String("foo"), Value: aws.String("bar")},
419+
},
420+
VpcId: aws.String("vpc-test1"),
413421
},
414-
},
415-
{
416-
SubnetId: aws.String("subnet-test3"),
417-
AvailabilityZone: aws.String("test-zone-1c"),
418-
AvailabilityZoneId: aws.String("tstz1-1c"),
419-
AvailableIpAddressCount: aws.Int32(100),
420-
Tags: []ec2types.Tag{
421-
{Key: aws.String("Name"), Value: aws.String("test-subnet-3")},
422-
{Key: aws.String("TestTag")},
423-
{Key: aws.String("foo"), Value: aws.String("bar")},
422+
{
423+
SubnetId: aws.String("subnet-test3"),
424+
AvailabilityZone: aws.String("test-zone-1c"),
425+
AvailabilityZoneId: aws.String("tstz1-1c"),
426+
AvailableIpAddressCount: aws.Int32(100),
427+
Tags: []ec2types.Tag{
428+
{Key: aws.String("Name"), Value: aws.String("test-subnet-3")},
429+
{Key: aws.String("TestTag")},
430+
{Key: aws.String("foo"), Value: aws.String("bar")},
431+
},
432+
VpcId: aws.String("vpc-test1"),
424433
},
425-
},
426-
{
427-
SubnetId: aws.String("subnet-test4"),
428-
AvailabilityZone: aws.String("test-zone-1a-local"),
429-
AvailabilityZoneId: aws.String("tstz1-1alocal"),
430-
AvailableIpAddressCount: aws.Int32(100),
431-
MapPublicIpOnLaunch: aws.Bool(true),
432-
Tags: []ec2types.Tag{
433-
{Key: aws.String("Name"), Value: aws.String("test-subnet-4")},
434+
{
435+
SubnetId: aws.String("subnet-test4"),
436+
AvailabilityZone: aws.String("test-zone-1a-local"),
437+
AvailabilityZoneId: aws.String("tstz1-1alocal"),
438+
AvailableIpAddressCount: aws.Int32(100),
439+
MapPublicIpOnLaunch: aws.Bool(true),
440+
Tags: []ec2types.Tag{
441+
{Key: aws.String("Name"), Value: aws.String("test-subnet-4")},
442+
},
443+
VpcId: aws.String("vpc-test1"),
434444
},
435-
},
436-
}
437-
if len(input.Filters) == 0 {
438-
return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid")
439-
}
440-
return &ec2.DescribeSubnetsOutput{Subnets: FilterDescribeSubnets(subnets, input.Filters)}, nil
445+
}
446+
if len(input.Filters) == 0 {
447+
return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid")
448+
}
449+
return &ec2.DescribeSubnetsOutput{Subnets: FilterDescribeSubnets(defaultSubnets, input.Filters)}, nil
450+
})
441451
}
442452

443453
func (e *EC2API) DescribeSecurityGroups(_ context.Context, input *ec2.DescribeSecurityGroupsInput, _ ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) {
444-
if !e.NextError.IsNil() {
445-
defer e.NextError.Reset()
446-
return nil, e.NextError.Get()
447-
}
448-
if !e.DescribeSecurityGroupsOutput.IsNil() {
449-
describeSecurityGroupsOutput := e.DescribeSecurityGroupsOutput.Clone()
450-
describeSecurityGroupsOutput.SecurityGroups = FilterDescribeSecurtyGroups(describeSecurityGroupsOutput.SecurityGroups, input.Filters)
451-
return e.DescribeSecurityGroupsOutput.Clone(), nil
452-
}
453-
sgs := []ec2types.SecurityGroup{
454-
{
455-
GroupId: aws.String("sg-test1"),
456-
GroupName: aws.String("securityGroup-test1"),
457-
Tags: []ec2types.Tag{
458-
{Key: aws.String("Name"), Value: aws.String("test-security-group-1")},
459-
{Key: aws.String("foo"), Value: aws.String("bar")},
454+
return e.DescribeSecurityGroupsBehavior.Invoke(input, func(input *ec2.DescribeSecurityGroupsInput) (*ec2.DescribeSecurityGroupsOutput, error) {
455+
defaultSecurityGroups := []ec2types.SecurityGroup{
456+
{
457+
GroupId: aws.String("sg-test1"),
458+
GroupName: aws.String("securityGroup-test1"),
459+
Tags: []ec2types.Tag{
460+
{Key: aws.String("Name"), Value: aws.String("test-security-group-1")},
461+
{Key: aws.String("foo"), Value: aws.String("bar")},
462+
},
460463
},
461-
},
462-
{
463-
GroupId: aws.String("sg-test2"),
464-
GroupName: aws.String("securityGroup-test2"),
465-
Tags: []ec2types.Tag{
466-
{Key: aws.String("Name"), Value: aws.String("test-security-group-2")},
467-
{Key: aws.String("foo"), Value: aws.String("bar")},
464+
{
465+
GroupId: aws.String("sg-test2"),
466+
GroupName: aws.String("securityGroup-test2"),
467+
Tags: []ec2types.Tag{
468+
{Key: aws.String("Name"), Value: aws.String("test-security-group-2")},
469+
{Key: aws.String("foo"), Value: aws.String("bar")},
470+
},
468471
},
469-
},
470-
{
471-
GroupId: aws.String("sg-test3"),
472-
GroupName: aws.String("securityGroup-test3"),
473-
Tags: []ec2types.Tag{
474-
{Key: aws.String("Name"), Value: aws.String("test-security-group-3")},
475-
{Key: aws.String("TestTag")},
476-
{Key: aws.String("foo"), Value: aws.String("bar")},
472+
{
473+
GroupId: aws.String("sg-test3"),
474+
GroupName: aws.String("securityGroup-test3"),
475+
Tags: []ec2types.Tag{
476+
{Key: aws.String("Name"), Value: aws.String("test-security-group-3")},
477+
{Key: aws.String("TestTag")},
478+
{Key: aws.String("foo"), Value: aws.String("bar")},
479+
},
477480
},
478-
},
479-
}
480-
if len(input.Filters) == 0 {
481-
return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid")
482-
}
483-
return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: FilterDescribeSecurtyGroups(sgs, input.Filters)}, nil
481+
}
482+
if len(input.Filters) == 0 {
483+
return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid")
484+
}
485+
return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: FilterDescribeSecurtyGroups(defaultSecurityGroups, input.Filters)}, nil
486+
})
484487
}
485488

486489
func (e *EC2API) DescribeAvailabilityZones(context.Context, *ec2.DescribeAvailabilityZonesInput, ...func(*ec2.Options)) (*ec2.DescribeAvailabilityZonesOutput, error) {

pkg/fake/types.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525

2626
type MockedFunction[I any, O any] struct {
2727
Output AtomicPtr[O] // Output to return on call to this function
28+
MultiOut AtomicPtrSlice[O]
2829
OutputPages AtomicPtrSlice[O]
2930
CalledWithInput AtomicPtrSlice[I] // Slice used to keep track of passed input to this function
3031
Error AtomicError // Error to return a certain number of times defined by custom error options
@@ -38,6 +39,7 @@ type MockedFunction[I any, O any] struct {
3839
// each other.
3940
func (m *MockedFunction[I, O]) Reset() {
4041
m.Output.Reset()
42+
m.MultiOut.Reset()
4143
m.OutputPages.Reset()
4244
m.CalledWithInput.Reset()
4345
m.Error.Reset()
@@ -59,6 +61,11 @@ func (m *MockedFunction[I, O]) Invoke(input *I, defaultTransformer func(*I) (*O,
5961
m.successfulCalls.Add(1)
6062
return m.Output.Clone(), nil
6163
}
64+
65+
if m.MultiOut.Len() > 0 {
66+
m.successfulCalls.Add(1)
67+
return m.MultiOut.Pop(), nil
68+
}
6269
// This output pages multi-threaded handling isn't perfect
6370
// It will fail if pages are asynchronously requested from the same NextToken
6471
if m.OutputPages.Len() > 0 {

0 commit comments

Comments
 (0)