|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT license |
| 3 | + |
| 4 | +package internal |
| 5 | + |
| 6 | +import ( |
| 7 | + "context" |
| 8 | + |
| 9 | + "github.com/microsoft/moc/pkg/status" |
| 10 | + "github.com/microsoft/wssd-sdk-for-go/services/security/keyvault" |
| 11 | + |
| 12 | + "github.com/microsoft/moc/pkg/auth" |
| 13 | + "github.com/microsoft/moc/pkg/errors" |
| 14 | + wssdcommonproto "github.com/microsoft/moc/rpc/common" |
| 15 | + wssdsecurity "github.com/microsoft/moc/rpc/nodeagent/security" |
| 16 | + wssdclient "github.com/microsoft/wssd-sdk-for-go/pkg/client" |
| 17 | + log "k8s.io/klog" |
| 18 | +) |
| 19 | + |
| 20 | +type client struct { |
| 21 | + wssdsecurity.KeyAgentClient |
| 22 | +} |
| 23 | + |
| 24 | +// NewKeyClient - creates a client session with the backend wssd agent |
| 25 | +func NewKeyClient(subID string, authorizer auth.Authorizer) (*client, error) { |
| 26 | + c, err := wssdclient.GetKeyClient(&subID, authorizer) |
| 27 | + if err != nil { |
| 28 | + return nil, err |
| 29 | + } |
| 30 | + return &client{c}, nil |
| 31 | +} |
| 32 | + |
| 33 | +// Get |
| 34 | +func (c *client) Get(ctx context.Context, name string, vaultName string) (*[]keyvault.Key, error) { |
| 35 | + request := getKeyRequest(wssdcommonproto.Operation_GET, name, vaultName, nil, 0, nil) |
| 36 | + response, err := c.KeyAgentClient.Invoke(ctx, request) |
| 37 | + if err != nil { |
| 38 | + return nil, err |
| 39 | + } |
| 40 | + return getKeysFromResponse(response), nil |
| 41 | +} |
| 42 | + |
| 43 | +// CreateOrUpdate |
| 44 | +func (c *client) CreateOrUpdate(ctx context.Context, keyIn *keyvault.Key) (*keyvault.Key, error) { |
| 45 | + err := c.validate(ctx, keyIn) |
| 46 | + if err != nil { |
| 47 | + return nil, err |
| 48 | + } |
| 49 | + |
| 50 | + request := getKeyRequest(wssdcommonproto.Operation_POST, "", "", nil, 0, keyIn) |
| 51 | + response, err := c.KeyAgentClient.Invoke(ctx, request) |
| 52 | + if err != nil { |
| 53 | + log.Errorf("[Key] Create failed with error %v", err) |
| 54 | + return nil, err |
| 55 | + } |
| 56 | + |
| 57 | + keys := getKeysFromResponse(response) |
| 58 | + |
| 59 | + if len(*keys) == 0 { |
| 60 | + return nil, errors.New("[Key][Create] Unexpected error: Creating a key returned no result") |
| 61 | + } |
| 62 | + |
| 63 | + return &((*keys)[0]), err |
| 64 | +} |
| 65 | + |
| 66 | +func (c *client) validate(ctx context.Context, key *keyvault.Key) (err error) { |
| 67 | + if key == nil || key.VaultName == nil || key.Name == nil || key.KeyType == nil || getMOCKeySize(key.KeySize) == wssdcommonproto.KeySize_K_UNKNOWN { |
| 68 | + return errors.Wrapf(errors.InvalidInput, "[Key][Create] Invalid Input") |
| 69 | + } |
| 70 | + |
| 71 | + return nil |
| 72 | +} |
| 73 | + |
| 74 | +// Delete methods invokes create or update on the client |
| 75 | +func (c *client) Delete(ctx context.Context, key *keyvault.Key) error { |
| 76 | + keys, err := c.Get(ctx, *key.Name, *key.VaultName) |
| 77 | + if err != nil { |
| 78 | + return err |
| 79 | + } |
| 80 | + if len(*keys) == 0 { |
| 81 | + return errors.Wrapf(errors.NotFound, "Key [%s] not found", *key.Name) |
| 82 | + } |
| 83 | + |
| 84 | + request := getKeyRequest(wssdcommonproto.Operation_DELETE, "", "", nil, 0, &(*keys)[0]) |
| 85 | + _, err = c.KeyAgentClient.Invoke(ctx, request) |
| 86 | + return err |
| 87 | +} |
| 88 | + |
| 89 | +// Rotates a key and returns the new key |
| 90 | +func (c *client) RotateKey(ctx context.Context, keyReq *keyvault.KeyOperationRequest) (*keyvault.KeyOperationResult, error) { |
| 91 | + wssdReq, err := getKeyOperationRequest(keyReq, wssdcommonproto.ProviderAccessOperation_Key_Rotate) |
| 92 | + if err != nil { |
| 93 | + return nil, err |
| 94 | + } |
| 95 | + |
| 96 | + wssdRep, err := c.KeyAgentClient.Operate(ctx, wssdReq) |
| 97 | + |
| 98 | + if err != nil { |
| 99 | + return nil, err |
| 100 | + } |
| 101 | + |
| 102 | + keyOpRes := keyvault.KeyOperationResult{ |
| 103 | + Key: getKey(wssdRep.GetKey()), |
| 104 | + Result: nil} // No result expected from rotate |
| 105 | + |
| 106 | + return &keyOpRes, nil |
| 107 | +} |
| 108 | + |
| 109 | +// Wraps a key and returns the result |
| 110 | +func (c *client) WrapKey(ctx context.Context, keyReq *keyvault.KeyOperationRequest) (*keyvault.KeyOperationResult, error) { |
| 111 | + wssdReq, err := getKeyOperationRequest(keyReq, wssdcommonproto.ProviderAccessOperation_Key_WrapKey) |
| 112 | + if err != nil { |
| 113 | + return nil, err |
| 114 | + } |
| 115 | + |
| 116 | + wssdRep, err := c.KeyAgentClient.Operate(ctx, wssdReq) |
| 117 | + |
| 118 | + if err != nil { |
| 119 | + return nil, err |
| 120 | + } |
| 121 | + |
| 122 | + if wssdRep.Data == "" { |
| 123 | + return nil, errors.New("[Key][Wrap] Unexpected error: Wrapping a key returned no result") |
| 124 | + } |
| 125 | + |
| 126 | + keyOpRes := keyvault.KeyOperationResult{ |
| 127 | + Key: nil, // No key changes expected |
| 128 | + Result: &wssdRep.Data} |
| 129 | + |
| 130 | + return &keyOpRes, nil |
| 131 | +} |
| 132 | + |
| 133 | +// Unwraps a key and returns the result |
| 134 | +func (c *client) UnwrapKey(ctx context.Context, keyReq *keyvault.KeyOperationRequest) (*keyvault.KeyOperationResult, error) { |
| 135 | + wssdReq, err := getKeyOperationRequest(keyReq, wssdcommonproto.ProviderAccessOperation_Key_UnwrapKey) |
| 136 | + if err != nil { |
| 137 | + return nil, err |
| 138 | + } |
| 139 | + |
| 140 | + wssdRep, err := c.KeyAgentClient.Operate(ctx, wssdReq) |
| 141 | + |
| 142 | + if err != nil { |
| 143 | + return nil, err |
| 144 | + } |
| 145 | + |
| 146 | + if wssdRep.Data == "" { |
| 147 | + return nil, errors.New("[Key][Wrap] Unexpected error: Unwrapping a key returned no result") |
| 148 | + } |
| 149 | + |
| 150 | + keyOpRes := keyvault.KeyOperationResult{ |
| 151 | + Key: nil, // No key changes expected |
| 152 | + Result: &wssdRep.Data} |
| 153 | + |
| 154 | + return &keyOpRes, nil |
| 155 | +} |
| 156 | + |
| 157 | +func getKeyOperationRequest(keyReq *keyvault.KeyOperationRequest, op wssdcommonproto.ProviderAccessOperation) (*wssdsecurity.KeyOperationRequest, error) { |
| 158 | + var wssdReq wssdsecurity.KeyOperationRequest |
| 159 | + switch op { |
| 160 | + case wssdcommonproto.ProviderAccessOperation_Key_Rotate: |
| 161 | + wssdReq = wssdsecurity.KeyOperationRequest{ |
| 162 | + Key: getWssdKey(keyReq.Key), |
| 163 | + OperationType: wssdcommonproto.ProviderAccessOperation_Key_Rotate} |
| 164 | + |
| 165 | + case wssdcommonproto.ProviderAccessOperation_Key_UnwrapKey: |
| 166 | + case wssdcommonproto.ProviderAccessOperation_Key_WrapKey: |
| 167 | + wssdReq = wssdsecurity.KeyOperationRequest{ |
| 168 | + Key: getWssdKey(keyReq.Key), |
| 169 | + Algorithm: wssdcommonproto.Algorithm(wssdcommonproto.Algorithm_value[string(*keyReq.Algorithm)]), |
| 170 | + OperationType: op, |
| 171 | + Data: *keyReq.Data} |
| 172 | + default: |
| 173 | + return nil, errors.InvalidInput |
| 174 | + } |
| 175 | + |
| 176 | + return &wssdReq, nil |
| 177 | +} |
| 178 | + |
| 179 | +func getKeysFromResponse(response *wssdsecurity.KeyResponse) *[]keyvault.Key { |
| 180 | + Keys := []keyvault.Key{} |
| 181 | + for _, key := range response.GetKeys() { |
| 182 | + Keys = append(Keys, *(getKey(key))) |
| 183 | + } |
| 184 | + |
| 185 | + return &Keys |
| 186 | +} |
| 187 | + |
| 188 | +func getKeyRequest(opType wssdcommonproto.Operation, name, vaultName string, keyType *keyvault.JSONWebKeyType, keySize int32, key *keyvault.Key) *wssdsecurity.KeyRequest { |
| 189 | + request := &wssdsecurity.KeyRequest{ |
| 190 | + OperationType: opType, |
| 191 | + Keys: []*wssdsecurity.Key{}, |
| 192 | + } |
| 193 | + |
| 194 | + if key != nil { |
| 195 | + request.Keys = append(request.Keys, getWssdKey(key)) |
| 196 | + } else if len(name) > 0 { |
| 197 | + tempKey := &wssdsecurity.Key{ |
| 198 | + Name: name, |
| 199 | + VaultName: vaultName, |
| 200 | + Size: getMOCKeySize(keySize)} |
| 201 | + |
| 202 | + if keyType != nil { |
| 203 | + tempKey.Type = wssdcommonproto.JsonWebKeyType(wssdcommonproto.JsonWebKeyType_value[string(*keyType)]) |
| 204 | + } |
| 205 | + |
| 206 | + request.Keys = append(request.Keys, tempKey) |
| 207 | + } |
| 208 | + return request |
| 209 | +} |
| 210 | + |
| 211 | +func getKey(key *wssdsecurity.Key) *keyvault.Key { |
| 212 | + ct := key.CreationTime.AsTime() |
| 213 | + keyType := keyvault.JSONWebKeyType(wssdcommonproto.JsonWebKeyType_name[int32(key.Type)]) |
| 214 | + return &keyvault.Key{ |
| 215 | + ID: &key.Id, |
| 216 | + Name: &key.Name, |
| 217 | + VaultName: &key.VaultName, |
| 218 | + CreationTime: &ct, |
| 219 | + KeyVersion: key.KeyVersion, |
| 220 | + KeyType: &keyType, |
| 221 | + KeySize: getKeySize(key.Size), |
| 222 | + ProvisioningState: status.GetProvisioningState(key.GetStatus().GetProvisioningStatus())} |
| 223 | +} |
| 224 | + |
| 225 | +func getWssdKey(key *keyvault.Key) *wssdsecurity.Key { |
| 226 | + keyOut := &wssdsecurity.Key{ |
| 227 | + Name: *key.Name, |
| 228 | + VaultName: *key.VaultName, |
| 229 | + Type: wssdcommonproto.JsonWebKeyType(wssdcommonproto.JsonWebKeyType_value[string(*key.KeyType)]), |
| 230 | + Size: getMOCKeySize(key.KeySize), |
| 231 | + KeyVersion: key.KeyVersion} |
| 232 | + |
| 233 | + return keyOut |
| 234 | +} |
| 235 | + |
| 236 | +func getMOCKeySize(size int32) (ksize wssdcommonproto.KeySize) { |
| 237 | + switch size { |
| 238 | + case 256: |
| 239 | + ksize = wssdcommonproto.KeySize__256 |
| 240 | + default: |
| 241 | + ksize = wssdcommonproto.KeySize_K_UNKNOWN |
| 242 | + } |
| 243 | + return |
| 244 | +} |
| 245 | + |
| 246 | +func getKeySize(ksize wssdcommonproto.KeySize) (size int32) { |
| 247 | + switch ksize { |
| 248 | + case wssdcommonproto.KeySize__256: |
| 249 | + size = 256 |
| 250 | + default: |
| 251 | + size = -1 |
| 252 | + } |
| 253 | + return |
| 254 | +} |
0 commit comments