Skip to content

Commit 75e6c33

Browse files
author
Andrew Hare
committed
fix: PR feedback
1 parent db28180 commit 75e6c33

File tree

3 files changed

+35
-35
lines changed

3 files changed

+35
-35
lines changed

gateway/grpc/server.go

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,32 @@ import (
1717
ctypes "github.com/akash-network/akash-api/go/node/cert/v1beta3"
1818
leasev1 "github.com/akash-network/akash-api/go/provider/lease/v1"
1919
providerv1 "github.com/akash-network/akash-api/go/provider/v1"
20+
cmblog "github.com/tendermint/tendermint/libs/log"
2021

2122
"github.com/akash-network/provider"
2223
"github.com/akash-network/provider/gateway/utils"
2324
"github.com/akash-network/provider/tools/fromctx"
2425
)
2526

27+
var (
28+
_ providerv1.ProviderRPCServer = (*server)(nil)
29+
_ leasev1.LeaseRPCServer = (*server)(nil)
30+
)
31+
32+
type server struct {
33+
*providerV1
34+
*leaseV1
35+
}
36+
2637
func Serve(ctx context.Context, endpoint string, certs []tls.Certificate, c provider.Client) error {
2738
group, err := fromctx.ErrGroupFromCtx(ctx)
2839
if err != nil {
2940
return err
3041
}
3142

32-
var (
33-
grpcSrv = newServer(ctx, certs, c)
34-
log = fromctx.LogcFromCtx(ctx)
35-
)
43+
grpcSrv := newServer(ctx, certs, c)
44+
45+
log := fromctx.LogcFromCtx(ctx)
3646

3747
group.Go(func() error {
3848
grpcLis, err := net.Listen("tcp", endpoint)
@@ -56,16 +66,6 @@ func Serve(ctx context.Context, endpoint string, certs []tls.Certificate, c prov
5666
return nil
5767
}
5868

59-
var (
60-
_ providerv1.ProviderRPCServer = (*server)(nil)
61-
_ leasev1.LeaseRPCServer = (*server)(nil)
62-
)
63-
64-
type server struct {
65-
*providerV1
66-
*leaseV1
67-
}
68-
6969
func newServer(ctx context.Context, certs []tls.Certificate, c provider.Client) *grpc.Server {
7070
// InsecureSkipVerify is set to true due to inability to use normal TLS verification
7171
// certificate validation and authentication performed later in mtlsHandler
@@ -88,7 +88,7 @@ func newServer(ctx context.Context, certs []tls.Certificate, c provider.Client)
8888
}),
8989
grpc.ChainUnaryInterceptor(
9090
mtlsInterceptor(cquery),
91-
errorLogInterceptor(),
91+
errorLogInterceptor(fromctx.LogcFromCtx(ctx)),
9292
),
9393
)
9494

@@ -111,10 +111,10 @@ func newServer(ctx context.Context, certs []tls.Certificate, c provider.Client)
111111
}
112112

113113
func mtlsInterceptor(cquery ctypes.QueryClient) grpc.UnaryServerInterceptor {
114-
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, h grpc.UnaryHandler) (any, error) {
114+
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, next grpc.UnaryHandler) (any, error) {
115115
if p, ok := peer.FromContext(ctx); ok {
116116
if mtls, ok := p.AuthInfo.(credentials.TLSInfo); ok {
117-
owner, err := utils.VerifyCertChain(ctx, mtls.State.PeerCertificates, "", x509.ExtKeyUsageServerAuth, cquery)
117+
owner, err := utils.VerifyOwnerCert(ctx, mtls.State.PeerCertificates, "", x509.ExtKeyUsageServerAuth, cquery)
118118
if err != nil {
119119
return nil, fmt.Errorf("verify cert chain: %w", err)
120120
}
@@ -125,18 +125,18 @@ func mtlsInterceptor(cquery ctypes.QueryClient) grpc.UnaryServerInterceptor {
125125
}
126126
}
127127

128-
return h(ctx, req)
128+
return next(ctx, req)
129129
}
130130
}
131131

132132
// TODO(andrewhare): Possibly replace this with
133133
// https://github.com/grpc-ecosystem/go-grpc-middleware/tree/main/interceptors/logging
134134
// to get full request/response logging?
135-
func errorLogInterceptor() grpc.UnaryServerInterceptor {
136-
return func(ctx context.Context, req any, i *grpc.UnaryServerInfo, h grpc.UnaryHandler) (any, error) {
137-
resp, err := h(ctx, req)
135+
func errorLogInterceptor(l cmblog.Logger) grpc.UnaryServerInterceptor {
136+
return func(ctx context.Context, req any, i *grpc.UnaryServerInfo, next grpc.UnaryHandler) (any, error) {
137+
resp, err := next(ctx, req)
138138
if err != nil {
139-
fromctx.LogcFromCtx(ctx).Error(i.FullMethod, "err", err)
139+
l.Error(i.FullMethod, "err", err)
140140
}
141141

142142
return resp, err

gateway/rest/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ func (c *client) verifyPeerCertificate(certificates [][]byte, _ [][]*x509.Certif
296296
return errors.Errorf("tls: invalid certificate chain")
297297
}
298298

299-
prov, err := utils.VerifyCertChain(
299+
prov, err := utils.VerifyOwnerCert(
300300
context.Background(),
301301
certificates,
302302
c.host.Hostname(),

gateway/utils/utils.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func NewServerTLSConfig(ctx context.Context, certs []tls.Certificate, cquery cty
2323
InsecureSkipVerify: true, // nolint: gosec
2424
MinVersion: tls.VersionTLS13,
2525
VerifyPeerCertificate: func(certificates [][]byte, _ [][]*x509.Certificate) error {
26-
if _, err := VerifyCertChain(ctx, certificates, "", x509.ExtKeyUsageClientAuth, cquery); err != nil {
26+
if _, err := VerifyOwnerCert(ctx, certificates, "", x509.ExtKeyUsageClientAuth, cquery); err != nil {
2727
return err
2828
}
2929
return nil
@@ -33,11 +33,11 @@ func NewServerTLSConfig(ctx context.Context, certs []tls.Certificate, cquery cty
3333
return cfg, nil
3434
}
3535

36-
type certChain interface {
36+
type cert interface {
3737
*x509.Certificate | []byte
3838
}
3939

40-
func VerifyCertChain[T certChain](
40+
func VerifyOwnerCert[T cert](
4141
ctx context.Context,
4242
chain []T,
4343
dnsName string,
@@ -52,31 +52,31 @@ func VerifyCertChain[T certChain](
5252
return nil, errors.Errorf("tls: invalid certificate chain")
5353
}
5454

55-
var cert *x509.Certificate
55+
var c *x509.Certificate
5656

5757
switch t := any(chain).(type) {
5858
case []*x509.Certificate:
59-
cert = t[0]
59+
c = t[0]
6060
case [][]byte:
6161
var err error
62-
if cert, err = x509.ParseCertificate(t[0]); err != nil {
62+
if c, err = x509.ParseCertificate(t[0]); err != nil {
6363
return nil, fmt.Errorf("tls: failed to parse certificate: %w", err)
6464
}
6565
}
6666

6767
// validation
68-
owner, err := sdk.AccAddressFromBech32(cert.Subject.CommonName)
68+
owner, err := sdk.AccAddressFromBech32(c.Subject.CommonName)
6969
if err != nil {
7070
return nil, fmt.Errorf("tls: invalid certificate's subject common name: %w", err)
7171
}
7272

7373
// 1. CommonName in issuer and Subject must match and be as Bech32 format
74-
if cert.Subject.CommonName != cert.Issuer.CommonName {
74+
if c.Subject.CommonName != c.Issuer.CommonName {
7575
return nil, fmt.Errorf("tls: invalid certificate's issuer common name: %w", err)
7676
}
7777

7878
// 2. serial number must be in
79-
if cert.SerialNumber == nil {
79+
if c.SerialNumber == nil {
8080
return nil, fmt.Errorf("tls: invalid certificate serial number: %w", err)
8181
}
8282

@@ -87,7 +87,7 @@ func VerifyCertChain[T certChain](
8787
&ctypes.QueryCertificatesRequest{
8888
Filter: ctypes.CertificateFilter{
8989
Owner: owner.String(),
90-
Serial: cert.SerialNumber.String(),
90+
Serial: c.SerialNumber.String(),
9191
State: "valid",
9292
},
9393
},
@@ -100,7 +100,7 @@ func VerifyCertChain[T certChain](
100100
}
101101

102102
clientCertPool := x509.NewCertPool()
103-
clientCertPool.AddCert(cert)
103+
clientCertPool.AddCert(c)
104104

105105
opts := x509.VerifyOptions{
106106
DNSName: dnsName,
@@ -110,7 +110,7 @@ func VerifyCertChain[T certChain](
110110
MaxConstraintComparisions: 0,
111111
}
112112

113-
if _, err = cert.Verify(opts); err != nil {
113+
if _, err = c.Verify(opts); err != nil {
114114
return nil, fmt.Errorf("tls: unable to verify certificate: %w", err)
115115
}
116116

0 commit comments

Comments
 (0)