Skip to content

Commit 1167cb9

Browse files
authored
fix: Fix casing during reading of resource group (#1673)
* fix: Fix casing of filter name for resource group * fix: fix broken unit test
1 parent a8f6ca1 commit 1167cb9

File tree

3 files changed

+196
-25
lines changed

3 files changed

+196
-25
lines changed

api/resource_groups.go

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
_ "embed"
2323
"encoding/json"
2424
"fmt"
25+
"strings"
2526
"time"
2627

2728
"github.com/pkg/errors"
@@ -95,14 +96,52 @@ func (svc *ResourceGroupsService) List() (response ResourceGroupsResponse, err e
9596
return rawResponse, err
9697
}
9798

99+
err = sanitizeFieldsInRawResponseList(&rawResponse, &response)
100+
if err != nil {
101+
return rawResponse, err
102+
}
103+
98104
return rawResponse, nil
99105
}
100106

107+
func sanitizeFieldsInRawResponse(rawResponse *ResourceGroupResponse, response interface{}) error {
108+
// update filters keys to match the query template
109+
updateFiltersKeys(&rawResponse.Data)
110+
111+
j, err := json.Marshal(rawResponse)
112+
if err != nil {
113+
return err
114+
}
115+
116+
return json.Unmarshal(j, &response)
117+
}
118+
119+
func sanitizeFieldsInRawResponseList(rawResponse *ResourceGroupsResponse, response interface{}) error {
120+
for i := range rawResponse.Data {
121+
// update filters keys to match the query template
122+
updateFiltersKeys(&rawResponse.Data[i])
123+
}
124+
125+
j, err := json.Marshal(rawResponse)
126+
if err != nil {
127+
return err
128+
}
129+
130+
return json.Unmarshal(j, &response)
131+
}
132+
101133
func (svc *ResourceGroupsService) Create(group ResourceGroupData) (
102134
response ResourceGroupResponse,
103135
err error,
104136
) {
105-
err = svc.create(group, &response)
137+
var rawResponse ResourceGroupResponse
138+
err = svc.create(group, &rawResponse)
139+
if err != nil {
140+
return
141+
}
142+
143+
err = sanitizeFieldsInRawResponse(&rawResponse, &response)
144+
106145
return
107146
}
108147

@@ -117,14 +156,58 @@ func (svc *ResourceGroupsService) Update(data *ResourceGroupData) (
117156
guid := data.ID()
118157
data.ResetResourceGUID()
119158

120-
err = svc.update(guid, data, &response)
159+
var rawResponse ResourceGroupResponse
160+
err = svc.update(guid, data, &rawResponse)
161+
121162
if err != nil {
122163
return
123164
}
124165

166+
err = sanitizeFieldsInRawResponse(&rawResponse, &response)
167+
125168
return
126169
}
127170

171+
func collectFilterNames(children []*RGChild, filterNames map[string]string) {
172+
for _, child := range children {
173+
if child.FilterName != "" {
174+
normalizedKey := strings.ReplaceAll(strings.ToLower(child.FilterName), "_", "")
175+
filterNames[normalizedKey] = child.FilterName
176+
}
177+
if len(child.Children) > 0 {
178+
collectFilterNames(child.Children, filterNames)
179+
}
180+
}
181+
}
182+
183+
/*
184+
updateFiltersKeys updates the keys in the Filters map of ResourceGroupData to ensure they match the filter names
185+
defined in the nested children of the query expression. This is necessary because JSON decoding/encoding can
186+
convert keys to camel case, causing mismatches. The function normalizes the keys by removing underscores and
187+
converting them to lower case, then compares them with the filter names. If a mismatch is found, the key is
188+
updated to the value in RGExpression.Children
189+
*/
190+
func updateFiltersKeys(data *ResourceGroupData) {
191+
if data.Query == nil || data.Query.Expression == nil {
192+
return
193+
}
194+
195+
filterNames := make(map[string]string)
196+
collectFilterNames(data.Query.Expression.Children, filterNames)
197+
198+
updatedFilters := make(map[string]*RGFilter)
199+
for key, value := range data.Query.Filters {
200+
normalizedKey := strings.ReplaceAll(strings.ToLower(key), "_", "")
201+
if _, exists := filterNames[normalizedKey]; exists {
202+
updatedFilters[filterNames[normalizedKey]] = value
203+
} else {
204+
updatedFilters[key] = value
205+
}
206+
}
207+
208+
data.Query.Filters = updatedFilters
209+
}
210+
128211
func (group *ResourceGroupData) ResetResourceGUID() {
129212
group.ResourceGroupGuid = ""
130213
group.UpdatedBy = ""
@@ -149,20 +232,17 @@ func (svc *ResourceGroupsService) Delete(guid string) error {
149232

150233
func (svc *ResourceGroupsService) Get(guid string, response interface{}) error {
151234
var rawResponse ResourceGroupResponse
235+
152236
err := svc.get(guid, &rawResponse)
153237
if err != nil {
154238
return err
155239
}
156240

157-
j, err := json.Marshal(rawResponse)
241+
err = sanitizeFieldsInRawResponse(&rawResponse, response)
158242
if err != nil {
159243
return err
160244
}
161245

162-
err = json.Unmarshal(j, &response)
163-
if err != nil {
164-
return err
165-
}
166246
return nil
167247
}
168248

api/resource_groups_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package api_test
2020

2121
import (
22+
"encoding/json"
2223
"fmt"
2324
"net/http"
2425
"strings"
@@ -132,6 +133,113 @@ func TestResourceGroupGet(t *testing.T) {
132133
})
133134
}
134135

136+
func TestResourceGroupsGetCorrectlyParsersFilterNames(t *testing.T) {
137+
var (
138+
queryJson = `
139+
{
140+
"expression": {
141+
"children": [
142+
{
143+
"filterName": "filter_account"
144+
},
145+
{
146+
"filterName": "filter1"
147+
},
148+
{
149+
"filterName": "filter2"
150+
},
151+
{
152+
"children": [
153+
{
154+
"filterName": "team_Account"
155+
}
156+
],
157+
"operator": "OR"
158+
}
159+
160+
],
161+
"operator": "AND"
162+
},
163+
"filters": {
164+
"filter1": {
165+
"field": "Resource Tag",
166+
"key": "Hostname",
167+
"operation": "INCLUDES",
168+
"values": [
169+
"*"
170+
]
171+
},
172+
"filter2": {
173+
"field": "Region",
174+
"operation": "STARTS_WITH",
175+
"values": [
176+
"ap-south"
177+
]
178+
},
179+
"filter_account": {
180+
"field": "Account",
181+
"operation": "EQUALS",
182+
"values": [
183+
"123456789012"
184+
]
185+
},
186+
"team_Account": {
187+
"field": "Account",
188+
"operation": "EQUALS",
189+
"values": [
190+
"123456789012"
191+
]
192+
}
193+
}
194+
}
195+
`
196+
resourceGUID = intgguid.New()
197+
vanillaType = "VANILLA"
198+
apiPath = fmt.Sprintf("ResourceGroups/%s", resourceGUID)
199+
vanillaGroup = singleVanillaResourceGroup(resourceGUID, vanillaType, queryJson)
200+
fakeServer = lacework.MockServer()
201+
)
202+
203+
fakeServer.MockToken("TOKEN")
204+
defer fakeServer.Close()
205+
206+
fakeServer.MockAPI(apiPath,
207+
func(w http.ResponseWriter, r *http.Request) {
208+
if assert.Equal(t, "GET", r.Method, "Get() should be a GET method") {
209+
fmt.Fprintf(w, generateResourceGroupResponse(vanillaGroup))
210+
}
211+
},
212+
)
213+
214+
c, err := api.NewClient("test",
215+
api.WithToken("TOKEN"),
216+
api.WithURL(fakeServer.URL()),
217+
)
218+
219+
assert.Nil(t, err)
220+
221+
t.Run("when resource groups GET is called. Filter keys are correctly parsed", func(t *testing.T) {
222+
var response api.ResourceGroupResponse
223+
err := c.V2.ResourceGroups.Get(resourceGUID, &response)
224+
assert.Nil(t, err)
225+
if assert.NotNil(t, response) {
226+
assert.Equal(t, resourceGUID, response.Data.ResourceGroupGuid)
227+
assert.Equal(t, "group_name", response.Data.Name)
228+
assert.Equal(t, "VANILLA", response.Data.Type)
229+
// assert that the filter names in queryjson matach RGQuery
230+
var expectedQuery api.RGQuery
231+
err = json.Unmarshal([]byte(queryJson), &expectedQuery)
232+
assert.Nil(t, err)
233+
234+
assert.NotNil(t, response.Data.Query.Filters["filter_account"])
235+
assert.Equal(t, expectedQuery.Filters["filter_account"], response.Data.Query.Filters["filter_account"])
236+
237+
assert.NotNil(t, response.Data.Query.Filters["team_Account"])
238+
assert.Equal(t, expectedQuery.Filters["team_Account"], response.Data.Query.Filters["team_Account"])
239+
}
240+
})
241+
}
242+
135243
func TestResourceGroupsDelete(t *testing.T) {
136244
var (
137245
resourceGUID = intgguid.New()

cli/cmd/resource_groups.go

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -208,24 +208,7 @@ func promptCreateResourceGroup() error {
208208
return err
209209
}
210210

211-
switch group {
212-
case "AWS":
213-
return createResourceGroup("AWS")
214-
case "AZURE":
215-
return createResourceGroup("AZURE")
216-
case "GCP":
217-
return createResourceGroup("GCP")
218-
case "CONTAINER":
219-
return createResourceGroup("CONTAINER")
220-
case "MACHINE":
221-
return createResourceGroup("MACHINE")
222-
case "OCI":
223-
return createResourceGroup("OCI")
224-
case "KUBERNETES":
225-
return createResourceGroup("KUBERNETES")
226-
default:
227-
return errors.New("unknown resource group type")
228-
}
211+
return createResourceGroup(group)
229212
}
230213

231214
func init() {

0 commit comments

Comments
 (0)