Skip to content

Commit 7f996bc

Browse files
committed
Implement s3:// protocol
For those that want to pull from s3 Signed-off-by: Eric Curtin <[email protected]>
1 parent 3d804de commit 7f996bc

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

examples/run/run.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,23 @@ static int printe(const char * fmt, ...) {
6565
return ret;
6666
}
6767

68+
static std::string strftime_fmt(const char * fmt, const std::tm & tm) {
69+
// Estimate the size of the output buffer
70+
std::string buffer;
71+
buffer.resize(128);
72+
73+
// Try to format the string
74+
size_t len;
75+
while ((len = std::strftime(buffer.data(), buffer.size(), fmt, &tm)) == 0) {
76+
// If the buffer was too small, double its size and try again
77+
buffer.resize(buffer.size() * 2);
78+
}
79+
80+
buffer.resize(len);
81+
82+
return buffer;
83+
}
84+
6885
class Opt {
6986
public:
7087
int init(int argc, const char ** argv) {
@@ -698,6 +715,49 @@ class LlamaData {
698715
return download(url, bn, true);
699716
}
700717

718+
int s3_dl(const std::string & model, const std::string & bn) {
719+
// Extract bucket and key from S3 URI
720+
const std::string prefix = "s3://";
721+
const size_t pos = model.find(prefix);
722+
if (pos != 0) {
723+
return 1;
724+
}
725+
726+
const std::string path = model.substr(prefix.length());
727+
const size_t slash_pos = path.find('/');
728+
if (slash_pos == std::string::npos) {
729+
return 1;
730+
}
731+
732+
const std::string bucket = path.substr(0, slash_pos);
733+
const std::string key = path.substr(slash_pos + 1);
734+
735+
// Get AWS credentials from environment, used env vars, open to change
736+
const char * access_key = std::getenv("AWS_ACCESS_KEY_ID");
737+
const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY");
738+
if (!access_key || !secret_key) {
739+
printe("AWS credentials not found in environment\n");
740+
return 1;
741+
}
742+
743+
// Generate AWS Signature Version 4 headers
744+
// (Implementation requires HMAC-SHA256 and date handling)
745+
// Get current timestamp
746+
const time_t now = time(nullptr);
747+
const tm tm = *gmtime(&now);
748+
std::string date = strftime_fmt("%Y%m%d", tm);
749+
const std::vector<std::string> headers = {
750+
"Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date +
751+
"/us-east-1/s3/aws4_request",
752+
"x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: 20240130T000000Z"
753+
};
754+
755+
// Construct S3 endpoint URL
756+
const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key;
757+
758+
return download(url, bn, true, headers);
759+
}
760+
701761
std::string basename(const std::string & path) {
702762
const size_t pos = path.find_last_of("/\\");
703763
if (pos == std::string::npos) {
@@ -738,6 +798,9 @@ class LlamaData {
738798
rm_until_substring(model_, "github:");
739799
rm_until_substring(model_, "://");
740800
ret = github_dl(model_, bn);
801+
} else if (string_starts_with(model_, "s3://")) {
802+
rm_until_substring(model_, "://");
803+
ret = s3_dl(model_, bn);
741804
} else { // ollama:// or nothing
742805
rm_until_substring(model_, "ollama.com/library/");
743806
rm_until_substring(model_, "://");

0 commit comments

Comments
 (0)