Skip to content

Commit 5e92fca

Browse files
committed
Add S3 URI handling to modules
1 parent 98a5486 commit 5e92fca

File tree

5 files changed

+275
-38
lines changed

5 files changed

+275
-38
lines changed

cft/pkg/content.go

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,14 @@ type ModuleContent struct {
1717
BaseUri string
1818
}
1919

20-
// Helper function to handle zip file extraction with proper path resolution
21-
func handleZipFile(root string, location string, hash string, path string) ([]byte, error) {
22-
// Resolve the path to the zip file if it's a local path
23-
zipPath := location
24-
if !strings.HasPrefix(zipPath, "http://") && !strings.HasPrefix(zipPath, "https://") {
25-
// If it's a relative path, resolve it relative to the template's directory
26-
if !filepath.IsAbs(zipPath) {
27-
zipPath = filepath.Join(root, zipPath)
28-
}
29-
}
3020

31-
// Check if the zip file exists if it's a local file
32-
if !strings.HasPrefix(zipPath, "http://") && !strings.HasPrefix(zipPath, "https://") {
33-
_, err := os.Stat(zipPath)
34-
if err != nil {
35-
return nil, fmt.Errorf("error accessing zip file %s: %v", zipPath, err)
36-
}
37-
}
3821

39-
// Unzip, verify hash if there is one, and put the files in memory
40-
content, err := DownloadFromZip(zipPath, hash, path)
41-
if err != nil {
42-
config.Debugf("ZIP: Error extracting from zip: %v", err)
43-
return nil, err
44-
}
22+
func isHttpsUrl(uri string) bool {
23+
return strings.HasPrefix(uri, "https://")
24+
}
4525

46-
return content, nil
26+
func isS3URI(uri string) bool {
27+
return strings.HasPrefix(uri, "s3://")
4728
}
4829

4930
// Get the module's content from a local file, memory, or a remote uri
@@ -54,6 +35,8 @@ func getModuleContent(
5435
baseUri string,
5536
uri string) (*ModuleContent, error) {
5637

38+
config.Debugf("getModuleContent root: %s, uri: %s", root, uri)
39+
5740
var content []byte
5841
var err error
5942
var newRootDir string
@@ -71,7 +54,15 @@ func getModuleContent(
7154

7255
if strings.HasSuffix(packageAlias.Location, ".zip") {
7356
isZip = true
74-
content, err = handleZipFile(root, packageAlias.Location, packageAlias.Hash, path)
57+
58+
// Use DownloadFromZip directly
59+
zipLocation := packageAlias.Location
60+
// For local files, resolve the path relative to the template's directory
61+
if !isS3URI(zipLocation) && !isHttpsUrl(zipLocation) && !filepath.IsAbs(zipLocation) {
62+
zipLocation = filepath.Join(root, zipLocation)
63+
}
64+
65+
content, err = DownloadFromZip(zipLocation, packageAlias.Hash, path)
7566
if err != nil {
7667
return nil, err
7768
}
@@ -92,12 +83,25 @@ func getModuleContent(
9283
// getModuleContent: root=cft/pkg/tmpl/awscli-modules, baseUri=, uri=package.zip/zip-module.yaml
9384
if strings.Contains(uri, ".zip/") {
9485
isZip = true
95-
tokens := strings.Split(uri, "/")
96-
location := tokens[0]
97-
path := strings.Join(tokens[1:], "/")
98-
content, err = handleZipFile(root, location, "", path)
99-
if err != nil {
100-
return nil, err
86+
87+
// Extract the zip location and path within the zip
88+
zipIndex := strings.Index(uri, ".zip/")
89+
if zipIndex > 0 {
90+
zipLocation := uri[:zipIndex+4] // Include the .zip part
91+
zipPath := uri[zipIndex+5:] // Skip the .zip/ part
92+
93+
// For local files, resolve the path relative to the template's directory
94+
if !isS3URI(zipLocation) && !isHttpsUrl(zipLocation) && !filepath.IsAbs(zipLocation) {
95+
zipLocation = filepath.Join(root, zipLocation)
96+
}
97+
98+
config.Debugf("Extracting from zip: %s, path: %s", zipLocation, zipPath)
99+
100+
// Use DownloadFromZip directly - it can handle S3, HTTPS, and local files
101+
content, err = DownloadFromZip(zipLocation, "", zipPath)
102+
if err != nil {
103+
return nil, err
104+
}
101105
}
102106
}
103107

@@ -109,7 +113,15 @@ func getModuleContent(
109113
path := strings.Replace(uri, packageAlias.Alias+"/", "", 1)
110114
if strings.HasSuffix(packageAlias.Location, ".zip") {
111115
isZip = true
112-
content, err = handleZipFile(root, packageAlias.Location, packageAlias.Hash, path)
116+
117+
// Use DownloadFromZip directly
118+
zipLocation := packageAlias.Location
119+
// For local files, resolve the path relative to the template's directory
120+
if !isS3URI(zipLocation) && !isHttpsUrl(zipLocation) && !filepath.IsAbs(zipLocation) {
121+
zipLocation = filepath.Join(root, zipLocation)
122+
}
123+
124+
content, err = DownloadFromZip(zipLocation, packageAlias.Hash, path)
113125
if err != nil {
114126
return nil, err
115127
}
@@ -122,7 +134,7 @@ func getModuleContent(
122134
// Is this a local file or a URL or did we already unzip a package?
123135
if isZip {
124136
config.Debugf("Using content from a zipped module package (length: %d bytes)", len(content))
125-
} else if strings.HasPrefix(uri, "https://") {
137+
} else if isHttpsUrl(uri) || isS3URI(uri) {
126138
config.Debugf("Downloading from URL: %s", uri)
127139
content, err = downloadModule(uri)
128140
if err != nil {
@@ -138,6 +150,7 @@ func getModuleContent(
138150
baseUri = strings.Join(urlParts[:len(urlParts)-1], "/")
139151

140152
} else {
153+
config.Debugf("Downloading from a local file, baseUri=%s, uri=%s", baseUri, uri)
141154
if baseUri != "" {
142155
// If we have a base URL, prepend it to the relative path
143156
uri = baseUri + "/" + uri

cft/pkg/download.go

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"path/filepath"
1212
"strings"
1313

14+
"github.com/aws-cloudformation/rain/internal/aws/s3"
1415
"github.com/aws-cloudformation/rain/internal/config"
1516
"github.com/google/uuid"
1617
)
@@ -42,11 +43,25 @@ func downloadHash(uri string) (string, error) {
4243

4344
// DownloadFromZip retrieves a single file from a zip file hosted on a URI
4445
func DownloadFromZip(uriString string, verifyHash string, path string) ([]byte, error) {
46+
47+
config.Debugf("DownloadFromZip uriString: %s, path: %s",
48+
uriString, path)
49+
4550
var zipData []byte
4651
var err error
47-
48-
// Check if it's a URL or local file
49-
if strings.HasPrefix(uriString, "http://") || strings.HasPrefix(uriString, "https://") {
52+
53+
isUrl := isHttpsUrl(uriString)
54+
isS3 := isS3URI(uriString)
55+
56+
// Check if it's an S3 URI, HTTPS URL, or local file
57+
if isS3 {
58+
// Download from S3
59+
config.Debugf("Downloading from S3: %s", uriString)
60+
zipData, err = downloadS3(uriString)
61+
if err != nil {
62+
return nil, fmt.Errorf("failed to download zip from S3: %v", err)
63+
}
64+
} else if isUrl {
5065
// Download from URL
5166
config.Debugf("Downloading %s", uriString)
5267
resp, err := http.Get(uriString)
@@ -59,7 +74,7 @@ func DownloadFromZip(uriString string, verifyHash string, path string) ([]byte,
5974
config.Debugf("Error closing body: %v", err)
6075
}
6176
}(resp.Body)
62-
77+
6378
zipData, err = io.ReadAll(resp.Body)
6479
if err != nil {
6580
return nil, err
@@ -112,7 +127,7 @@ func DownloadFromZip(uriString string, verifyHash string, path string) ([]byte,
112127

113128
// Download or read the hash
114129
var originalHash string
115-
if strings.HasPrefix(verifyHash, "http://") || strings.HasPrefix(verifyHash, "https://") {
130+
if isUrl {
116131
originalHash, err = downloadHash(verifyHash)
117132
if err != nil {
118133
return nil, err
@@ -215,9 +230,31 @@ func Unzip(f *os.File, dest string) error {
215230
return nil
216231
}
217232

233+
func downloadS3(uri string) ([]byte, error) {
234+
// Parse the S3 URI
235+
bucket, key, err := s3.ParseURI(uri)
236+
if err != nil {
237+
return nil, err
238+
}
239+
240+
// Download the file from S3
241+
content, err := s3.GetObject(bucket, key)
242+
if err != nil {
243+
return nil, err
244+
}
245+
246+
return content, nil
247+
}
248+
218249
// downloadModule downloads the file from the given URI and returns its content as a byte slice.
219250
func downloadModule(uri string) ([]byte, error) {
220251
config.Debugf("Downloading %s", uri)
252+
253+
// If it's an S3 uri, use the s3 package to download the file
254+
if strings.HasPrefix(uri, "s3://") {
255+
return downloadS3(uri)
256+
}
257+
221258
resp, err := http.Get(uri)
222259
if err != nil {
223260
return nil, err
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Packages:
2+
abc:
3+
Source: s3://ezbeard-awscli/modules/package.zip
4+
Constants:
5+
ModuleSource: s3://ezbeard-awscli/modules
6+
Modules:
7+
Content:
8+
Source: !Sub ${Const::ModuleSource}/basic-module.yaml
9+
Properties:
10+
Name: foo
11+
Overrides:
12+
Bucket:
13+
Properties:
14+
OverrideMe: def
15+
TestPackage:
16+
Source: $abc/zip-module.yaml
17+
Resources:
18+
OtherResource:
19+
Type: AWS::S3::Bucket

internal/aws/s3/s3.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,52 @@ func RainBucket(forceCreation bool) string {
295295
return bucketName
296296
}
297297

298+
// ParseURI parses an S3 URI like s3://bucket/key
299+
// The object key name is a sequence of Unicode characters with UTF-8 encoding of up to 1,024 bytes long.
300+
// The bucket name must be a valid DNS name and follow S3 bucket naming rules.
301+
func ParseURI(uri string) (bucket string, key string, err error) {
302+
// Check if the URI starts with s3:// prefix
303+
if !strings.HasPrefix(uri, "s3://") {
304+
err = fmt.Errorf("invalid s3 uri: %s, must start with s3://", uri)
305+
return
306+
}
307+
308+
// Remove the s3:// prefix
309+
uri = strings.TrimPrefix(uri, "s3://")
310+
311+
// Split the remaining string by the first /
312+
parts := strings.SplitN(uri, "/", 2)
313+
if len(parts) == 0 || parts[0] == "" {
314+
err = fmt.Errorf("invalid s3 uri: %s, bucket name missing", uri)
315+
return
316+
}
317+
318+
bucket = parts[0]
319+
320+
// Validate bucket name (basic validation)
321+
if len(bucket) < 3 || len(bucket) > 63 {
322+
err = fmt.Errorf("invalid bucket name: %s, length must be between 3 and 63 characters", bucket)
323+
return
324+
}
325+
326+
if len(parts) == 1 {
327+
key = ""
328+
} else {
329+
key = parts[1]
330+
331+
// Validate key length (S3 limit is 1024 bytes)
332+
if len(key) > 1024 {
333+
err = fmt.Errorf("invalid key: %s, length exceeds 1024 bytes", key)
334+
return
335+
}
336+
}
337+
return
338+
}
339+
298340
// GetObject gets an object by key from an S3 bucket
299341
func GetObject(bucketName string, key string) ([]byte, error) {
300342

343+
config.Debugf("GetObject bucket %s, key: %s", bucketName, key)
301344
accountId, err := getAccountId()
302345
if err != nil {
303346
return nil, err

0 commit comments

Comments
 (0)