Skip to content

Commit 7fe7ed5

Browse files
authored
Merge pull request #85 from OpenVPN/fix/url-path-injection
Refactor API endpoint construction and add ID validation
2 parents 769ea18 + d6d17b9 commit 7fe7ed5

19 files changed

+429
-75
lines changed

cloudconnexa/access_groups.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ func (c *AccessGroupsService) List() ([]AccessGroup, error) {
9696

9797
// Get retrieves a specific access group by its ID from the CloudConnexa API.
9898
func (c *AccessGroupsService) Get(id string) (*AccessGroup, error) {
99-
endpoint := fmt.Sprintf("%s/access-groups/%s", c.client.GetV1Url(), id)
99+
if err := validateID(id); err != nil {
100+
return nil, err
101+
}
102+
endpoint := buildURL(c.client.GetV1Url(), "access-groups", id)
100103
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
101104
if err != nil {
102105
return nil, err
@@ -140,7 +143,7 @@ func (c *AccessGroupsService) Create(accessGroup *AccessGroup) (*AccessGroup, er
140143
return nil, err
141144
}
142145

143-
endpoint := fmt.Sprintf("%s/access-groups", c.client.GetV1Url())
146+
endpoint := buildURL(c.client.GetV1Url(), "access-groups")
144147

145148
req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(accessGroupJSON))
146149
if err != nil {
@@ -163,12 +166,15 @@ func (c *AccessGroupsService) Create(accessGroup *AccessGroup) (*AccessGroup, er
163166
// Update updates an existing access group in the CloudConnexa API.
164167
// It returns the updated access group.
165168
func (c *AccessGroupsService) Update(id string, accessGroup *AccessGroup) (*AccessGroup, error) {
169+
if err := validateID(id); err != nil {
170+
return nil, err
171+
}
166172
accessGroupJSON, err := json.Marshal(accessGroup)
167173
if err != nil {
168174
return nil, err
169175
}
170176

171-
endpoint := fmt.Sprintf("%s/access-groups/%s", c.client.GetV1Url(), id)
177+
endpoint := buildURL(c.client.GetV1Url(), "access-groups", id)
172178

173179
req, err := http.NewRequest(http.MethodPut, endpoint, bytes.NewBuffer(accessGroupJSON))
174180
if err != nil {
@@ -190,7 +196,10 @@ func (c *AccessGroupsService) Update(id string, accessGroup *AccessGroup) (*Acce
190196

191197
// Delete removes an access group from the CloudConnexa API by its ID.
192198
func (c *AccessGroupsService) Delete(id string) error {
193-
endpoint := fmt.Sprintf("%s/access-groups/%s", c.client.GetV1Url(), id)
199+
if err := validateID(id); err != nil {
200+
return err
201+
}
202+
endpoint := buildURL(c.client.GetV1Url(), "access-groups", id)
194203
req, err := http.NewRequest(http.MethodDelete, endpoint, nil)
195204
if err != nil {
196205
return err

cloudconnexa/cloudconnexa.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"net/http"
10+
"net/url"
1011
"strconv"
1112
"strings"
1213
"time"
@@ -243,3 +244,26 @@ func (c *Client) AssignLimits(res *http.Response, rateLimiter *rate.Limiter) err
243244
func (c *Client) GetV1Url() string {
244245
return c.BaseURL + "/api/v1"
245246
}
247+
248+
// buildURL constructs a URL with escaped path segments for safe API calls.
249+
// Example: buildURL(c.GetV1Url(), "users", userID, "activate")
250+
// Returns: https://api.example.com/api/v1/users/{escaped-id}/activate
251+
func buildURL(base string, segments ...string) string {
252+
if len(segments) == 0 {
253+
return base
254+
}
255+
escaped := make([]string, len(segments))
256+
for i, seg := range segments {
257+
escaped[i] = url.PathEscape(seg)
258+
}
259+
return base + "/" + strings.Join(escaped, "/")
260+
}
261+
262+
// validateID returns an error if the provided ID is empty.
263+
// This should be called before making API calls that require an ID parameter.
264+
func validateID(id string) error {
265+
if id == "" {
266+
return ErrEmptyID
267+
}
268+
return nil
269+
}

cloudconnexa/cloudconnexa_test.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,99 @@ func TestDoRequest(t *testing.T) {
111111
})
112112
}
113113
}
114+
115+
// TestBuildURL tests the buildURL function that constructs URLs with escaped path segments.
116+
func TestBuildURL(t *testing.T) {
117+
tests := []struct {
118+
name string
119+
base string
120+
segments []string
121+
expected string
122+
}{
123+
{
124+
name: "no segments",
125+
base: "https://api.example.com/v1",
126+
segments: []string{},
127+
expected: "https://api.example.com/v1",
128+
},
129+
{
130+
name: "single segment",
131+
base: "https://api.example.com/v1",
132+
segments: []string{"users"},
133+
expected: "https://api.example.com/v1/users",
134+
},
135+
{
136+
name: "multiple segments",
137+
base: "https://api.example.com/v1",
138+
segments: []string{"users", "abc-123", "activate"},
139+
expected: "https://api.example.com/v1/users/abc-123/activate",
140+
},
141+
{
142+
name: "path traversal escaped",
143+
base: "https://api.example.com/v1",
144+
segments: []string{"users", "../admin"},
145+
expected: "https://api.example.com/v1/users/..%2Fadmin",
146+
},
147+
{
148+
name: "forward slash escaped",
149+
base: "https://api.example.com/v1",
150+
segments: []string{"users", "user/admin"},
151+
expected: "https://api.example.com/v1/users/user%2Fadmin",
152+
},
153+
{
154+
name: "space escaped",
155+
base: "https://api.example.com/v1",
156+
segments: []string{"users", "user 123"},
157+
expected: "https://api.example.com/v1/users/user%20123",
158+
},
159+
{
160+
name: "question mark escaped",
161+
base: "https://api.example.com/v1",
162+
segments: []string{"users", "user?role=admin"},
163+
expected: "https://api.example.com/v1/users/user%3Frole=admin",
164+
},
165+
}
166+
167+
for _, tt := range tests {
168+
t.Run(tt.name, func(t *testing.T) {
169+
result := buildURL(tt.base, tt.segments...)
170+
assert.Equal(t, tt.expected, result)
171+
})
172+
}
173+
}
174+
175+
// TestValidateID tests the validateID function that validates ID parameters.
176+
func TestValidateID(t *testing.T) {
177+
tests := []struct {
178+
name string
179+
id string
180+
wantErr error
181+
}{
182+
{
183+
name: "valid ID",
184+
id: "abc-123",
185+
wantErr: nil,
186+
},
187+
{
188+
name: "valid UUID",
189+
id: "550e8400-e29b-41d4-a716-446655440000",
190+
wantErr: nil,
191+
},
192+
{
193+
name: "empty ID",
194+
id: "",
195+
wantErr: ErrEmptyID,
196+
},
197+
}
198+
199+
for _, tt := range tests {
200+
t.Run(tt.name, func(t *testing.T) {
201+
err := validateID(tt.id)
202+
if tt.wantErr != nil {
203+
assert.Equal(t, tt.wantErr, err)
204+
} else {
205+
assert.NoError(t, err)
206+
}
207+
})
208+
}
209+
}

cloudconnexa/devices.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ func (d *DevicesService) ListAll() ([]DeviceDetail, error) {
150150

151151
// GetByID retrieves a specific device by its ID.
152152
func (d *DevicesService) GetByID(deviceID string) (*DeviceDetail, error) {
153-
endpoint := fmt.Sprintf("%s/devices/%s", d.client.GetV1Url(), deviceID)
153+
if err := validateID(deviceID); err != nil {
154+
return nil, err
155+
}
156+
endpoint := buildURL(d.client.GetV1Url(), "devices", deviceID)
154157
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
155158
if err != nil {
156159
return nil, err
@@ -172,12 +175,15 @@ func (d *DevicesService) GetByID(deviceID string) (*DeviceDetail, error) {
172175

173176
// Update updates an existing device by its ID.
174177
func (d *DevicesService) Update(deviceID string, updateRequest DeviceUpdateRequest) (*DeviceDetail, error) {
178+
if err := validateID(deviceID); err != nil {
179+
return nil, err
180+
}
175181
requestJSON, err := json.Marshal(updateRequest)
176182
if err != nil {
177183
return nil, err
178184
}
179185

180-
endpoint := fmt.Sprintf("%s/devices/%s", d.client.GetV1Url(), deviceID)
186+
endpoint := buildURL(d.client.GetV1Url(), "devices", deviceID)
181187
req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(requestJSON))
182188
if err != nil {
183189
return nil, err

cloudconnexa/dns_records.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ func (c *DNSRecordsService) List() ([]DNSRecord, error) {
8383
// This is the preferred method for getting a single DNS record as it uses the direct
8484
// GET /api/v1/dns-records/{id} endpoint introduced in API v1.1.0.
8585
func (c *DNSRecordsService) GetByID(recordID string) (*DNSRecord, error) {
86-
endpoint := fmt.Sprintf("%s/dns-records/%s", c.client.GetV1Url(), recordID)
86+
if err := validateID(recordID); err != nil {
87+
return nil, err
88+
}
89+
endpoint := buildURL(c.client.GetV1Url(), "dns-records", recordID)
8790
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
8891
if err != nil {
8992
return nil, err
@@ -134,7 +137,7 @@ func (c *DNSRecordsService) Create(record DNSRecord) (*DNSRecord, error) {
134137
return nil, err
135138
}
136139

137-
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/dns-records", c.client.GetV1Url()), bytes.NewBuffer(recordJSON))
140+
req, err := http.NewRequest(http.MethodPost, buildURL(c.client.GetV1Url(), "dns-records"), bytes.NewBuffer(recordJSON))
138141
if err != nil {
139142
return nil, err
140143
}
@@ -154,12 +157,15 @@ func (c *DNSRecordsService) Create(record DNSRecord) (*DNSRecord, error) {
154157

155158
// Update updates an existing DNS record.
156159
func (c *DNSRecordsService) Update(record DNSRecord) error {
160+
if err := validateID(record.ID); err != nil {
161+
return err
162+
}
157163
recordJSON, err := json.Marshal(record)
158164
if err != nil {
159165
return err
160166
}
161167

162-
req, err := http.NewRequest(http.MethodPut, fmt.Sprintf("%s/dns-records/%s", c.client.GetV1Url(), record.ID), bytes.NewBuffer(recordJSON))
168+
req, err := http.NewRequest(http.MethodPut, buildURL(c.client.GetV1Url(), "dns-records", record.ID), bytes.NewBuffer(recordJSON))
163169
if err != nil {
164170
return err
165171
}
@@ -170,7 +176,10 @@ func (c *DNSRecordsService) Update(record DNSRecord) error {
170176

171177
// Delete deletes a DNS record by ID.
172178
func (c *DNSRecordsService) Delete(recordID string) error {
173-
req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("%s/dns-records/%s", c.client.GetV1Url(), recordID), nil)
179+
if err := validateID(recordID); err != nil {
180+
return err
181+
}
182+
req, err := http.NewRequest(http.MethodDelete, buildURL(c.client.GetV1Url(), "dns-records", recordID), nil)
174183
if err != nil {
175184
return err
176185
}

cloudconnexa/errors.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@ import "errors"
44

55
// ErrCredentialsRequired is returned when client ID or client secret is missing.
66
var ErrCredentialsRequired = errors.New("both client_id and client_secret credentials must be specified")
7+
8+
// ErrEmptyID is returned when an empty ID is provided to a method that requires one.
9+
var ErrEmptyID = errors.New("id cannot be empty")

cloudconnexa/host_applications.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,10 @@ func (c *HostApplicationsService) List() ([]ApplicationResponse, error) {
106106

107107
// Get retrieves a specific host application by its ID.
108108
func (c *HostApplicationsService) Get(id string) (*ApplicationResponse, error) {
109-
endpoint := fmt.Sprintf("%s/hosts/applications/%s", c.client.GetV1Url(), id)
109+
if err := validateID(id); err != nil {
110+
return nil, err
111+
}
112+
endpoint := buildURL(c.client.GetV1Url(), "hosts", "applications", id)
110113
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
111114
if err != nil {
112115
return nil, err
@@ -178,12 +181,15 @@ func (c *HostApplicationsService) Create(application *Application) (*Application
178181

179182
// Update updates an existing host application by its ID.
180183
func (c *HostApplicationsService) Update(id string, application *Application) (*ApplicationResponse, error) {
184+
if err := validateID(id); err != nil {
185+
return nil, err
186+
}
181187
applicationJSON, err := json.Marshal(application)
182188
if err != nil {
183189
return nil, err
184190
}
185191

186-
endpoint := fmt.Sprintf("%s/hosts/applications/%s", c.client.GetV1Url(), id)
192+
endpoint := buildURL(c.client.GetV1Url(), "hosts", "applications", id)
187193

188194
req, err := http.NewRequest(http.MethodPut, endpoint, bytes.NewBuffer(applicationJSON))
189195
if err != nil {
@@ -205,7 +211,10 @@ func (c *HostApplicationsService) Update(id string, application *Application) (*
205211

206212
// Delete removes a host application by its ID.
207213
func (c *HostApplicationsService) Delete(id string) error {
208-
endpoint := fmt.Sprintf("%s/hosts/applications/%s", c.client.GetV1Url(), id)
214+
if err := validateID(id); err != nil {
215+
return err
216+
}
217+
endpoint := buildURL(c.client.GetV1Url(), "hosts", "applications", id)
209218
req, err := http.NewRequest(http.MethodDelete, endpoint, nil)
210219
if err != nil {
211220
return err

0 commit comments

Comments
 (0)