fix(aws): replace context.TODO() with proper timeouts#611
Conversation
Pull Request Test Coverage Report for Build 21755952179Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
There was a problem hiding this comment.
Pull request overview
This pull request aims to replace context.TODO() calls with proper timeout contexts in pkg/provider/aws/create.go to prevent indefinite hangs during AWS API operations. However, the PR includes significant scope creep with multiple new CLI commands unrelated to the stated purpose.
Changes:
- Added timeout constants for AWS operations (VPC, Subnet, IGW, Route Table, Security Group, EC2)
- Partially replaced
context.TODO()withcontext.WithTimeout()in 6 AWS creation functions - Added 7 new CLI commands (validate, provision, describe, get, scp, ssh, update) - not mentioned in PR description
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 12 comments.
| File | Description |
|---|---|
| pkg/provider/aws/create.go | Adds timeout constants and partially replaces context.TODO() calls; many AWS API calls still lack timeout protection |
| cmd/cli/validate/validate.go | New validate command - not mentioned in PR description |
| cmd/cli/provision/provision.go | New provision command - not mentioned in PR description |
| cmd/cli/main.go | Registers new CLI commands and updates help text - not mentioned in PR description |
cmd/cli/validate/validate.go
Outdated
| /* | ||
| * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package validate | ||
|
|
||
| import ( | ||
| "context" | ||
| "fmt" | ||
| "os" | ||
| "strings" | ||
|
|
||
| "github.com/aws/aws-sdk-go-v2/config" | ||
|
|
||
| "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" | ||
| "github.com/NVIDIA/holodeck/internal/logger" | ||
| "github.com/NVIDIA/holodeck/pkg/jyaml" | ||
|
|
||
| cli "github.com/urfave/cli/v2" | ||
| ) | ||
|
|
||
| type command struct { | ||
| log *logger.FunLogger | ||
| envFile string | ||
| strict bool | ||
| } | ||
|
|
||
| // ValidationResult represents the result of a validation check | ||
| type ValidationResult struct { | ||
| Check string | ||
| Passed bool | ||
| Message string | ||
| } | ||
|
|
||
| // NewCommand constructs the validate command with the specified logger | ||
| func NewCommand(log *logger.FunLogger) *cli.Command { | ||
| c := command{ | ||
| log: log, | ||
| } | ||
| return c.build() | ||
| } | ||
|
|
||
| func (m *command) build() *cli.Command { | ||
| validateCmd := cli.Command{ | ||
| Name: "validate", | ||
| Usage: "Validate a Holodeck environment file", | ||
| ArgsUsage: "", | ||
| Description: `Validate an environment file before creating an instance. | ||
|
|
||
| Checks performed: | ||
| - Environment file is valid YAML | ||
| - Required fields are present | ||
| - AWS credentials are configured (for AWS provider) | ||
| - SSH private key is readable | ||
| - SSH public key is readable | ||
|
|
||
| Examples: | ||
| # Validate an environment file | ||
| holodeck validate -f env.yaml | ||
|
|
||
| # Strict mode (fail on warnings) | ||
| holodeck validate -f env.yaml --strict`, | ||
| Flags: []cli.Flag{ | ||
| &cli.StringFlag{ | ||
| Name: "envFile", | ||
| Aliases: []string{"f"}, | ||
| Usage: "Path to the Environment file", | ||
| Destination: &m.envFile, | ||
| Required: true, | ||
| }, | ||
| &cli.BoolFlag{ | ||
| Name: "strict", | ||
| Usage: "Fail on warnings (not just errors)", | ||
| Destination: &m.strict, | ||
| }, | ||
| }, | ||
| Action: func(c *cli.Context) error { | ||
| return m.run() | ||
| }, | ||
| } | ||
|
|
||
| return &validateCmd | ||
| } | ||
|
|
||
| func (m *command) run() error { | ||
| results := make([]ValidationResult, 0) | ||
| hasErrors := false | ||
| hasWarnings := false | ||
|
|
||
| // 1. Validate environment file exists and is valid YAML | ||
| env, err := m.validateEnvFile() | ||
| if err != nil { | ||
| results = append(results, ValidationResult{ | ||
| Check: "Environment file", | ||
| Passed: false, | ||
| Message: err.Error(), | ||
| }) | ||
| hasErrors = true | ||
| m.printResults(results) | ||
| return fmt.Errorf("validation failed") | ||
| } | ||
| results = append(results, ValidationResult{ | ||
| Check: "Environment file", | ||
| Passed: true, | ||
| Message: "Valid YAML structure", | ||
| }) | ||
|
|
||
| // 2. Validate required fields | ||
| fieldResults := m.validateRequiredFields(env) | ||
| for _, r := range fieldResults { | ||
| results = append(results, r) | ||
| if !r.Passed { | ||
| hasErrors = true | ||
| } | ||
| } | ||
|
|
||
| // 3. Validate SSH keys | ||
| keyResults := m.validateSSHKeys(env) | ||
| for _, r := range keyResults { | ||
| results = append(results, r) | ||
| if !r.Passed { | ||
| hasErrors = true | ||
| } | ||
| } | ||
|
|
||
| // 4. Validate AWS credentials (if AWS provider) | ||
| if env.Spec.Provider == v1alpha1.ProviderAWS { | ||
| awsResult := m.validateAWSCredentials() | ||
| results = append(results, awsResult) | ||
| if !awsResult.Passed { | ||
| if strings.Contains(awsResult.Message, "warning") { | ||
| hasWarnings = true | ||
| } else { | ||
| hasErrors = true | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // 5. Validate component configuration | ||
| compResults := m.validateComponents(env) | ||
| for _, r := range compResults { | ||
| results = append(results, r) | ||
| if !r.Passed { | ||
| hasWarnings = true | ||
| } | ||
| } | ||
|
|
||
| // Print results | ||
| m.printResults(results) | ||
|
|
||
| // Determine exit status | ||
| if hasErrors { | ||
| return fmt.Errorf("validation failed with errors") | ||
| } | ||
| if hasWarnings && m.strict { | ||
| return fmt.Errorf("validation failed with warnings (strict mode)") | ||
| } | ||
|
|
||
| m.log.Info("\n✅ Validation passed") | ||
| return nil | ||
| } | ||
|
|
||
| func (m *command) validateEnvFile() (*v1alpha1.Environment, error) { | ||
| if m.envFile == "" { | ||
| return nil, fmt.Errorf("environment file path is required") | ||
| } | ||
|
|
||
| if _, err := os.Stat(m.envFile); os.IsNotExist(err) { | ||
| return nil, fmt.Errorf("file not found: %s", m.envFile) | ||
| } | ||
|
|
||
| env, err := jyaml.UnmarshalFromFile[v1alpha1.Environment](m.envFile) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("invalid YAML: %v", err) | ||
| } | ||
|
|
||
| return &env, nil | ||
| } | ||
|
|
||
| func (m *command) validateRequiredFields(env *v1alpha1.Environment) []ValidationResult { | ||
| results := make([]ValidationResult, 0) | ||
|
|
||
| // Check provider | ||
| if env.Spec.Provider == "" { | ||
| results = append(results, ValidationResult{ | ||
| Check: "Provider", | ||
| Passed: false, | ||
| Message: "Provider is required (aws or ssh)", | ||
| }) | ||
| } else { | ||
| results = append(results, ValidationResult{ | ||
| Check: "Provider", | ||
| Passed: true, | ||
| Message: fmt.Sprintf("Provider: %s", env.Spec.Provider), | ||
| }) | ||
| } | ||
|
|
||
| // Check auth | ||
| if env.Spec.Auth.KeyName == "" { | ||
| results = append(results, ValidationResult{ | ||
| Check: "Auth.KeyName", | ||
| Passed: false, | ||
| Message: "KeyName is required", | ||
| }) | ||
| } else { | ||
| results = append(results, ValidationResult{ | ||
| Check: "Auth.KeyName", | ||
| Passed: true, | ||
| Message: fmt.Sprintf("KeyName: %s", env.Spec.Auth.KeyName), | ||
| }) | ||
| } | ||
|
|
||
| // Check region (for AWS) | ||
| if env.Spec.Provider == v1alpha1.ProviderAWS { | ||
| region := "" | ||
| if env.Spec.Cluster != nil { | ||
| region = env.Spec.Cluster.Region | ||
| } else { | ||
| region = env.Spec.Instance.Region | ||
| } | ||
|
|
||
| if region == "" { | ||
| results = append(results, ValidationResult{ | ||
| Check: "Region", | ||
| Passed: false, | ||
| Message: "Region is required for AWS provider", | ||
| }) | ||
| } else { | ||
| results = append(results, ValidationResult{ | ||
| Check: "Region", | ||
| Passed: true, | ||
| Message: fmt.Sprintf("Region: %s", region), | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| // Check instance type or cluster config | ||
| if env.Spec.Provider == v1alpha1.ProviderAWS { | ||
| if env.Spec.Cluster == nil { | ||
| if env.Spec.Instance.Type == "" { | ||
| results = append(results, ValidationResult{ | ||
| Check: "Instance.Type", | ||
| Passed: false, | ||
| Message: "Instance type is required for single-node AWS deployment", | ||
| }) | ||
| } else { | ||
| results = append(results, ValidationResult{ | ||
| Check: "Instance.Type", | ||
| Passed: true, | ||
| Message: fmt.Sprintf("Instance type: %s", env.Spec.Instance.Type), | ||
| }) | ||
| } | ||
| } else { | ||
| results = append(results, ValidationResult{ | ||
| Check: "Cluster config", | ||
| Passed: true, | ||
| Message: fmt.Sprintf("Cluster mode: %d CP, %d workers", | ||
| env.Spec.Cluster.ControlPlane.Count, | ||
| func() int32 { | ||
| if env.Spec.Cluster.Workers != nil { | ||
| return env.Spec.Cluster.Workers.Count | ||
| } | ||
| return 0 | ||
| }()), | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| // Check host URL for SSH provider | ||
| if env.Spec.Provider == v1alpha1.ProviderSSH { | ||
| if env.Spec.Instance.HostUrl == "" { | ||
| results = append(results, ValidationResult{ | ||
| Check: "HostUrl", | ||
| Passed: false, | ||
| Message: "HostUrl is required for SSH provider", | ||
| }) | ||
| } else { | ||
| results = append(results, ValidationResult{ | ||
| Check: "HostUrl", | ||
| Passed: true, | ||
| Message: fmt.Sprintf("Host: %s", env.Spec.Instance.HostUrl), | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| return results | ||
| } | ||
|
|
||
| func (m *command) validateSSHKeys(env *v1alpha1.Environment) []ValidationResult { | ||
| results := make([]ValidationResult, 0) | ||
|
|
||
| // Check private key | ||
| if env.Spec.Auth.PrivateKey == "" { | ||
| results = append(results, ValidationResult{ | ||
| Check: "SSH private key", | ||
| Passed: false, | ||
| Message: "Private key path is required", | ||
| }) | ||
| } else { | ||
| // Expand home directory | ||
| keyPath := expandPath(env.Spec.Auth.PrivateKey) | ||
| if _, err := os.Stat(keyPath); os.IsNotExist(err) { | ||
| results = append(results, ValidationResult{ | ||
| Check: "SSH private key", | ||
| Passed: false, | ||
| Message: fmt.Sprintf("Private key not found: %s", keyPath), | ||
| }) | ||
| } else { | ||
| // Check if readable | ||
| if _, err := os.ReadFile(keyPath); err != nil { | ||
| results = append(results, ValidationResult{ | ||
| Check: "SSH private key", | ||
| Passed: false, | ||
| Message: fmt.Sprintf("Cannot read private key: %v", err), | ||
| }) | ||
| } else { | ||
| results = append(results, ValidationResult{ | ||
| Check: "SSH private key", | ||
| Passed: true, | ||
| Message: fmt.Sprintf("Readable: %s", keyPath), | ||
| }) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Check public key | ||
| if env.Spec.Auth.PublicKey == "" { | ||
| results = append(results, ValidationResult{ | ||
| Check: "SSH public key", | ||
| Passed: false, | ||
| Message: "Public key path is required", | ||
| }) | ||
| } else { | ||
| keyPath := expandPath(env.Spec.Auth.PublicKey) | ||
| if _, err := os.Stat(keyPath); os.IsNotExist(err) { | ||
| results = append(results, ValidationResult{ | ||
| Check: "SSH public key", | ||
| Passed: false, | ||
| Message: fmt.Sprintf("Public key not found: %s", keyPath), | ||
| }) | ||
| } else { | ||
| results = append(results, ValidationResult{ | ||
| Check: "SSH public key", | ||
| Passed: true, | ||
| Message: fmt.Sprintf("Found: %s", keyPath), | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| return results | ||
| } | ||
|
|
||
| func (m *command) validateAWSCredentials() ValidationResult { | ||
| // Try to load AWS config | ||
| ctx := context.Background() | ||
| cfg, err := config.LoadDefaultConfig(ctx) | ||
| if err != nil { | ||
| return ValidationResult{ | ||
| Check: "AWS credentials", | ||
| Passed: false, | ||
| Message: fmt.Sprintf("Failed to load AWS config: %v", err), | ||
| } | ||
| } | ||
|
|
||
| // Check if credentials are available | ||
| creds, err := cfg.Credentials.Retrieve(ctx) | ||
| if err != nil { | ||
| return ValidationResult{ | ||
| Check: "AWS credentials", | ||
| Passed: false, | ||
| Message: fmt.Sprintf("Failed to retrieve credentials: %v", err), | ||
| } | ||
| } | ||
|
|
||
| if creds.AccessKeyID == "" { | ||
| return ValidationResult{ | ||
| Check: "AWS credentials", | ||
| Passed: false, | ||
| Message: "No AWS access key found", | ||
| } | ||
| } | ||
|
|
||
| return ValidationResult{ | ||
| Check: "AWS credentials", | ||
| Passed: true, | ||
| Message: fmt.Sprintf("Configured (source: %s)", creds.Source), | ||
| } | ||
| } | ||
|
|
||
| func (m *command) validateComponents(env *v1alpha1.Environment) []ValidationResult { | ||
| results := make([]ValidationResult, 0) | ||
|
|
||
| // Check for common misconfigurations | ||
| if env.Spec.NVIDIAContainerToolkit.Install && !env.Spec.ContainerRuntime.Install { | ||
| results = append(results, ValidationResult{ | ||
| Check: "Component dependencies", | ||
| Passed: false, | ||
| Message: "Warning: Container Toolkit requires a container runtime", | ||
| }) | ||
| } | ||
|
|
||
| if env.Spec.Kubernetes.Install && !env.Spec.ContainerRuntime.Install { | ||
| results = append(results, ValidationResult{ | ||
| Check: "Component dependencies", | ||
| Passed: false, | ||
| Message: "Warning: Kubernetes requires a container runtime", | ||
| }) | ||
| } | ||
|
|
||
| // Check driver branch/version | ||
| if env.Spec.NVIDIADriver.Install { | ||
| if env.Spec.NVIDIADriver.Version != "" && env.Spec.NVIDIADriver.Branch != "" { | ||
| results = append(results, ValidationResult{ | ||
| Check: "NVIDIA Driver config", | ||
| Passed: true, | ||
| Message: "Both version and branch specified; version takes precedence", | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| // Kubernetes installer validation | ||
| if env.Spec.Kubernetes.Install { | ||
| installer := env.Spec.Kubernetes.KubernetesInstaller | ||
| if installer == "" { | ||
| installer = "kubeadm" | ||
| } | ||
| validInstallers := map[string]bool{"kubeadm": true, "kind": true, "microk8s": true} | ||
| if !validInstallers[installer] { | ||
| results = append(results, ValidationResult{ | ||
| Check: "Kubernetes installer", | ||
| Passed: false, | ||
| Message: fmt.Sprintf("Warning: Unknown installer %q, expected kubeadm/kind/microk8s", installer), | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| return results | ||
| } | ||
|
|
||
| func (m *command) printResults(results []ValidationResult) { | ||
| fmt.Println("\n=== Validation Results ===\n") | ||
|
|
||
| for _, r := range results { | ||
| icon := "✓" | ||
| if !r.Passed { | ||
| icon = "✗" | ||
| } | ||
| fmt.Printf(" %s %s\n", icon, r.Check) | ||
| fmt.Printf(" %s\n", r.Message) | ||
| } | ||
| } | ||
|
|
||
| // expandPath expands ~ to home directory | ||
| func expandPath(path string) string { | ||
| if strings.HasPrefix(path, "~/") { | ||
| home, err := os.UserHomeDir() | ||
| if err == nil { | ||
| return strings.Replace(path, "~", home, 1) | ||
| } | ||
| } | ||
| return path | ||
| } |
There was a problem hiding this comment.
This file appears to be a new command implementation that was not mentioned in the PR description. The PR description states it only replaces context.TODO() calls in pkg/provider/aws/create.go with proper timeouts. Adding entirely new CLI commands (validate, provision, describe, get, scp, ssh, update) represents a significant expansion of scope beyond the stated purpose of the PR.
cmd/cli/main.go
Outdated
| @@ -63,9 +70,22 @@ Examples: | |||
| # List all environments | |||
| holodeck list | |||
|
|
|||
| # List environments in JSON format | |||
| holodeck list -o json | |||
|
|
|||
| # Get status of a specific environment | |||
| holodeck status <instance-id> | |||
|
|
|||
| # SSH into an instance | |||
| holodeck ssh <instance-id> | |||
|
|
|||
| # Run a command on an instance | |||
| holodeck ssh <instance-id> -- nvidia-smi | |||
|
|
|||
| # Copy files to/from an instance | |||
| holodeck scp ./local-file.txt <instance-id>:/remote/path/ | |||
| holodeck scp <instance-id>:/remote/file.log ./local/ | |||
|
|
|||
| # Delete an environment | |||
| holodeck delete <instance-id> | |||
|
|
|||
| @@ -93,10 +113,17 @@ Examples: | |||
| cleanup.NewCommand(log), | |||
| create.NewCommand(log), | |||
| delete.NewCommand(log), | |||
| describe.NewCommand(log), | |||
| dryrun.NewCommand(log), | |||
| get.NewCommand(log), | |||
| list.NewCommand(log), | |||
| oscmd.NewCommand(log), | |||
| provision.NewCommand(log), | |||
| scp.NewCommand(log), | |||
| ssh.NewCommand(log), | |||
| status.NewCommand(log), | |||
| update.NewCommand(log), | |||
| validate.NewCommand(log), | |||
There was a problem hiding this comment.
These changes add multiple new CLI commands (describe, get, provision, scp, ssh, update, validate) that were not mentioned in the PR description. The PR description states it only replaces context.TODO() calls in pkg/provider/aws/create.go with proper timeouts. Adding new commands and updating help text represents a significant expansion of scope beyond the stated purpose.
| ctx, cancel := context.WithTimeout(context.Background(), defaultVPCTimeout) | ||
|
|
||
|
|
||
| defer cancel() | ||
|
|
||
|
|
||
| vpcOutput, err := p.ec2.CreateVpc(ctx, vpcInput) |
There was a problem hiding this comment.
The context timeout starts at line 114 but this AWS API call (CreateVpc) uses the context on line 120. However, the subsequent ModifyVpcAttribute call on line 132 (outside changed region) still uses context.Background() without a timeout. Consider whether all AWS API calls in this function should share the same timeout context.
| } | ||
| gwOutput, err := p.ec2.CreateInternetGateway(context.TODO(), gwInput) | ||
| ctx, cancel := context.WithTimeout(context.Background(), defaultIGWTimeout) | ||
|
|
There was a problem hiding this comment.
There are extra blank lines added here (lines 186-187) that create inconsistent spacing. This doesn't follow typical Go code formatting conventions. Remove the extra blank lines.
| } | ||
| rtOutput, err := p.ec2.CreateRouteTable(context.TODO(), rtInput) | ||
| ctx, cancel := context.WithTimeout(context.Background(), defaultRouteTableTimeout) | ||
|
|
There was a problem hiding this comment.
There are extra blank lines added here (lines 229-230) that create inconsistent spacing. This doesn't follow typical Go code formatting conventions. Remove the extra blank lines.
| networkInterfaceId := *instance.NetworkInterfaces[0].NetworkInterfaceId | ||
| _, err = p.ec2.CreateTags(context.TODO(), &ec2.CreateTagsInput{ | ||
| ctx, cancel := context.WithTimeout(context.Background(), defaultEC2Timeout) | ||
|
|
There was a problem hiding this comment.
There's an extra blank line added here (line 437) that creates inconsistent spacing. This doesn't follow typical Go code formatting conventions. Remove the extra blank line.
| ctx, cancel := context.WithTimeout(context.Background(), defaultEC2Timeout) | ||
|
|
||
| defer cancel() | ||
|
|
||
| _, err = p.ec2.CreateTags(ctx, &ec2.CreateTagsInput{ |
There was a problem hiding this comment.
The context with timeout is created here but is only used for the CreateTags and ModifyNetworkInterfaceAttribute calls later in the function. However, the RunInstances call on line 401, the waiter.Wait call on line 416, and the DescribeInstances call on line 424 (all outside the changed region) still use context.Background() without timeouts. These are the most critical and time-consuming operations in EC2 instance creation and should also have proper timeout protection. Consider moving the context creation to the beginning of the function and using it for all AWS API calls.
cmd/cli/provision/provision.go
Outdated
| /* | ||
| * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package provision | ||
|
|
||
| import ( | ||
| "fmt" | ||
| "os" | ||
| "path/filepath" | ||
|
|
||
| "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" | ||
| "github.com/NVIDIA/holodeck/internal/instances" | ||
| "github.com/NVIDIA/holodeck/internal/logger" | ||
| "github.com/NVIDIA/holodeck/pkg/jyaml" | ||
| "github.com/NVIDIA/holodeck/pkg/provider/aws" | ||
| "github.com/NVIDIA/holodeck/pkg/provisioner" | ||
| "github.com/NVIDIA/holodeck/pkg/utils" | ||
|
|
||
| cli "github.com/urfave/cli/v2" | ||
| ) | ||
|
|
||
| type command struct { | ||
| log *logger.FunLogger | ||
| cachePath string | ||
| kubeconfig string | ||
|
|
||
| // SSH provider flags | ||
| sshMode bool | ||
| host string | ||
| keyPath string | ||
| username string | ||
| envFile string | ||
| } | ||
|
|
||
| // NewCommand constructs the provision command with the specified logger | ||
| func NewCommand(log *logger.FunLogger) *cli.Command { | ||
| c := command{ | ||
| log: log, | ||
| } | ||
| return c.build() | ||
| } | ||
|
|
||
| func (m *command) build() *cli.Command { | ||
| provisionCmd := cli.Command{ | ||
| Name: "provision", | ||
| Usage: "Provision or re-provision a Holodeck instance", | ||
| ArgsUsage: "[instance-id]", | ||
| Description: `Provision or re-provision an existing Holodeck instance. | ||
|
|
||
| This command runs the provisioning scripts on an instance. Because templates | ||
| are idempotent, it's safe to re-run provisioning to add components or recover | ||
| from failures. | ||
|
|
||
| Modes: | ||
| 1. Instance mode: Provision an existing instance by ID | ||
| 2. SSH mode: Provision a remote host directly (no instance required) | ||
|
|
||
| Examples: | ||
| # Provision an existing instance | ||
| holodeck provision abc123 | ||
|
|
||
| # Re-provision with kubeconfig download | ||
| holodeck provision abc123 -k ./kubeconfig | ||
|
|
||
| # SSH mode: Provision a remote host directly | ||
| holodeck provision --ssh --host 1.2.3.4 --key ~/.ssh/id_rsa -f env.yaml | ||
|
|
||
| # SSH mode with custom username | ||
| holodeck provision --ssh --host myhost.example.com --key ~/.ssh/key --user ec2-user -f env.yaml`, | ||
| Flags: []cli.Flag{ | ||
| &cli.StringFlag{ | ||
| Name: "cachepath", | ||
| Aliases: []string{"c"}, | ||
| Usage: "Path to the cache directory", | ||
| Destination: &m.cachePath, | ||
| }, | ||
| &cli.StringFlag{ | ||
| Name: "kubeconfig", | ||
| Aliases: []string{"k"}, | ||
| Usage: "Path to save the kubeconfig file", | ||
| Destination: &m.kubeconfig, | ||
| }, | ||
| // SSH mode flags | ||
| &cli.BoolFlag{ | ||
| Name: "ssh", | ||
| Usage: "SSH mode: provision a remote host directly", | ||
| Destination: &m.sshMode, | ||
| }, | ||
| &cli.StringFlag{ | ||
| Name: "host", | ||
| Usage: "SSH mode: remote host address", | ||
| Destination: &m.host, | ||
| }, | ||
| &cli.StringFlag{ | ||
| Name: "key", | ||
| Usage: "SSH mode: path to SSH private key", | ||
| Destination: &m.keyPath, | ||
| }, | ||
| &cli.StringFlag{ | ||
| Name: "user", | ||
| Aliases: []string{"u"}, | ||
| Usage: "SSH mode: SSH username (default: ubuntu)", | ||
| Destination: &m.username, | ||
| Value: "ubuntu", | ||
| }, | ||
| &cli.StringFlag{ | ||
| Name: "envFile", | ||
| Aliases: []string{"f"}, | ||
| Usage: "Path to the Environment file (required for SSH mode)", | ||
| Destination: &m.envFile, | ||
| }, | ||
| }, | ||
| Action: func(c *cli.Context) error { | ||
| if m.sshMode { | ||
| return m.runSSHMode() | ||
| } | ||
|
|
||
| if c.NArg() != 1 { | ||
| return fmt.Errorf("instance ID is required (or use --ssh mode)") | ||
| } | ||
| return m.runInstanceMode(c.Args().Get(0)) | ||
| }, | ||
| } | ||
|
|
||
| return &provisionCmd | ||
| } | ||
|
|
||
| func (m *command) runInstanceMode(instanceID string) error { | ||
| // Get instance details | ||
| manager := instances.NewManager(m.log, m.cachePath) | ||
| instance, err := manager.GetInstance(instanceID) | ||
| if err != nil { | ||
| return fmt.Errorf("failed to get instance: %v", err) | ||
| } | ||
|
|
||
| // Load environment | ||
| env, err := jyaml.UnmarshalFromFile[v1alpha1.Environment](instance.CacheFile) | ||
| if err != nil { | ||
| return fmt.Errorf("failed to read environment: %v", err) | ||
| } | ||
|
|
||
| m.log.Info("Provisioning instance %s...", instanceID) | ||
|
|
||
| // Run provisioning based on instance type | ||
| if env.Spec.Cluster != nil && env.Status.Cluster != nil && len(env.Status.Cluster.Nodes) > 0 { | ||
| if err := m.runClusterProvision(&env); err != nil { | ||
| return err | ||
| } | ||
| } else { | ||
| if err := m.runSingleNodeProvision(&env); err != nil { | ||
| return err | ||
| } | ||
| } | ||
|
|
||
| // Update provisioned status | ||
| env.Labels[instances.InstanceProvisionedLabelKey] = "true" | ||
| data, err := jyaml.MarshalYAML(env) | ||
| if err != nil { | ||
| return fmt.Errorf("failed to marshal environment: %v", err) | ||
| } | ||
| if err := os.WriteFile(instance.CacheFile, data, 0600); err != nil { | ||
| return fmt.Errorf("failed to update cache file: %v", err) | ||
| } | ||
|
|
||
| // Download kubeconfig if requested and Kubernetes is installed | ||
| if m.kubeconfig != "" && env.Spec.Kubernetes.Install { | ||
| hostUrl, err := m.getHostURL(&env) | ||
| if err != nil { | ||
| m.log.Warning("Failed to get host URL for kubeconfig: %v", err) | ||
| } else { | ||
| if err := utils.GetKubeConfig(m.log, &env, hostUrl, m.kubeconfig); err != nil { | ||
| m.log.Warning("Failed to download kubeconfig: %v", err) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| m.log.Info("✅ Provisioning completed successfully") | ||
| return nil | ||
| } | ||
|
|
||
| func (m *command) runSSHMode() error { | ||
| // Validate SSH mode flags | ||
| if m.host == "" { | ||
| return fmt.Errorf("--host is required in SSH mode") | ||
| } | ||
| if m.keyPath == "" { | ||
| return fmt.Errorf("--key is required in SSH mode") | ||
| } | ||
| if m.envFile == "" { | ||
| return fmt.Errorf("--envFile/-f is required in SSH mode") | ||
| } | ||
|
|
||
| // Load environment file | ||
| env, err := jyaml.UnmarshalFromFile[v1alpha1.Environment](m.envFile) | ||
| if err != nil { | ||
| return fmt.Errorf("failed to read environment file: %v", err) | ||
| } | ||
|
|
||
| // Override with SSH mode settings | ||
| env.Spec.Provider = v1alpha1.ProviderSSH | ||
| env.Spec.HostUrl = m.host | ||
| env.Spec.PrivateKey = m.keyPath | ||
| env.Spec.Username = m.username | ||
|
|
||
| m.log.Info("Provisioning %s via SSH...", m.host) | ||
|
|
||
| // Create provisioner and run | ||
| p, err := provisioner.New(m.log, m.keyPath, m.username, m.host) | ||
| if err != nil { | ||
| return fmt.Errorf("failed to create provisioner: %v", err) | ||
| } | ||
| defer p.Client.Close() | ||
|
|
||
| if err := p.Run(env); err != nil { | ||
| return fmt.Errorf("provisioning failed: %v", err) | ||
| } | ||
|
|
||
| // Download kubeconfig if requested and Kubernetes is installed | ||
| if m.kubeconfig != "" && env.Spec.Kubernetes.Install { | ||
| if err := utils.GetKubeConfig(m.log, &env, m.host, m.kubeconfig); err != nil { | ||
| m.log.Warning("Failed to download kubeconfig: %v", err) | ||
| } | ||
| } | ||
|
|
||
| m.log.Info("✅ Provisioning completed successfully") | ||
| return nil | ||
| } | ||
|
|
||
| func (m *command) runSingleNodeProvision(env *v1alpha1.Environment) error { | ||
| hostUrl, err := m.getHostURL(env) | ||
| if err != nil { | ||
| return fmt.Errorf("failed to get host URL: %v", err) | ||
| } | ||
|
|
||
| // Create provisioner and run | ||
| p, err := provisioner.New(m.log, env.Spec.PrivateKey, env.Spec.Username, hostUrl) | ||
| if err != nil { | ||
| return fmt.Errorf("failed to create provisioner: %v", err) | ||
| } | ||
| defer p.Client.Close() | ||
|
|
||
| return p.Run(*env) | ||
| } | ||
|
|
||
| func (m *command) runClusterProvision(env *v1alpha1.Environment) error { | ||
| // Build node list from cluster status | ||
| var nodes []provisioner.NodeInfo | ||
| for _, node := range env.Status.Cluster.Nodes { | ||
| nodes = append(nodes, provisioner.NodeInfo{ | ||
| Name: node.Name, | ||
| PublicIP: node.PublicIP, | ||
| PrivateIP: node.PrivateIP, | ||
| Role: node.Role, | ||
| SSHUsername: node.SSHUsername, | ||
| }) | ||
| } | ||
|
|
||
| if len(nodes) == 0 { | ||
| return fmt.Errorf("no nodes found in cluster status") | ||
| } | ||
|
|
||
| // Create cluster provisioner | ||
| cp := provisioner.NewClusterProvisioner( | ||
| m.log, | ||
| env.Spec.PrivateKey, | ||
| env.Spec.Username, | ||
| env, | ||
| ) | ||
|
|
||
| return cp.ProvisionCluster(nodes) | ||
| } | ||
|
|
||
| func (m *command) getHostURL(env *v1alpha1.Environment) (string, error) { | ||
| // For multinode clusters, get first control-plane | ||
| if env.Spec.Cluster != nil && env.Status.Cluster != nil && len(env.Status.Cluster.Nodes) > 0 { | ||
| for _, node := range env.Status.Cluster.Nodes { | ||
| if node.Role == "control-plane" { | ||
| return node.PublicIP, nil | ||
| } | ||
| } | ||
| return env.Status.Cluster.Nodes[0].PublicIP, nil | ||
| } | ||
|
|
||
| // Single node - get from properties | ||
| if env.Spec.Provider == v1alpha1.ProviderAWS { | ||
| for _, p := range env.Status.Properties { | ||
| if p.Name == aws.PublicDnsName { | ||
| return p.Value, nil | ||
| } | ||
| } | ||
| } else if env.Spec.Provider == v1alpha1.ProviderSSH { | ||
| return env.Spec.HostUrl, nil | ||
| } | ||
|
|
||
| return "", fmt.Errorf("unable to determine host URL") | ||
| } | ||
|
|
||
| // getKubeconfigPath returns the path to save kubeconfig | ||
| func getKubeconfigPath(instanceID string) string { | ||
| homeDir, err := os.UserHomeDir() | ||
| if err != nil { | ||
| return fmt.Sprintf("kubeconfig-%s", instanceID) | ||
| } | ||
| kubeDir := filepath.Join(homeDir, ".kube") | ||
| _ = os.MkdirAll(kubeDir, 0755) | ||
| return filepath.Join(kubeDir, fmt.Sprintf("config-%s", instanceID)) | ||
| } |
There was a problem hiding this comment.
This file appears to be a new command implementation that was not mentioned in the PR description. The PR description states it only replaces context.TODO() calls in pkg/provider/aws/create.go with proper timeouts. Adding entirely new CLI commands represents a significant expansion of scope beyond the stated purpose of the PR.
pkg/provider/aws/create.go
Outdated
|
|
||
|
|
||
| defer cancel() | ||
|
|
||
|
|
There was a problem hiding this comment.
There are extra blank lines added here (lines 115-116, 118-119) that create inconsistent spacing. This doesn't follow typical Go code formatting conventions. Remove the extra blank lines to maintain clean code style.
| defer cancel() | |
| defer cancel() |
| } | ||
| subnetOutput, err := p.ec2.CreateSubnet(context.TODO(), subnetInput) | ||
| ctx, cancel := context.WithTimeout(context.Background(), defaultSubnetTimeout) | ||
|
|
There was a problem hiding this comment.
There's an extra blank line added here (line 158) that creates inconsistent spacing. This doesn't follow typical Go code formatting conventions. Remove the extra blank line.
501e7f9 to
5b10db8
Compare
- Added timeout constants for VPC, subnet, IGW, route table, security group, EC2, and waiter operations - Replaced 10 context.TODO() calls with context.WithTimeout() calls - Modified 6 functions: createVPC, createSubnet, createInternetGateway, createRouteTable, createSecurityGroup, createEC2Instance - Timeouts: 2 minutes for network resources, 10 minutes for EC2 operations Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
5b10db8 to
30fe15e
Compare
…VIDIA#611) - Added timeout constants for VPC, subnet, IGW, route table, security group, EC2, and waiter operations - Replaced 10 context.TODO() calls with context.WithTimeout() calls - Modified 6 functions: createVPC, createSubnet, createInternetGateway, createRouteTable, createSecurityGroup, createEC2Instance - Timeouts: 2 minutes for network resources, 10 minutes for EC2 operations Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
…VIDIA#611) - Added timeout constants for VPC, subnet, IGW, route table, security group, EC2, and waiter operations - Replaced 10 context.TODO() calls with context.WithTimeout() calls - Modified 6 functions: createVPC, createSubnet, createInternetGateway, createRouteTable, createSecurityGroup, createEC2Instance - Timeouts: 2 minutes for network resources, 10 minutes for EC2 operations Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
…VIDIA#611) - Added timeout constants for VPC, subnet, IGW, route table, security group, EC2, and waiter operations - Replaced 10 context.TODO() calls with context.WithTimeout() calls - Modified 6 functions: createVPC, createSubnet, createInternetGateway, createRouteTable, createSecurityGroup, createEC2Instance - Timeouts: 2 minutes for network resources, 10 minutes for EC2 operations Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
…VIDIA#611) - Added timeout constants for VPC, subnet, IGW, route table, security group, EC2, and waiter operations - Replaced 10 context.TODO() calls with context.WithTimeout() calls - Modified 6 functions: createVPC, createSubnet, createInternetGateway, createRouteTable, createSecurityGroup, createEC2Instance - Timeouts: 2 minutes for network resources, 10 minutes for EC2 operations Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
…VIDIA#611) - Added timeout constants for VPC, subnet, IGW, route table, security group, EC2, and waiter operations - Replaced 10 context.TODO() calls with context.WithTimeout() calls - Modified 6 functions: createVPC, createSubnet, createInternetGateway, createRouteTable, createSecurityGroup, createEC2Instance - Timeouts: 2 minutes for network resources, 10 minutes for EC2 operations Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
…VIDIA#611) - Added timeout constants for VPC, subnet, IGW, route table, security group, EC2, and waiter operations - Replaced 10 context.TODO() calls with context.WithTimeout() calls - Modified 6 functions: createVPC, createSubnet, createInternetGateway, createRouteTable, createSecurityGroup, createEC2Instance - Timeouts: 2 minutes for network resources, 10 minutes for EC2 operations Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
Summary
Replace all
context.TODO()calls inpkg/provider/aws/create.gowith proper timeout contexts.Changes
Timeout Constants Added
defaultVPCTimeoutdefaultSubnetTimeoutdefaultIGWTimeoutdefaultRouteTableTimeoutdefaultSecurityGroupTimeoutdefaultEC2TimeoutdefaultWaiterTimeoutFunctions Modified (6)
createVPC()createSubnet()createInternetGateway()createRouteTable()createSecurityGroup()createEC2Instance()Impact
Test plan
go build ./pkg/provider/aws/...- compilesgo test ./pkg/provider/aws/...- verify no regressions