Skip to content

Commit 2dcb1f3

Browse files
Merge pull request #473 from ramya-bangera/cherrypick-sqli-ss-v0.19
[cherry-pick into v0.19 branch] sanitize sorting to use valid identifiers
2 parents 363d978 + 6057428 commit 2dcb1f3

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

query/sorting.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package query
22

33
import (
44
"fmt"
5+
"regexp"
56
"strings"
67
)
78

@@ -22,6 +23,11 @@ func (c SortCriteria) GoString() string {
2223
return fmt.Sprintf("%s %s", c.Tag, c.Order)
2324
}
2425

26+
// FieldIdentifierRegex is a regular expression that matches valid field
27+
// identifiers. It is used to validate field names in sorting criteria. This can be
28+
// overridden at init() time to allow for custom field name formats.
29+
var FieldIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_\.]*$`)
30+
2531
// ParseSorting parses raw string that represent sort criteria into a Sorting
2632
// data structure.
2733
// Provided string is supposed to be in accordance with the sorting collection
@@ -46,6 +52,14 @@ func ParseSorting(s string) (*Sorting, error) {
4652
default:
4753
return nil, fmt.Errorf("invalid sort criteria: %s", craw)
4854
}
55+
// check if tag is not valid
56+
if !FieldIdentifierRegex.MatchString(c.Tag) {
57+
return nil, fmt.Errorf("invalid field name: %s", c.Tag)
58+
}
59+
// check if tag is not empty
60+
if c.Tag == "" {
61+
return nil, fmt.Errorf("empty field name")
62+
}
4963

5064
sorting.Criterias = append(sorting.Criterias, &c)
5165
}

query/sorting_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package query
22

33
import (
4+
"reflect"
45
"testing"
56
)
67

@@ -59,3 +60,35 @@ func TestParseSorting(t *testing.T) {
5960
t.Errorf("invalid error message: %s - expected: %s", err, "invalid sort order - \"dask\" in \"name dask\"")
6061
}
6162
}
63+
64+
func TestParseSortingInjection(t *testing.T) {
65+
type args struct {
66+
s string
67+
}
68+
tests := []struct {
69+
name string
70+
args args
71+
want *Sorting
72+
wantErr bool
73+
}{
74+
{
75+
name: "subquery",
76+
args: args{
77+
s: "(SELECT/**/1)::int",
78+
},
79+
wantErr: true,
80+
},
81+
}
82+
for _, tt := range tests {
83+
t.Run(tt.name, func(t *testing.T) {
84+
got, err := ParseSorting(tt.args.s)
85+
if (err != nil) != tt.wantErr {
86+
t.Errorf("ParseSorting() error = %v, wantErr %v", err, tt.wantErr)
87+
return
88+
}
89+
if !reflect.DeepEqual(got, tt.want) {
90+
t.Errorf("ParseSorting() = %v, want %v", got, tt.want)
91+
}
92+
})
93+
}
94+
}

0 commit comments

Comments
 (0)