diff --git a/CHANGELOG.md b/CHANGELOG.md index 96e592a39ac..ebcea780da9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,10 @@ and this project adheres to - [#5290](https://github.com/firecracker-microvm/firecracker/pull/5290): Fixed MMDS to reject PUT requests containing `X-Forwarded-For` header regardless of its casing (e.g. `x-forwarded-for`). +- [#5328](https://github.com/firecracker-microvm/firecracker/pull/5328): Fixed + MMDS to set the token TTL header (i.e. "X-metadata-token-ttl-seconds" or + "X-aws-ec2-metadata-token-ttl-seconds") in the response to "PUT + /latest/api/token", as EC2 IMDS does. ## [1.12.0] diff --git a/Cargo.lock b/Cargo.lock index 9e00c993221..6daf7e5c9e8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -955,7 +955,7 @@ dependencies = [ [[package]] name = "micro_http" version = "0.1.0" -source = "git+https://github.com/firecracker-microvm/micro-http#11cc5da16ac86f9107d3f45791944fa6b964a6a9" +source = "git+https://github.com/firecracker-microvm/micro-http#98d85677ba603d16c40103c09059b54c38d71825" dependencies = [ "libc", "vmm-sys-util", diff --git a/resources/overlay/usr/local/bin/fillmem.c b/resources/overlay/usr/local/bin/fillmem.c index 2e8ea902c4e..821cf9d821d 100644 --- a/resources/overlay/usr/local/bin/fillmem.c +++ b/resources/overlay/usr/local/bin/fillmem.c @@ -19,9 +19,9 @@ int fill_mem(int mb_count) { - int i, j; + int i; char *ptr = NULL; - for(j = 0; j < mb_count; j++) { + for(i = 0; i < mb_count; i++) { do { // We can't map the whole chunk of memory at once because // in case the system is already in a memory pressured @@ -50,7 +50,7 @@ int main(int argc, char *const argv[]) { printf("Usage: ./fillmem mb_count\n"); return -1; } - + int mb_count = atoi(argv[1]); int pid = fork(); diff --git a/resources/overlay/usr/local/bin/go_sdk_cred_provider.go/main.go b/resources/overlay/usr/local/bin/go_sdk_cred_provider.go/main.go new file mode 100644 index 00000000000..9e6a265a477 --- /dev/null +++ b/resources/overlay/usr/local/bin/go_sdk_cred_provider.go/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" +) + +func main() { + cfg, err := config.LoadDefaultConfig( + context.TODO(), + config.WithClientLogMode( + aws.LogSigning| + aws.LogRetries| + aws.LogRequest| + aws.LogRequestWithBody| + aws.LogResponse| + aws.LogResponseWithBody, + ), + ) + if err != nil { + log.Fatalf("Unable to load config: %v", err) + } + + cred, err := cfg.Credentials.Retrieve(context.TODO()) + if err != nil { + log.Fatalf("Unable to retrieve credentials: %v", err) + } + + fmt.Printf("%v,%v,%v\n", cred.AccessKeyID, cred.SecretAccessKey, cred.SessionToken) +} diff --git a/resources/overlay/usr/local/bin/go_sdk_cred_provider_with_custom_endpoint.go/main.go b/resources/overlay/usr/local/bin/go_sdk_cred_provider_with_custom_endpoint.go/main.go new file mode 100644 index 00000000000..4a16a6823c7 --- /dev/null +++ b/resources/overlay/usr/local/bin/go_sdk_cred_provider_with_custom_endpoint.go/main.go @@ -0,0 +1,138 @@ +package main + +import ( + "context" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httputil" + "os" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/endpointcreds" +) + +const mmdsBaseUrl = "http://169.254.169.254" + +func main() { + // Get MMDS token + token, err := getMmdsToken() + if err != nil { + log.Fatalf("Failed to get MMDS token: %v", err) + } + + // Construct a client + client := &http.Client{ + Transport: &tokenInjector{ + token: token, + next: &loggingRoundTripper{ + next: http.DefaultTransport, + }, + }, + } + + // Construct a credential provider + endpoint := fmt.Sprintf("%s/latest/meta-data/iam/security-credentials/role", mmdsBaseUrl) + provider := endpointcreds.New(endpoint, func(o *endpointcreds.Options) { + o.HTTPClient = client + }) + + // Load config with the custom provider + cfg, err := config.LoadDefaultConfig( + context.TODO(), + config.WithCredentialsProvider(provider), + ) + if err != nil { + log.Fatalf("Unable to load config: %v", err) + } + + // Retrieve credentials + cred, err := cfg.Credentials.Retrieve(context.TODO()) + if err != nil { + log.Fatalf("Unable to retrieve credentials: %v", err) + } + + fmt.Printf("%v,%v,%v\n", cred.AccessKeyID, cred.SecretAccessKey, cred.SessionToken) +} + +func getMmdsToken() (string, error) { + client := &http.Client{} + + // Construct a request + req, err := http.NewRequest("PUT", mmdsBaseUrl + "/latest/api/token", nil) + if err != nil { + return "", err + } + req.Header.Set("x-aws-ec2-metadata-token-ttl-seconds", "21600") + + // Log the request + dumpReq, err := httputil.DumpRequest(req, true) + if err != nil { + return "", err + } + fmt.Fprintf(os.Stderr, "REQUEST:\n%s\n", dumpReq) + + // Perform the request + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // Log the response + dumpResp, err := httputil.DumpResponse(resp, true) + if err != nil { + return "", err + } + fmt.Fprintf(os.Stderr, "RESPONSE:\n%s\n", dumpResp) + + // Check the response status code. + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("Status: %s", resp.Status) + } + + // Read the body + body, _ := ioutil.ReadAll(resp.Body) + return string(body), nil +} + +// tokenInjector adds the token header on every metadata request +type tokenInjector struct { + token string + next http.RoundTripper +} + +func (t *tokenInjector) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("x-aws-ec2-metadata-token", t.token) + return t.next.RoundTrip(req) +} + +// logginRoundTripper logs requests and responses +type loggingRoundTripper struct { + next http.RoundTripper +} + +func (l *loggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Log the request + dumpReq, err := httputil.DumpRequest(req, true) + if err != nil { + return nil, err + } + fmt.Fprintf(os.Stderr, "REQUEST:\n%s\n", dumpReq) + + // Perform the request + resp, err := l.next.RoundTrip(req) + if err != nil { + return nil, err + } + + // Log the response + dumpResp, err := httputil.DumpResponse(resp, true) + if err != nil { + return nil, err + } + fmt.Fprintf(os.Stderr, "RESPONSE:\n%s\n", dumpResp) + + return resp, nil +} diff --git a/resources/rebuild.sh b/resources/rebuild.sh index c6d5e2dd38d..f7215af371e 100755 --- a/resources/rebuild.sh +++ b/resources/rebuild.sh @@ -18,6 +18,18 @@ source "$GIT_ROOT_DIR/tools/functions" function install_dependencies { apt update apt install -y bc flex bison gcc make libelf-dev libssl-dev squashfs-tools busybox-static tree cpio curl patch docker.io + + # Install Go + version=$(curl -s https://go.dev/VERSION?m=text | head -n 1) + case $ARCH in + x86_64) archive="${version}.linux-amd64.tar.gz" ;; + aarch64) archive="${version}.linux-arm64.tar.gz" ;; + esac + curl -LO http://go.dev/dl/${archive} + tar -C /usr/local -xzf $archive + export PATH=$PATH:/usr/local/go/bin + go version + rm $archive } function prepare_docker { @@ -28,11 +40,19 @@ function prepare_docker { } function compile_and_install { - local C_FILE=$1 - local BIN_FILE=$2 - local OUTPUT_DIR=$(dirname $BIN_FILE) - mkdir -pv $OUTPUT_DIR - gcc -Wall -o $BIN_FILE $C_FILE + local SRC=$1 + local BIN="${SRC%.*}" + if [[ $SRC == *.c ]]; then + gcc -Wall -o $BIN $SRC + elif [[ $SRC == *.go ]]; then + pushd $SRC + local MOD=$(basename $BIN) + go mod init $MOD + go mod tidy + go build -o ../$MOD + rm go.mod go.sum + popd + fi } # Build a rootfs @@ -65,12 +85,6 @@ for d in $dirs; do tar c "/$d" | tar x -C $rootfs; done mkdir -pv $rootfs/{dev,proc,sys,run,tmp,var/lib/systemd} # So apt works mkdir -pv $rootfs/var/lib/dpkg/ - -# Install AWS CLI v2 -curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" -unzip awscliv2.zip -./aws/install --install-dir $rootfs/usr/local/aws-cli --bin-dir $rootfs/usr/local/bin -rm -rf awscliv2.zip aws EOF # TBD what abt /etc/hosts? @@ -80,9 +94,6 @@ EOF mv $rootfs/root/manifest $OUTPUT_DIR/$ROOTFS_NAME.manifest mksquashfs $rootfs $rootfs_img -all-root -noappend -comp zstd rm -rf $rootfs - for bin in fast_page_fault_helper fillmem init readmem; do - rm $PWD/overlay/usr/local/bin/$bin - done rm -f nohup.out } @@ -187,17 +198,24 @@ function build_al_kernel { } function prepare_and_build_rootfs { - BIN=overlay/usr/local/bin - compile_and_install $BIN/init.c $BIN/init - compile_and_install $BIN/fillmem.c $BIN/fillmem - compile_and_install $BIN/fast_page_fault_helper.c $BIN/fast_page_fault_helper - compile_and_install $BIN/readmem.c $BIN/readmem + BIN_DIR=overlay/usr/local/bin + + SRCS=(init.c fillmem.c fast_page_fault_helper.c readmem.c go_sdk_cred_provider.go go_sdk_cred_provider_with_custom_endpoint.go) if [ $ARCH == "aarch64" ]; then - compile_and_install $BIN/devmemread.c $BIN/devmemread + SRCS+=(devmemread.c) fi + for SRC in ${SRCS[@]}; do + compile_and_install $BIN_DIR/$SRC + done + build_rootfs ubuntu-24.04 noble build_initramfs + + for SRC in ${SRCS[@]}; do + BIN="${SRC%.*}" + rm $BIN_DIR/$BIN + done } function vmlinux_split_debuginfo { diff --git a/src/vmm/src/mmds/mod.rs b/src/vmm/src/mmds/mod.rs index 509442048b9..44e5a82b841 100644 --- a/src/vmm/src/mmds/mod.rs +++ b/src/vmm/src/mmds/mod.rs @@ -271,7 +271,7 @@ fn respond_to_put_request(mmds: &mut Mmds, request: Request) -> Response { } // Get token lifetime value. - let ttl_seconds = match get_header_value_pair( + let (header, ttl_seconds) = match get_header_value_pair( custom_headers, &[ X_METADATA_TOKEN_TTL_SECONDS_HEADER, @@ -279,8 +279,8 @@ fn respond_to_put_request(mmds: &mut Mmds, request: Request) -> Response { ], ) { // Header found - Some((k, v)) => match v.parse::() { - Ok(ttl_seconds) => ttl_seconds, + Some((header, value)) => match value.parse::() { + Ok(ttl_seconds) => (header, ttl_seconds), Err(_) => { return build_response( request.http_version(), @@ -288,8 +288,8 @@ fn respond_to_put_request(mmds: &mut Mmds, request: Request) -> Response { MediaType::PlainText, Body::new( RequestError::HeaderError(HttpHeaderError::InvalidValue( - k.into(), - v.into(), + header.into(), + value.into(), )) .to_string(), ), @@ -310,12 +310,22 @@ fn respond_to_put_request(mmds: &mut Mmds, request: Request) -> Response { // Generate token. let result = mmds.generate_token(ttl_seconds); match result { - Ok(token) => build_response( - request.http_version(), - StatusCode::OK, - MediaType::PlainText, - Body::new(token), - ), + Ok(token) => { + let mut response = build_response( + request.http_version(), + StatusCode::OK, + MediaType::PlainText, + Body::new(token), + ); + let custom_headers = [(header.into(), ttl_seconds.to_string())].into(); + // Safe to unwrap because the header name and the value are valid as US-ASCII. + // - `header` is either `X_METADATA_TOKEN_TTL_SECONDS_HEADER` or + // `X_AWS_EC2_METADATA_TOKEN_SSL_SECONDS_HEADER`. + // - `ttl_seconds` is a decimal number between `MIN_TOKEN_TTL_SECONDS` and + // `MAX_TOKEN_TTL_SECONDS`. + response.set_custom_headers(&custom_headers).unwrap(); + response + } Err(err) => build_response( request.http_version(), StatusCode::BadRequest, @@ -752,6 +762,13 @@ mod tests { let actual_response = convert_to_response(mmds.clone(), request); assert_eq!(actual_response.status(), StatusCode::OK); assert_eq!(actual_response.content_type(), MediaType::PlainText); + assert_eq!( + actual_response + .custom_headers() + .get("X-metadata-token-ttl-seconds") + .unwrap(), + "60" + ); // Test unsupported `X-Forwarded-For` header for header in ["X-Forwarded-For", "x-forwarded-for", "X-fOrWaRdEd-FoR"] { diff --git a/tests/integration_tests/functional/test_mmds.py b/tests/integration_tests/functional/test_mmds.py index 1b548944dd8..a9ca8ead8dd 100644 --- a/tests/integration_tests/functional/test_mmds.py +++ b/tests/integration_tests/functional/test_mmds.py @@ -748,19 +748,13 @@ def test_deprecated_mmds_config(uvm_plain): ) -@pytest.mark.parametrize("version", MMDS_VERSIONS) -@pytest.mark.parametrize("imds_compat", [None, False, True]) -def test_aws_credential_provider(uvm_plain, version, imds_compat): - """ - Test AWS CLI credential provider - """ - test_microvm = uvm_plain - test_microvm.spawn() - test_microvm.basic_config() - test_microvm.add_net_iface() +def _configure_with_aws_credentials(microvm, version, imds_compat): + microvm.spawn() + microvm.basic_config() + microvm.add_net_iface() # V2 requires session tokens for GET requests configure_mmds( - test_microvm, iface_ids=["eth0"], version=version, imds_compat=imds_compat + microvm, iface_ids=["eth0"], version=version, imds_compat=imds_compat ) now = datetime.now(timezone.utc) credentials = { @@ -782,21 +776,78 @@ def test_aws_credential_provider(uvm_plain, version, imds_compat): } } } - populate_data_store(test_microvm, data_store) - test_microvm.start() - - ssh_connection = test_microvm.ssh + populate_data_store(microvm, data_store) + microvm.start() + ssh_connection = microvm.ssh run_guest_cmd(ssh_connection, f"ip route add {DEFAULT_IPV4} dev eth0", "") - cmd = r"""python3 - <