Skip to content

Commit 50b889b

Browse files
authored
Merge pull request #3 from DefangLabs/edw/add-listener-rule
Add listener rule when invoked the first time
2 parents 8cc649c + 9b97bbf commit 50b889b

File tree

3 files changed

+162
-30
lines changed

3 files changed

+162
-30
lines changed

acme/update.go

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ import (
1010
"log"
1111
"os"
1212

13-
awsalb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types"
1413
"github.com/DefangLabs/cloudacme/aws/acm"
1514
"github.com/DefangLabs/cloudacme/aws/alb"
15+
awsalb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types"
1616
"github.com/mholt/acmez"
1717
"go.uber.org/zap"
1818
)
@@ -33,14 +33,43 @@ func UpdateAcmeCertificate(ctx context.Context, albArn, domain string, solver ac
3333
return fmt.Errorf("failed to get account key: %w", err)
3434
}
3535

36+
certToUpdate, _, err := GetExistingCertificate(ctx, albArn, domain)
37+
if err != nil {
38+
return fmt.Errorf("failed to get existing certificate: %w", err)
39+
}
40+
41+
acmeDirectory := os.Getenv("ACME_DIRECTORY")
42+
if acmeDirectory == "" {
43+
acmeDirectory = DefaultAcmeDirectory
44+
}
45+
46+
acmeClient := Acme{
47+
Directory: acmeDirectory,
48+
AccountKey: accountKey,
49+
Logger: logger,
50+
AlbArn: albArn,
51+
HttpSolver: solver,
52+
}
53+
54+
key, chain, err := acmeClient.GetCertificate(ctx, []string{domain})
55+
if err != nil {
56+
return fmt.Errorf("failed to get certificates: %w", err)
57+
}
58+
59+
if err := acm.ImportCertificate(ctx, key, chain, certToUpdate); err != nil {
60+
return fmt.Errorf("error importing certificate: %w", err)
61+
}
62+
return nil
63+
}
64+
65+
func GetExistingCertificate(ctx context.Context, albArn, domain string) (string, *x509.Certificate, error) {
3666
// Find the certificate to update from all the certificates attached to the ALB
3767
certArns, err := alb.GetAlbCerts(ctx, albArn)
3868
if err != nil {
39-
return fmt.Errorf("failed to get ALB certificates: %w", err)
69+
return "", nil, fmt.Errorf("failed to get ALB certificates: %w", err)
4070
}
4171

4272
var getCertErrs []error
43-
certToUpdate := ""
4473
for _, certArn := range certArns {
4574
certPem, err := acm.GetCertificate(ctx, certArn)
4675
if err != nil {
@@ -60,37 +89,25 @@ func UpdateAcmeCertificate(ctx context.Context, albArn, domain string, solver ac
6089
if cert.Subject.CommonName == domain {
6190
// TODO: check the issuer and expiration date
6291
// TODO: should we check SANs? probably not, as byod domain are added as SNI single domain certs
63-
certToUpdate = certArn
64-
break
92+
return certArn, cert, nil
6593
}
6694
}
67-
if certToUpdate == "" {
68-
if len(getCertErrs) == 0 {
69-
return fmt.Errorf("no certificate matching %v found", domain)
70-
}
71-
return fmt.Errorf("failed to get certificate: %w", errors.Join(getCertErrs...))
72-
}
73-
74-
acmeDirectory := os.Getenv("ACME_DIRECTORY")
75-
if acmeDirectory == "" {
76-
acmeDirectory = DefaultAcmeDirectory
77-
}
95+
return "", nil, fmt.Errorf("no certificate matching %v found: %w", domain, errors.Join(getCertErrs...))
96+
}
7897

79-
acmeClient := Acme{
80-
Directory: acmeDirectory,
81-
AccountKey: accountKey,
82-
Logger: logger,
83-
AlbArn: albArn,
84-
HttpSolver: solver,
98+
func SetupHttpRule(ctx context.Context, albArn, lambdaArn string, ruleCond alb.RuleCondition) error {
99+
listener, err := alb.GetListener(ctx, albArn, awsalb.ProtocolEnumHttp, 80)
100+
if err != nil {
101+
return fmt.Errorf("cannot get http listener: %w", err)
85102
}
86103

87-
key, chain, err := acmeClient.GetCertificate(ctx, []string{domain})
104+
targetGroupArn, err := alb.GetLambdaTargetGroup(ctx, lambdaArn)
88105
if err != nil {
89-
return fmt.Errorf("failed to get certificates: %w", err)
106+
return fmt.Errorf("cannot get target group for lambda %v: %w", lambdaArn, err)
90107
}
91108

92-
if err := acm.ImportCertificate(ctx, key, chain, certToUpdate); err != nil {
93-
return fmt.Errorf("error importing certificate: %w", err)
109+
if err := alb.AddListenerTriggerTargetGroupRule(ctx, *listener.ListenerArn, ruleCond, targetGroupArn); err != nil {
110+
return fmt.Errorf("failed to create listener static rule: %w", err)
94111
}
95112
return nil
96113
}

aws/alb/updatealb.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"log"
78
"sort"
89
"strconv"
10+
"strings"
911

1012
"github.com/DefangLabs/cloudacme/aws"
1113
elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
@@ -132,6 +134,79 @@ func AddListenerStaticRule(ctx context.Context, listenerArn string, ruleCond Rul
132134
return nil
133135
}
134136

137+
func AddListenerTriggerTargetGroupRule(ctx context.Context, listenerArn string, ruleCond RuleCondition, targetArn string) error {
138+
svc := elbv2.NewFromConfig(aws.LoadConfig())
139+
140+
priority, err := GetNextAvailablePriority(ctx, listenerArn)
141+
if err != nil {
142+
return err
143+
}
144+
145+
input := &elbv2.CreateRuleInput{
146+
Actions: []types.Action{
147+
{
148+
Type: types.ActionTypeEnumForward,
149+
TargetGroupArn: &targetArn,
150+
},
151+
},
152+
Conditions: []types.RuleCondition{
153+
{
154+
Field: ptr.String("path-pattern"),
155+
PathPatternConfig: &types.PathPatternConditionConfig{Values: ruleCond.PathPattern},
156+
},
157+
{
158+
Field: ptr.String("host-header"),
159+
HostHeaderConfig: &types.HostHeaderConditionConfig{Values: ruleCond.HostHeader},
160+
},
161+
},
162+
ListenerArn: &listenerArn,
163+
Priority: ptr.Int32(priority),
164+
}
165+
166+
if _, err := svc.CreateRule(ctx, input); err != nil {
167+
return err
168+
}
169+
return nil
170+
}
171+
172+
func GetLambdaTargetGroup(ctx context.Context, lambdaArn string) (string, error) {
173+
svc := elbv2.NewFromConfig(aws.LoadConfig())
174+
paginator := elbv2.NewDescribeTargetGroupsPaginator(svc, &elbv2.DescribeTargetGroupsInput{})
175+
176+
log.Printf("Searching for target group for lambda %s", lambdaArn)
177+
for paginator.HasMorePages() {
178+
page, err := paginator.NextPage(ctx)
179+
if err != nil {
180+
return "", fmt.Errorf("failed to list target groups: %w", err)
181+
}
182+
183+
for _, tg := range page.TargetGroups {
184+
log.Printf("Checking target group %s of type %s", *tg.TargetGroupArn, tg.TargetType)
185+
if tg.TargetType != types.TargetTypeEnumLambda {
186+
continue
187+
}
188+
189+
// Check registered targets for this target group
190+
targetsOut, err := svc.DescribeTargetHealth(ctx, &elbv2.DescribeTargetHealthInput{
191+
TargetGroupArn: tg.TargetGroupArn,
192+
})
193+
if err != nil {
194+
return "", fmt.Errorf("describe target health failed for %s: %w", *tg.TargetGroupArn, err)
195+
}
196+
197+
for _, desc := range targetsOut.TargetHealthDescriptions {
198+
if desc.Target != nil && desc.Target.Id != nil {
199+
log.Printf("Checking target %s with status %s", *desc.Target.Id, desc.TargetHealth.State)
200+
if strings.HasPrefix(lambdaArn, *desc.Target.Id) {
201+
return *tg.TargetGroupArn, nil
202+
}
203+
}
204+
}
205+
}
206+
}
207+
return "", fmt.Errorf("no target group found for lambda %s", lambdaArn)
208+
}
209+
135210
func GetNextAvailablePriority(ctx context.Context, listenerArn string) (int32, error) {
136211
rules, err := GetAllRules(ctx, listenerArn)
137212
if err != nil {

cmd/lambda/main.go

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,21 @@ package main
33
import (
44
"context"
55
"crypto/tls"
6+
"crypto/x509"
67
"errors"
78
"fmt"
89
"log"
910
"net/http"
1011
"net/url"
12+
"strings"
1113
"time"
1214

13-
"github.com/aws/aws-lambda-go/events"
14-
"github.com/aws/aws-lambda-go/lambda"
1515
"github.com/DefangLabs/cloudacme/acme"
1616
"github.com/DefangLabs/cloudacme/aws/alb"
1717
"github.com/DefangLabs/cloudacme/solver"
18+
"github.com/aws/aws-lambda-go/events"
19+
"github.com/aws/aws-lambda-go/lambda"
20+
"github.com/aws/aws-lambda-go/lambdacontext"
1821
)
1922

2023
var version = "dev" // to be set by ldflags
@@ -34,7 +37,28 @@ func HandleEvent(ctx context.Context, evt Event) (any, error) {
3437
if evt.HTTPMethod != "" {
3538
return HandleALBEvent(ctx, evt.ALBTargetGroupRequest)
3639
} else {
37-
return nil, HandleEventBridgeEvent(ctx, evt.CertificateRenewalEvent)
40+
_, cert, err := acme.GetExistingCertificate(ctx, evt.AlbArn, evt.Domain)
41+
if err != nil {
42+
return nil, fmt.Errorf("failed to get existing certificate: %w", err)
43+
}
44+
45+
ownArn := ""
46+
if lc, ok := lambdacontext.FromContext(ctx); ok {
47+
ownArn = lc.InvokedFunctionArn
48+
}
49+
if ownArn == "" {
50+
return nil, errors.New("unable to determine own Lambda ARN from context")
51+
}
52+
53+
if !IsLetsEncryptCertificate(cert) {
54+
log.Printf("Certificate for domain %s is not issued by Let's Encrypt, initial run, setup load balancer rule for acme lambda", evt.Domain)
55+
return nil, acme.SetupHttpRule(ctx, evt.AlbArn, ownArn, alb.RuleCondition{
56+
HostHeader: []string{evt.Domain},
57+
PathPattern: []string{"/"},
58+
})
59+
} else {
60+
return nil, HandleScheduledRenewalEvent(ctx, evt.CertificateRenewalEvent)
61+
}
3862
}
3963
}
4064

@@ -124,7 +148,7 @@ func getHttpsRedirectURL(evt events.ALBTargetGroupRequest) string {
124148
return fmt.Sprintf("https://%s%s%s", evt.Headers["host"], evt.Path, params)
125149
}
126150

127-
func HandleEventBridgeEvent(ctx context.Context, evt CertificateRenewalEvent) error {
151+
func HandleScheduledRenewalEvent(ctx context.Context, evt CertificateRenewalEvent) error {
128152
log.Printf("Handling Certificate Renewal Event: %+v", evt)
129153

130154
albSolver := solver.AlbHttp01Solver{
@@ -139,6 +163,22 @@ func HandleEventBridgeEvent(ctx context.Context, evt CertificateRenewalEvent) er
139163
return nil
140164
}
141165

166+
func IsLetsEncryptCertificate(cert *x509.Certificate) bool {
167+
// Check Issuer Organization
168+
for _, org := range cert.Issuer.Organization {
169+
if strings.Contains(strings.ToLower(org), "let's encrypt") {
170+
return true
171+
}
172+
}
173+
174+
// Fallback: check Common Name
175+
if strings.Contains(strings.ToLower(cert.Issuer.CommonName), "let's encrypt") {
176+
return true
177+
}
178+
179+
return false
180+
}
181+
142182
func main() {
143183
lambda.Start(HandleEvent)
144184
}

0 commit comments

Comments
 (0)