Skip to content

Commit 7f8a83d

Browse files
(Backport v1.3.x)fix: prevent hash collisions while resolving subnets, security groups and AMIs from nodeclass selectors (#8659)
Co-authored-by: Saurav Agarwalla <saurav-agarwalla@users.noreply.github.com>
1 parent a313d43 commit 7f8a83d

File tree

16 files changed

+638
-146
lines changed

16 files changed

+638
-146
lines changed

pkg/cloudprovider/suite_test.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ var _ = Describe("CloudProvider", func() {
659659
},
660660
},
661661
})
662-
awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{
662+
awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{
663663
SecurityGroups: []ec2types.SecurityGroup{
664664
{
665665
GroupId: aws.String(validSecurityGroup),
@@ -673,7 +673,7 @@ var _ = Describe("CloudProvider", func() {
673673
},
674674
},
675675
})
676-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{
676+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{
677677
Subnets: []ec2types.Subnet{
678678
{
679679
SubnetId: aws.String(validSubnet1),
@@ -1180,7 +1180,7 @@ var _ = Describe("CloudProvider", func() {
11801180
})
11811181
It("should launch instances into subnet with the most available IP addresses", func() {
11821182
awsEnv.SubnetCache.Flush()
1183-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
1183+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
11841184
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(10),
11851185
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}},
11861186
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(100),
@@ -1197,7 +1197,7 @@ var _ = Describe("CloudProvider", func() {
11971197
})
11981198
It("should launch instances into subnet with the most available IP addresses in-between cache refreshes", func() {
11991199
awsEnv.SubnetCache.Flush()
1200-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
1200+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
12011201
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(10),
12021202
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}},
12031203
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(11),
@@ -1225,7 +1225,7 @@ var _ = Describe("CloudProvider", func() {
12251225
Expect(fake.SubnetsFromFleetRequest(createFleetInput)).To(ConsistOf("test-subnet-1"))
12261226
})
12271227
It("should update in-flight IPs when a CreateFleet error occurs", func() {
1228-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
1228+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
12291229
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailableIpAddressCount: aws.Int32(10),
12301230
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}},
12311231
}})
@@ -1236,12 +1236,20 @@ var _ = Describe("CloudProvider", func() {
12361236
Expect(len(bindings)).To(Equal(0))
12371237
})
12381238
It("should launch instances into subnets that are excluded by another NodePool", func() {
1239-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
1240-
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(10),
1241-
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}},
1242-
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(100),
1243-
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-2")}}},
1244-
}})
1239+
awsEnv.EC2API.Subnets.Store("test-zone-1a", ec2types.Subnet{
1240+
SubnetId: aws.String("test-subnet-1"),
1241+
AvailabilityZone: aws.String("test-zone-1a"),
1242+
AvailabilityZoneId: aws.String("tstz1-1a"),
1243+
AvailableIpAddressCount: aws.Int32(10),
1244+
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}},
1245+
})
1246+
awsEnv.EC2API.Subnets.Store("test-zone-1b", ec2types.Subnet{
1247+
SubnetId: aws.String("test-subnet-2"),
1248+
AvailabilityZone: aws.String("test-zone-1b"),
1249+
AvailabilityZoneId: aws.String("tstz1-1a"),
1250+
AvailableIpAddressCount: aws.Int32(100),
1251+
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-2")}},
1252+
})
12451253
nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{{Tags: map[string]string{"Name": "test-subnet-1"}}}
12461254
ExpectApplied(ctx, env.Client, nodePool, nodeClass)
12471255
controller := nodeclass.NewController(awsEnv.Clock, env.Client, recorder, awsEnv.SubnetProvider, awsEnv.SecurityGroupProvider, awsEnv.AMIProvider, awsEnv.InstanceProfileProvider, awsEnv.LaunchTemplateProvider, awsEnv.CapacityReservationProvider, awsEnv.EC2API, awsEnv.ValidationCache, awsEnv.AMIResolver)

pkg/controllers/nodeclass/ami_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ var _ = Describe("NodeClass AMI Status Controller", func() {
631631
awsEnv.Clock.Step(40 * time.Minute)
632632

633633
// Flush Cache
634-
awsEnv.EC2Cache.Flush()
634+
awsEnv.AMICache.Flush()
635635

636636
ExpectObjectReconciled(ctx, env.Client, controller, nodeClass)
637637
nodeClass = ExpectExists(ctx, env.Client, nodeClass)
@@ -730,7 +730,7 @@ var _ = Describe("NodeClass AMI Status Controller", func() {
730730
},
731731
})
732732

733-
awsEnv.EC2Cache.Flush()
733+
awsEnv.AMICache.Flush()
734734

735735
ExpectApplied(ctx, env.Client, nodeClass)
736736
ExpectObjectReconciled(ctx, env.Client, controller, nodeClass)

pkg/controllers/nodeclass/subnet_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() {
8080
Expect(nodeClass.StatusConditions().IsTrue(v1.ConditionTypeSubnetsReady)).To(BeTrue())
8181
})
8282
It("Should have the correct ordering for the Subnets", func() {
83-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
83+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
8484
{SubnetId: aws.String("subnet-test1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(20)},
8585
{SubnetId: aws.String("subnet-test2"), AvailabilityZone: aws.String("test-zone-1b"), AvailabilityZoneId: aws.String("tstz1-1b"), AvailableIpAddressCount: aws.Int32(100)},
8686
{SubnetId: aws.String("subnet-test3"), AvailabilityZone: aws.String("test-zone-1c"), AvailabilityZoneId: aws.String("tstz1-1c"), AvailableIpAddressCount: aws.Int32(50)},

pkg/controllers/providers/ssm/invalidation/controller.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ import (
2929
v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1"
3030
"github.com/aws/karpenter-provider-aws/pkg/providers/amifamily"
3131
"github.com/aws/karpenter-provider-aws/pkg/providers/ssm"
32+
33+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
34+
"k8s.io/apimachinery/pkg/util/uuid"
3235
)
3336

3437
// The SSM Invalidation controller is responsible for invalidating "latest" SSM parameters when they point to deprecated
@@ -66,6 +69,9 @@ func (c *Controller) Reconcile(ctx context.Context) (reconcile.Result, error) {
6669
amis := []amifamily.AMI{}
6770
for _, nodeClass := range lo.Map(lo.Keys(amiIDsToParameters), func(amiID string, _ int) *v1.EC2NodeClass {
6871
return &v1.EC2NodeClass{
72+
ObjectMeta: metav1.ObjectMeta{
73+
UID: uuid.NewUUID(), // ensures that this doesn't hit the AMI cache.
74+
},
6975
Spec: v1.EC2NodeClassSpec{
7076
AMISelectorTerms: []v1.AMISelectorTerm{{ID: amiID}},
7177
},

pkg/fake/ec2api.go

Lines changed: 98 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ type EC2Behavior struct {
5050
DescribeCapacityReservationsOutput AtomicPtr[ec2.DescribeCapacityReservationsOutput]
5151
DescribeImagesOutput AtomicPtr[ec2.DescribeImagesOutput]
5252
DescribeLaunchTemplatesOutput AtomicPtr[ec2.DescribeLaunchTemplatesOutput]
53-
DescribeSubnetsOutput AtomicPtr[ec2.DescribeSubnetsOutput]
54-
DescribeSecurityGroupsOutput AtomicPtr[ec2.DescribeSecurityGroupsOutput]
53+
DescribeSubnetsBehavior MockedFunction[ec2.DescribeSubnetsInput, ec2.DescribeSubnetsOutput]
54+
DescribeSecurityGroupsBehavior MockedFunction[ec2.DescribeSecurityGroupsInput, ec2.DescribeSecurityGroupsOutput]
5555
DescribeInstanceTypesOutput AtomicPtr[ec2.DescribeInstanceTypesOutput]
5656
DescribeInstanceTypeOfferingsOutput AtomicPtr[ec2.DescribeInstanceTypeOfferingsOutput]
5757
DescribeAvailabilityZonesOutput AtomicPtr[ec2.DescribeAvailabilityZonesOutput]
@@ -67,6 +67,7 @@ type EC2Behavior struct {
6767
InsufficientCapacityPools atomic.Slice[CapacityPool]
6868
NextError AtomicError
6969

70+
Subnets sync.Map
7071
LaunchTemplates sync.Map
7172
launchTemplatesToCapacityReservations sync.Map // map[lt-name]cr-id
7273
}
@@ -88,8 +89,8 @@ var DefaultSupportedUsageClasses = []ec2types.UsageClassType{ec2types.UsageClass
8889
func (e *EC2API) Reset() {
8990
e.DescribeImagesOutput.Reset()
9091
e.DescribeLaunchTemplatesOutput.Reset()
91-
e.DescribeSubnetsOutput.Reset()
92-
e.DescribeSecurityGroupsOutput.Reset()
92+
e.DescribeSubnetsBehavior.Reset()
93+
e.DescribeSecurityGroupsBehavior.Reset()
9394
e.DescribeInstanceTypesOutput.Reset()
9495
e.DescribeInstanceTypeOfferingsOutput.Reset()
9596
e.DescribeAvailabilityZonesOutput.Reset()
@@ -455,107 +456,109 @@ func (e *EC2API) DeleteLaunchTemplate(_ context.Context, input *ec2.DeleteLaunch
455456
}
456457

457458
func (e *EC2API) DescribeSubnets(_ context.Context, input *ec2.DescribeSubnetsInput, _ ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) {
458-
if !e.NextError.IsNil() {
459-
defer e.NextError.Reset()
460-
return nil, e.NextError.Get()
461-
}
462-
if !e.DescribeSubnetsOutput.IsNil() {
463-
describeSubnetsOutput := e.DescribeSubnetsOutput.Clone()
464-
describeSubnetsOutput.Subnets = FilterDescribeSubnets(describeSubnetsOutput.Subnets, input.Filters)
465-
return describeSubnetsOutput, nil
466-
}
467-
subnets := []ec2types.Subnet{
468-
{
469-
SubnetId: aws.String("subnet-test1"),
470-
AvailabilityZone: aws.String("test-zone-1a"),
471-
AvailabilityZoneId: aws.String("tstz1-1a"),
472-
AvailableIpAddressCount: aws.Int32(100),
473-
MapPublicIpOnLaunch: aws.Bool(false),
474-
Tags: []ec2types.Tag{
475-
{Key: aws.String("Name"), Value: aws.String("test-subnet-1")},
476-
{Key: aws.String("foo"), Value: aws.String("bar")},
459+
return e.DescribeSubnetsBehavior.Invoke(input, func(input *ec2.DescribeSubnetsInput) (*ec2.DescribeSubnetsOutput, error) {
460+
output := &ec2.DescribeSubnetsOutput{}
461+
e.Subnets.Range(func(key, value any) bool {
462+
subnet := value.(ec2types.Subnet)
463+
if lo.Contains(input.SubnetIds, lo.FromPtr(subnet.SubnetId)) || len(input.Filters) != 0 && len(FilterDescribeSubnets([]ec2types.Subnet{subnet}, input.Filters)) != 0 {
464+
output.Subnets = append(output.Subnets, subnet)
465+
}
466+
return true
467+
})
468+
if len(output.Subnets) != 0 {
469+
return output, nil
470+
}
471+
472+
defaultSubnets := []ec2types.Subnet{
473+
{
474+
SubnetId: aws.String("subnet-test1"),
475+
AvailabilityZone: aws.String("test-zone-1a"),
476+
AvailabilityZoneId: aws.String("tstz1-1a"),
477+
AvailableIpAddressCount: aws.Int32(100),
478+
MapPublicIpOnLaunch: aws.Bool(false),
479+
Tags: []ec2types.Tag{
480+
{Key: aws.String("Name"), Value: aws.String("test-subnet-1")},
481+
{Key: aws.String("foo"), Value: aws.String("bar")},
482+
},
483+
VpcId: aws.String("vpc-test1"),
477484
},
478-
},
479-
{
480-
SubnetId: aws.String("subnet-test2"),
481-
AvailabilityZone: aws.String("test-zone-1b"),
482-
AvailabilityZoneId: aws.String("tstz1-1b"),
483-
AvailableIpAddressCount: aws.Int32(100),
484-
MapPublicIpOnLaunch: aws.Bool(true),
485-
Tags: []ec2types.Tag{
486-
{Key: aws.String("Name"), Value: aws.String("test-subnet-2")},
487-
{Key: aws.String("foo"), Value: aws.String("bar")},
485+
{
486+
SubnetId: aws.String("subnet-test2"),
487+
AvailabilityZone: aws.String("test-zone-1b"),
488+
AvailabilityZoneId: aws.String("tstz1-1b"),
489+
AvailableIpAddressCount: aws.Int32(100),
490+
MapPublicIpOnLaunch: aws.Bool(true),
491+
Tags: []ec2types.Tag{
492+
{Key: aws.String("Name"), Value: aws.String("test-subnet-2")},
493+
{Key: aws.String("foo"), Value: aws.String("bar")},
494+
},
495+
VpcId: aws.String("vpc-test1"),
488496
},
489-
},
490-
{
491-
SubnetId: aws.String("subnet-test3"),
492-
AvailabilityZone: aws.String("test-zone-1c"),
493-
AvailabilityZoneId: aws.String("tstz1-1c"),
494-
AvailableIpAddressCount: aws.Int32(100),
495-
Tags: []ec2types.Tag{
496-
{Key: aws.String("Name"), Value: aws.String("test-subnet-3")},
497-
{Key: aws.String("TestTag")},
498-
{Key: aws.String("foo"), Value: aws.String("bar")},
497+
{
498+
SubnetId: aws.String("subnet-test3"),
499+
AvailabilityZone: aws.String("test-zone-1c"),
500+
AvailabilityZoneId: aws.String("tstz1-1c"),
501+
AvailableIpAddressCount: aws.Int32(100),
502+
Tags: []ec2types.Tag{
503+
{Key: aws.String("Name"), Value: aws.String("test-subnet-3")},
504+
{Key: aws.String("TestTag")},
505+
{Key: aws.String("foo"), Value: aws.String("bar")},
506+
},
507+
VpcId: aws.String("vpc-test1"),
499508
},
500-
},
501-
{
502-
SubnetId: aws.String("subnet-test4"),
503-
AvailabilityZone: aws.String("test-zone-1a-local"),
504-
AvailabilityZoneId: aws.String("tstz1-1alocal"),
505-
AvailableIpAddressCount: aws.Int32(100),
506-
MapPublicIpOnLaunch: aws.Bool(true),
507-
Tags: []ec2types.Tag{
508-
{Key: aws.String("Name"), Value: aws.String("test-subnet-4")},
509+
{
510+
SubnetId: aws.String("subnet-test4"),
511+
AvailabilityZone: aws.String("test-zone-1a-local"),
512+
AvailabilityZoneId: aws.String("tstz1-1alocal"),
513+
AvailableIpAddressCount: aws.Int32(100),
514+
MapPublicIpOnLaunch: aws.Bool(true),
515+
Tags: []ec2types.Tag{
516+
{Key: aws.String("Name"), Value: aws.String("test-subnet-4")},
517+
},
518+
VpcId: aws.String("vpc-test1"),
509519
},
510-
},
511-
}
512-
if len(input.Filters) == 0 {
513-
return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid")
514-
}
515-
return &ec2.DescribeSubnetsOutput{Subnets: FilterDescribeSubnets(subnets, input.Filters)}, nil
520+
}
521+
if len(input.Filters) == 0 {
522+
return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid")
523+
}
524+
return &ec2.DescribeSubnetsOutput{Subnets: FilterDescribeSubnets(defaultSubnets, input.Filters)}, nil
525+
})
516526
}
517527

518528
func (e *EC2API) DescribeSecurityGroups(_ context.Context, input *ec2.DescribeSecurityGroupsInput, _ ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) {
519-
if !e.NextError.IsNil() {
520-
defer e.NextError.Reset()
521-
return nil, e.NextError.Get()
522-
}
523-
if !e.DescribeSecurityGroupsOutput.IsNil() {
524-
describeSecurityGroupsOutput := e.DescribeSecurityGroupsOutput.Clone()
525-
describeSecurityGroupsOutput.SecurityGroups = FilterDescribeSecurtyGroups(describeSecurityGroupsOutput.SecurityGroups, input.Filters)
526-
return e.DescribeSecurityGroupsOutput.Clone(), nil
527-
}
528-
sgs := []ec2types.SecurityGroup{
529-
{
530-
GroupId: aws.String("sg-test1"),
531-
GroupName: aws.String("securityGroup-test1"),
532-
Tags: []ec2types.Tag{
533-
{Key: aws.String("Name"), Value: aws.String("test-security-group-1")},
534-
{Key: aws.String("foo"), Value: aws.String("bar")},
529+
return e.DescribeSecurityGroupsBehavior.Invoke(input, func(input *ec2.DescribeSecurityGroupsInput) (*ec2.DescribeSecurityGroupsOutput, error) {
530+
defaultSecurityGroups := []ec2types.SecurityGroup{
531+
{
532+
GroupId: aws.String("sg-test1"),
533+
GroupName: aws.String("securityGroup-test1"),
534+
Tags: []ec2types.Tag{
535+
{Key: aws.String("Name"), Value: aws.String("test-security-group-1")},
536+
{Key: aws.String("foo"), Value: aws.String("bar")},
537+
},
535538
},
536-
},
537-
{
538-
GroupId: aws.String("sg-test2"),
539-
GroupName: aws.String("securityGroup-test2"),
540-
Tags: []ec2types.Tag{
541-
{Key: aws.String("Name"), Value: aws.String("test-security-group-2")},
542-
{Key: aws.String("foo"), Value: aws.String("bar")},
539+
{
540+
GroupId: aws.String("sg-test2"),
541+
GroupName: aws.String("securityGroup-test2"),
542+
Tags: []ec2types.Tag{
543+
{Key: aws.String("Name"), Value: aws.String("test-security-group-2")},
544+
{Key: aws.String("foo"), Value: aws.String("bar")},
545+
},
543546
},
544-
},
545-
{
546-
GroupId: aws.String("sg-test3"),
547-
GroupName: aws.String("securityGroup-test3"),
548-
Tags: []ec2types.Tag{
549-
{Key: aws.String("Name"), Value: aws.String("test-security-group-3")},
550-
{Key: aws.String("TestTag")},
551-
{Key: aws.String("foo"), Value: aws.String("bar")},
547+
{
548+
GroupId: aws.String("sg-test3"),
549+
GroupName: aws.String("securityGroup-test3"),
550+
Tags: []ec2types.Tag{
551+
{Key: aws.String("Name"), Value: aws.String("test-security-group-3")},
552+
{Key: aws.String("TestTag")},
553+
{Key: aws.String("foo"), Value: aws.String("bar")},
554+
},
552555
},
553-
},
554-
}
555-
if len(input.Filters) == 0 {
556-
return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid")
557-
}
558-
return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: FilterDescribeSecurtyGroups(sgs, input.Filters)}, nil
556+
}
557+
if len(input.Filters) == 0 {
558+
return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid")
559+
}
560+
return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: FilterDescribeSecurtyGroups(defaultSecurityGroups, input.Filters)}, nil
561+
})
559562
}
560563

561564
func (e *EC2API) DescribeAvailabilityZones(context.Context, *ec2.DescribeAvailabilityZonesInput, ...func(*ec2.Options)) (*ec2.DescribeAvailabilityZonesOutput, error) {

pkg/fake/types.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525

2626
type MockedFunction[I any, O any] struct {
2727
Output AtomicPtr[O] // Output to return on call to this function
28+
MultiOut AtomicPtrSlice[O]
2829
OutputPages AtomicPtrSlice[O]
2930
CalledWithInput AtomicPtrSlice[I] // Slice used to keep track of passed input to this function
3031
Error AtomicError // Error to return a certain number of times defined by custom error options
@@ -38,6 +39,7 @@ type MockedFunction[I any, O any] struct {
3839
// each other.
3940
func (m *MockedFunction[I, O]) Reset() {
4041
m.Output.Reset()
42+
m.MultiOut.Reset()
4143
m.OutputPages.Reset()
4244
m.CalledWithInput.Reset()
4345
m.Error.Reset()
@@ -60,6 +62,11 @@ func (m *MockedFunction[I, O]) Invoke(input *I, defaultTransformer func(*I) (*O,
6062
m.successfulCalls.Add(1)
6163
return m.Output.Clone(), nil
6264
}
65+
66+
if m.MultiOut.Len() > 0 {
67+
m.successfulCalls.Add(1)
68+
return m.MultiOut.Pop(), nil
69+
}
6370
// This output pages multi-threaded handling isn't perfect
6471
// It will fail if pages are asynchronously requested from the same NextToken
6572
if m.OutputPages.Len() > 0 {

0 commit comments

Comments
 (0)