Skip to content
5 changes: 5 additions & 0 deletions k8s/cloud/base/api_deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ spec:
configMapKeyRef:
name: pl-service-config
key: PL_VZMGR_SERVICE
- name: PL_SCRIPTMGR_SERVICE
valueFrom:
configMapKeyRef:
name: pl-service-config
key: PL_SCRIPTMGR_SERVICE
- name: PL_AUTH_SERVICE
valueFrom:
configMapKeyRef:
Expand Down
10 changes: 6 additions & 4 deletions src/cloud/api/api_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func init() {

pflag.String("auth_connector_name", "", "If any, the name of the auth connector to be used with Pixie")
pflag.String("auth_connector_callback_url", "", "If any, the callback URL for the auth connector")
pflag.Bool("disable_script_modification", false, "If script modification should be disallowed to prevent arbitrary script execution")
}

func main() {
Expand Down Expand Up @@ -213,17 +214,18 @@ func main() {
authServer := &controllers.AuthServer{AuthClient: ac}
cloudpb.RegisterAuthServiceServer(s.GRPCServer(), authServer)

vpt := ptproxy.NewVizierPassThroughProxy(nc, vc)
vizierpb.RegisterVizierServiceServer(s.GRPCServer(), vpt)
vizierpb.RegisterVizierDebugServiceServer(s.GRPCServer(), vpt)

sm, err := apienv.NewScriptMgrServiceClient()
if err != nil {
log.WithError(err).Fatal("Failed to init scriptmgr client.")
}
sms := &controllers.ScriptMgrServer{ScriptMgr: sm}
cloudpb.RegisterScriptMgrServer(s.GRPCServer(), sms)

disableScriptModification := viper.GetBool("disable_script_modification")
vpt := ptproxy.NewVizierPassThroughProxy(nc, vc, sm, disableScriptModification)
vizierpb.RegisterVizierServiceServer(s.GRPCServer(), vpt)
vizierpb.RegisterVizierDebugServiceServer(s.GRPCServer(), vpt)

mdIndexName := viper.GetString("md_index_name")
if mdIndexName == "" {
log.Fatal("Must specify a name for the elastic index.")
Expand Down
4 changes: 2 additions & 2 deletions src/cloud/api/apienv/scriptmgr_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
)

func init() {
pflag.String("scriptmgr_service", "scriptmgr-service.plc.svc.local:52000", "The profile service url (load balancer/list is ok)")
pflag.String("scriptmgr_service", "scriptmgr-service.plc.svc.local:52000", "The scriptmgr service url (load balancer/list is ok)")
}

// NewScriptMgrServiceClient creates a new scriptmgr RPC client stub.
Expand All @@ -38,7 +38,7 @@ func NewScriptMgrServiceClient() (scriptmgrpb.ScriptMgrServiceClient, error) {
return nil, err
}

authChannel, err := grpc.Dial(viper.GetString("scripts_service"), dialOpts...)
authChannel, err := grpc.Dial(viper.GetString("scriptmgr_service"), dialOpts...)
if err != nil {
return nil, err
}
Expand Down
4 changes: 4 additions & 0 deletions src/cloud/api/ptproxy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ go_library(
deps = [
"//src/api/proto/uuidpb:uuid_pl_go_proto",
"//src/api/proto/vizierpb:vizier_pl_go_proto",
"//src/cloud/scriptmgr/scriptmgrpb:service_pl_go_proto",
"//src/cloud/shared/vzshard",
"//src/shared/cvmsgspb:cvmsgs_pl_go_proto",
"//src/shared/services/authcontext",
"//src/shared/services/jwtpb:jwt_pl_go_proto",
"//src/shared/services/utils",
"//src/utils",
"@com_github_gofrs_uuid//:uuid",
"@com_github_gogo_protobuf//proto",
"@com_github_gogo_protobuf//types",
"@com_github_nats_io_nats_go//:nats_go",
"@com_github_sirupsen_logrus//:logrus",
"@com_github_spf13_viper//:viper",
"@org_golang_google_grpc//:grpc",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//metadata",
Expand All @@ -53,6 +56,7 @@ pl_go_test(
":ptproxy",
"//src/api/proto/uuidpb:uuid_pl_go_proto",
"//src/api/proto/vizierpb:vizier_pl_go_proto",
"//src/cloud/scriptmgr/scriptmgrpb:service_pl_go_proto",
"//src/cloud/shared/vzshard",
"//src/shared/cvmsgspb:cvmsgs_pl_go_proto",
"//src/shared/services/env",
Expand Down
62 changes: 58 additions & 4 deletions src/cloud/api/ptproxy/vizier_pt_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,75 @@ package ptproxy

import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"

"github.com/nats-io/nats.go"
"github.com/spf13/viper"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"px.dev/pixie/src/api/proto/uuidpb"
"px.dev/pixie/src/api/proto/vizierpb"
"px.dev/pixie/src/cloud/scriptmgr/scriptmgrpb"
"px.dev/pixie/src/shared/cvmsgspb"
"px.dev/pixie/src/shared/services/authcontext"
"px.dev/pixie/src/shared/services/jwtpb"
jwtutils "px.dev/pixie/src/shared/services/utils"
)

type vzmgrClient interface {
GetVizierInfo(ctx context.Context, in *uuidpb.UUID, opts ...grpc.CallOption) (*cvmsgspb.VizierInfo, error)
GetVizierConnectionInfo(ctx context.Context, in *uuidpb.UUID, opts ...grpc.CallOption) (*cvmsgspb.VizierConnectionInfo, error)
}

type scriptmgrClient interface {
GetScriptByHash(ctx context.Context, req *scriptmgrpb.GetScriptByHashReq, opts ...grpc.CallOption) (*scriptmgrpb.GetScriptByHashResp, error)
}

// VizierPassThroughProxy implements the VizierAPI and allows proxying the data to the actual
// vizier cluster.
type VizierPassThroughProxy struct {
nc *nats.Conn
vc vzmgrClient
nc *nats.Conn
vc vzmgrClient
sm scriptmgrClient
disableScriptModifiation bool
}

// getServiceCredentials returns JWT credentials for inter-service requests.
func getServiceCredentials(signingKey string) (string, error) {
claims := jwtutils.GenerateJWTForService("cloud api", viper.GetString("domain_name"))
return jwtutils.SignJWTClaims(claims, signingKey)
}

// NewVizierPassThroughProxy creates a new passthrough proxy.
func NewVizierPassThroughProxy(nc *nats.Conn, vc vzmgrClient) *VizierPassThroughProxy {
return &VizierPassThroughProxy{nc: nc, vc: vc}
func NewVizierPassThroughProxy(nc *nats.Conn, vc vzmgrClient, sm scriptmgrClient, disableScriptModifiation bool) *VizierPassThroughProxy {
return &VizierPassThroughProxy{nc: nc, vc: vc, sm: sm, disableScriptModifiation: disableScriptModifiation}
}

func (v *VizierPassThroughProxy) isScriptModified(ctx context.Context, script string) (bool, error) {
hash := sha256.New()
hash.Write([]byte(script))
hashStr := hex.EncodeToString(hash.Sum(nil))
req := &scriptmgrpb.GetScriptByHashReq{Sha256Hash: hashStr}

serviceAuthToken, err := getServiceCredentials(viper.GetString("jwt_signing_key"))
ctx = metadata.AppendToOutgoingContext(ctx, "authorization",
fmt.Sprintf("bearer %s", serviceAuthToken))

if err != nil {
return false, err
}

resp, err := v.sm.GetScriptByHash(ctx, req)

if err != nil {
return false, err
}
return !resp.Exists, nil
}

// ExecuteScript is the GRPC stream method.
Expand All @@ -55,6 +98,17 @@ func (v *VizierPassThroughProxy) ExecuteScript(req *vizierpb.ExecuteScriptReques
return err
}
defer rp.Finish()
if v.disableScriptModifiation {
modified, err := v.isScriptModified(srv.Context(), req.QueryStr)
if err != nil {
return err
}

if modified {
return status.Error(codes.InvalidArgument, "Script modification has been disabled")
}
}

vizReq := rp.prepareVizierRequest()
vizReq.Msg = &cvmsgspb.C2VAPIStreamRequest_ExecReq{ExecReq: req}
if err := rp.sendMessageToVizier(vizReq); err != nil {
Expand Down
Loading
Loading