Skip to content

Commit 0ab9ce4

Browse files
mtls refactor for endpoints (#4869) (#3348)
Signed-off-by: Modular Magician <[email protected]>
1 parent f7f6a0b commit 0ab9ce4

File tree

9 files changed

+443
-279
lines changed

9 files changed

+443
-279
lines changed

.changelog/4869.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
```release-note:enhancement
2+
added support for [mtls authentication](https://google.aip.dev/auth/4114)
3+
```

google-beta/config.go

Lines changed: 269 additions & 163 deletions
Large diffs are not rendered by default.

google-beta/mtls_util.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package google
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/url"
7+
"strings"
8+
9+
"google.golang.org/api/option/internaloption"
10+
"google.golang.org/api/transport"
11+
)
12+
13+
// The transport libaray does not natively expose logic to determine whether
14+
// the user is within mtls mode or not. They do return the mtls endpoint if
15+
// it is enabled during client creation so we will use this logic to determine
16+
// the mode the user is in and throw away the client they give us back.
17+
func isMtls() bool {
18+
regularEndpoint := "https://mockservice.googleapis.com/v1/"
19+
mtlsEndpoint := getMtlsEndpoint(regularEndpoint)
20+
_, endpoint, err := transport.NewHTTPClient(context.Background(),
21+
internaloption.WithDefaultEndpoint(regularEndpoint),
22+
internaloption.WithDefaultMTLSEndpoint(mtlsEndpoint),
23+
)
24+
if err != nil {
25+
return false
26+
}
27+
isMtls := endpoint == mtlsEndpoint
28+
return isMtls
29+
}
30+
31+
func getMtlsEndpoint(baseEndpoint string) string {
32+
u, err := url.Parse(baseEndpoint)
33+
if err != nil {
34+
if strings.Contains(baseEndpoint, ".googleapis") {
35+
return strings.Replace(baseEndpoint, ".googleapis", ".mtls.googleapis", 1)
36+
}
37+
return baseEndpoint
38+
}
39+
domainParts := strings.Split(u.Host, ".")
40+
if len(domainParts) > 1 {
41+
u.Host = fmt.Sprintf("%s.mtls.%s", domainParts[0], strings.Join(domainParts[1:], "."))
42+
} else {
43+
u.Host = fmt.Sprintf("%s.mtls", domainParts[0])
44+
}
45+
return u.String()
46+
}

google-beta/mtls_util_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package google
2+
3+
import (
4+
"strings"
5+
"testing"
6+
)
7+
8+
func TestUnitMtls_urlSwitching(t *testing.T) {
9+
t.Parallel()
10+
for key, bp := range DefaultBasePaths {
11+
url := getMtlsEndpoint(bp)
12+
if !strings.Contains(url, ".mtls.") {
13+
t.Errorf("%s: mtls conversion unsuccessful preconv - %s postconv - %s", key, bp, url)
14+
}
15+
}
16+
}

0 commit comments

Comments
 (0)