@@ -344,7 +344,7 @@ func (c *mockClient) ListPolicies(_ context.Context, find *api.PolicyFindMessage
344344 if policy .State == api .Deleted && ! find .ShowDeleted {
345345 continue
346346 }
347- if policy .Name == name {
347+ if name == "policies" || policy .Name == name {
348348 policies = append (policies , policy )
349349 }
350350 }
@@ -365,10 +365,75 @@ func (c *mockClient) GetPolicy(_ context.Context, find *api.PolicyFindMessage) (
365365 return policy , nil
366366}
367367
368- // TODO(ed): update the mock and tests
369368// UpsertPolicy creates or updates the policy.
370- func (* mockClient ) UpsertPolicy (_ context.Context , _ * api.PolicyFindMessage , _ * api.PolicyPatchMessage ) (* api.PolicyMessage , error ) {
371- return nil , nil
369+ func (c * mockClient ) UpsertPolicy (_ context.Context , find * api.PolicyFindMessage , patch * api.PolicyPatchMessage ) (* api.PolicyMessage , error ) {
370+ name := getPolicyRequestName (find )
371+ policy , existed := c .policyMap [name ]
372+
373+ if ! existed {
374+ policy = & api.PolicyMessage {
375+ Name : name ,
376+ State : api .Active ,
377+ }
378+ }
379+
380+ switch patch .Type {
381+ case api .PolicyTypeAccessControl :
382+ if ! existed {
383+ if patch .AccessControlPolicy == nil {
384+ return nil , errors .Errorf ("payload is required to create the policy" )
385+ }
386+ }
387+ if v := patch .AccessControlPolicy ; v != nil {
388+ policy .AccessControlPolicy = v
389+ }
390+ case api .PolicyTypeBackupPlan :
391+ if ! existed {
392+ if patch .BackupPlanPolicy == nil {
393+ return nil , errors .Errorf ("payload is required to create the policy" )
394+ }
395+ }
396+ if v := patch .BackupPlanPolicy ; v != nil {
397+ policy .BackupPlanPolicy = v
398+ }
399+ case api .PolicyTypeDeploymentApproval :
400+ if ! existed {
401+ if patch .DeploymentApprovalPolicy == nil {
402+ return nil , errors .Errorf ("payload is required to create the policy" )
403+ }
404+ }
405+ if v := patch .DeploymentApprovalPolicy ; v != nil {
406+ policy .DeploymentApprovalPolicy = v
407+ }
408+ case api .PolicyTypeSQLReview :
409+ if ! existed {
410+ if patch .SQLReviewPolicy == nil {
411+ return nil , errors .Errorf ("payload is required to create the policy" )
412+ }
413+ }
414+ if v := patch .SQLReviewPolicy ; v != nil {
415+ policy .SQLReviewPolicy = v
416+ }
417+ case api .PolicyTypeSensitiveData :
418+ if ! existed {
419+ if patch .SensitiveDataPolicy == nil {
420+ return nil , errors .Errorf ("payload is required to create the policy" )
421+ }
422+ }
423+ if v := patch .SensitiveDataPolicy ; v != nil {
424+ policy .SensitiveDataPolicy = v
425+ }
426+ default :
427+ return nil , errors .Errorf ("invalid policy type %v" , patch .Type )
428+ }
429+
430+ if v := patch .InheritFromParent ; v != nil {
431+ policy .InheritFromParent = * v
432+ }
433+
434+ c .policyMap [name ] = policy
435+
436+ return policy , nil
372437}
373438
374439// DeletePolicy deletes the policy.
0 commit comments