|
15 | 15 | package auth_providers_test |
16 | 16 |
|
17 | 17 | import ( |
| 18 | + "crypto/tls" |
| 19 | + "encoding/pem" |
18 | 20 | "fmt" |
19 | 21 | "net/http" |
| 22 | + "net/url" |
20 | 23 | "os" |
| 24 | + "path/filepath" |
21 | 25 | "strings" |
22 | 26 | "testing" |
23 | 27 |
|
@@ -107,7 +111,19 @@ func TestCommandConfigOauth_Authenticate(t *testing.T) { |
107 | 111 | t.FailNow() |
108 | 112 | } |
109 | 113 |
|
110 | | - caCertPath := "../lib/certs/int-oidc-lab.eastus2.cloudapp.azure.com.crt" |
| 114 | + hostName := os.Getenv(auth_providers.EnvKeyfactorHostName) |
| 115 | + caCertPath := fmt.Sprintf("../lib/certs/%s.crt", hostName) |
| 116 | + // check if the caCertPath exists and if not then reach out to host to get the cert and save it to the path |
| 117 | + if _, err := os.Stat(caCertPath); os.IsNotExist(err) { |
| 118 | + // get the cert from the host |
| 119 | + dErr := DownloadCertificate(hostName, caCertPath) |
| 120 | + if dErr != nil { |
| 121 | + t.Errorf("unable to download certificate from %s: %v", hostName, dErr) |
| 122 | + t.FailNow() |
| 123 | + } |
| 124 | + |
| 125 | + // save the cert to the |
| 126 | + } |
111 | 127 |
|
112 | 128 | //Delete the config file |
113 | 129 | t.Logf("Deleting config file: %s", configFilePath) |
@@ -434,3 +450,83 @@ func unsetOAuthEnvVariables() { |
434 | 450 | //os.Unsetenv(auth_providers.EnvKeyfactorDomain) |
435 | 451 |
|
436 | 452 | } |
| 453 | + |
| 454 | +// DownloadCertificate fetches the SSL certificate chain from the given URL or hostname |
| 455 | +// while ignoring SSL verification and saves it to a file named "<hostname>.crt". |
| 456 | +func DownloadCertificate(input string, outputPath string) error { |
| 457 | + // Ensure the input has a scheme; default to "https://" |
| 458 | + if !strings.HasPrefix(input, "http://") && !strings.HasPrefix(input, "https://") { |
| 459 | + input = "https://" + input |
| 460 | + } |
| 461 | + |
| 462 | + // Parse the URL |
| 463 | + parsedURL, err := url.Parse(input) |
| 464 | + if err != nil { |
| 465 | + return fmt.Errorf("invalid URL: %v", err) |
| 466 | + } |
| 467 | + |
| 468 | + hostname := parsedURL.Hostname() |
| 469 | + if hostname == "" { |
| 470 | + return fmt.Errorf("could not determine hostname from URL: %s", input) |
| 471 | + } |
| 472 | + |
| 473 | + // Set default output path to current working directory if none is provided |
| 474 | + if outputPath == "" { |
| 475 | + cwd, err := os.Getwd() |
| 476 | + if err != nil { |
| 477 | + return fmt.Errorf("failed to get current working directory: %v", err) |
| 478 | + } |
| 479 | + outputPath = cwd |
| 480 | + } |
| 481 | + |
| 482 | + // Ensure the output directory exists |
| 483 | + if err := os.MkdirAll(outputPath, os.ModePerm); err != nil { |
| 484 | + return fmt.Errorf("failed to create output directory: %v", err) |
| 485 | + } |
| 486 | + |
| 487 | + // Create the output file |
| 488 | + outputFile := filepath.Join(outputPath, fmt.Sprintf("%s.crt", hostname)) |
| 489 | + file, err := os.Create(outputFile) |
| 490 | + if err != nil { |
| 491 | + return fmt.Errorf("failed to create file %s: %v", outputFile, err) |
| 492 | + } |
| 493 | + defer file.Close() |
| 494 | + |
| 495 | + // Create an HTTP client that ignores SSL verification |
| 496 | + httpClient := &http.Client{ |
| 497 | + Transport: &http.Transport{ |
| 498 | + TLSClientConfig: &tls.Config{ |
| 499 | + InsecureSkipVerify: true, // Ignore SSL certificate verification |
| 500 | + }, |
| 501 | + }, |
| 502 | + } |
| 503 | + |
| 504 | + // Send an HTTP GET request to the server |
| 505 | + resp, err := httpClient.Get(input) |
| 506 | + if err != nil { |
| 507 | + return fmt.Errorf("failed to connect to %s: %v", input, err) |
| 508 | + } |
| 509 | + defer resp.Body.Close() |
| 510 | + |
| 511 | + // Get the TLS connection state from the response |
| 512 | + tlsConnState := resp.TLS |
| 513 | + if tlsConnState == nil { |
| 514 | + return fmt.Errorf("no TLS connection state found") |
| 515 | + } |
| 516 | + |
| 517 | + // Write the entire certificate chain to the output file in PEM format |
| 518 | + for _, cert := range tlsConnState.PeerCertificates { |
| 519 | + err = pem.Encode( |
| 520 | + file, &pem.Block{ |
| 521 | + Type: "CERTIFICATE", |
| 522 | + Bytes: cert.Raw, |
| 523 | + }, |
| 524 | + ) |
| 525 | + if err != nil { |
| 526 | + return fmt.Errorf("failed to write certificate to file: %v", err) |
| 527 | + } |
| 528 | + } |
| 529 | + |
| 530 | + fmt.Printf("Certificate chain saved to: %s\n", outputFile) |
| 531 | + return nil |
| 532 | +} |
0 commit comments