Skip to content

Commit 6d51f16

Browse files
committed
Implement s3:// protocol
For those that want to pull from s3 Signed-off-by: Eric Curtin <[email protected]>
1 parent 3d804de commit 6d51f16

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

examples/run/run.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,44 @@ class LlamaData {
698698
return download(url, bn, true);
699699
}
700700

701+
int s3_dl(const std::string & model, const std::string & bn) {
702+
// Extract bucket and key from S3 URI
703+
const std::string prefix = "s3://";
704+
const size_t pos = model.find(prefix);
705+
if (pos != 0) {
706+
return 1;
707+
}
708+
709+
const std::string path = model.substr(prefix.length());
710+
const size_t slash_pos = path.find('/');
711+
if (slash_pos == std::string::npos) {
712+
return 1;
713+
}
714+
715+
const std::string bucket = path.substr(0, slash_pos);
716+
const std::string key = path.substr(slash_pos + 1);
717+
718+
// Get AWS credentials from environment, used env vars, open to change
719+
const char * access_key = std::getenv("AWS_ACCESS_KEY_ID");
720+
const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY");
721+
if (!access_key || !secret_key) {
722+
printe("AWS credentials not found in environment\n");
723+
return 1;
724+
}
725+
726+
// Generate AWS Signature Version 4 headers
727+
// (Implementation requires HMAC-SHA256 and date handling)
728+
const std::vector<std::string> headers = { "Authorization: AWS4-HMAC-SHA256 Credential=" +
729+
std::string(access_key) + "/20240130/us-east-1/s3/aws4_request",
730+
"x-amz-content-sha256: UNSIGNED-PAYLOAD",
731+
"x-amz-date: 20240130T000000Z" };
732+
733+
// Construct S3 endpoint URL
734+
const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key;
735+
736+
return download(url, bn, true, headers);
737+
}
738+
701739
std::string basename(const std::string & path) {
702740
const size_t pos = path.find_last_of("/\\");
703741
if (pos == std::string::npos) {
@@ -738,6 +776,9 @@ class LlamaData {
738776
rm_until_substring(model_, "github:");
739777
rm_until_substring(model_, "://");
740778
ret = github_dl(model_, bn);
779+
} else if (string_starts_with(model_, "s3://")) {
780+
rm_until_substring(model_, "://");
781+
ret = s3_dl(model_, bn);
741782
} else { // ollama:// or nothing
742783
rm_until_substring(model_, "ollama.com/library/");
743784
rm_until_substring(model_, "://");

0 commit comments

Comments
 (0)