Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions cloudconnexa/access_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ func (c *AccessGroupsService) List() ([]AccessGroup, error) {

// Get retrieves a specific access group by its ID from the CloudConnexa API.
func (c *AccessGroupsService) Get(id string) (*AccessGroup, error) {
endpoint := fmt.Sprintf("%s/access-groups/%s", c.client.GetV1Url(), id)
if err := validateID(id); err != nil {
return nil, err
}
endpoint := buildURL(c.client.GetV1Url(), "access-groups", id)
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
Expand Down Expand Up @@ -140,7 +143,7 @@ func (c *AccessGroupsService) Create(accessGroup *AccessGroup) (*AccessGroup, er
return nil, err
}

endpoint := fmt.Sprintf("%s/access-groups", c.client.GetV1Url())
endpoint := buildURL(c.client.GetV1Url(), "access-groups")

req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(accessGroupJSON))
if err != nil {
Expand All @@ -163,12 +166,15 @@ func (c *AccessGroupsService) Create(accessGroup *AccessGroup) (*AccessGroup, er
// Update updates an existing access group in the CloudConnexa API.
// It returns the updated access group.
func (c *AccessGroupsService) Update(id string, accessGroup *AccessGroup) (*AccessGroup, error) {
if err := validateID(id); err != nil {
return nil, err
}
accessGroupJSON, err := json.Marshal(accessGroup)
if err != nil {
return nil, err
}

endpoint := fmt.Sprintf("%s/access-groups/%s", c.client.GetV1Url(), id)
endpoint := buildURL(c.client.GetV1Url(), "access-groups", id)

req, err := http.NewRequest(http.MethodPut, endpoint, bytes.NewBuffer(accessGroupJSON))
if err != nil {
Expand All @@ -190,7 +196,10 @@ func (c *AccessGroupsService) Update(id string, accessGroup *AccessGroup) (*Acce

// Delete removes an access group from the CloudConnexa API by its ID.
func (c *AccessGroupsService) Delete(id string) error {
endpoint := fmt.Sprintf("%s/access-groups/%s", c.client.GetV1Url(), id)
if err := validateID(id); err != nil {
return err
}
endpoint := buildURL(c.client.GetV1Url(), "access-groups", id)
req, err := http.NewRequest(http.MethodDelete, endpoint, nil)
if err != nil {
return err
Expand Down
24 changes: 24 additions & 0 deletions cloudconnexa/cloudconnexa.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -243,3 +244,26 @@ func (c *Client) AssignLimits(res *http.Response, rateLimiter *rate.Limiter) err
func (c *Client) GetV1Url() string {
return c.BaseURL + "/api/v1"
}

// buildURL constructs a URL with escaped path segments for safe API calls.
// Example: buildURL(c.GetV1Url(), "users", userID, "activate")
// Returns: https://api.example.com/api/v1/users/{escaped-id}/activate
func buildURL(base string, segments ...string) string {
if len(segments) == 0 {
return base
}
escaped := make([]string, len(segments))
for i, seg := range segments {
escaped[i] = url.PathEscape(seg)
}
return base + "/" + strings.Join(escaped, "/")
}

// validateID returns an error if the provided ID is empty.
// This should be called before making API calls that require an ID parameter.
func validateID(id string) error {
if id == "" {
return ErrEmptyID
}
return nil
}
96 changes: 96 additions & 0 deletions cloudconnexa/cloudconnexa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,99 @@ func TestDoRequest(t *testing.T) {
})
}
}

// TestBuildURL tests the buildURL function that constructs URLs with escaped path segments.
func TestBuildURL(t *testing.T) {
tests := []struct {
name string
base string
segments []string
expected string
}{
{
name: "no segments",
base: "https://api.example.com/v1",
segments: []string{},
expected: "https://api.example.com/v1",
},
{
name: "single segment",
base: "https://api.example.com/v1",
segments: []string{"users"},
expected: "https://api.example.com/v1/users",
},
{
name: "multiple segments",
base: "https://api.example.com/v1",
segments: []string{"users", "abc-123", "activate"},
expected: "https://api.example.com/v1/users/abc-123/activate",
},
{
name: "path traversal escaped",
base: "https://api.example.com/v1",
segments: []string{"users", "../admin"},
expected: "https://api.example.com/v1/users/..%2Fadmin",
},
{
name: "forward slash escaped",
base: "https://api.example.com/v1",
segments: []string{"users", "user/admin"},
expected: "https://api.example.com/v1/users/user%2Fadmin",
},
{
name: "space escaped",
base: "https://api.example.com/v1",
segments: []string{"users", "user 123"},
expected: "https://api.example.com/v1/users/user%20123",
},
{
name: "question mark escaped",
base: "https://api.example.com/v1",
segments: []string{"users", "user?role=admin"},
expected: "https://api.example.com/v1/users/user%3Frole=admin",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildURL(tt.base, tt.segments...)
assert.Equal(t, tt.expected, result)
})
}
}

// TestValidateID tests the validateID function that validates ID parameters.
func TestValidateID(t *testing.T) {
tests := []struct {
name string
id string
wantErr error
}{
{
name: "valid ID",
id: "abc-123",
wantErr: nil,
},
{
name: "valid UUID",
id: "550e8400-e29b-41d4-a716-446655440000",
wantErr: nil,
},
{
name: "empty ID",
id: "",
wantErr: ErrEmptyID,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateID(tt.id)
if tt.wantErr != nil {
assert.Equal(t, tt.wantErr, err)
} else {
assert.NoError(t, err)
}
})
}
}
10 changes: 8 additions & 2 deletions cloudconnexa/devices.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ func (d *DevicesService) ListAll() ([]DeviceDetail, error) {

// GetByID retrieves a specific device by its ID.
func (d *DevicesService) GetByID(deviceID string) (*DeviceDetail, error) {
endpoint := fmt.Sprintf("%s/devices/%s", d.client.GetV1Url(), deviceID)
if err := validateID(deviceID); err != nil {
return nil, err
}
endpoint := buildURL(d.client.GetV1Url(), "devices", deviceID)
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
Expand All @@ -172,12 +175,15 @@ func (d *DevicesService) GetByID(deviceID string) (*DeviceDetail, error) {

// Update updates an existing device by its ID.
func (d *DevicesService) Update(deviceID string, updateRequest DeviceUpdateRequest) (*DeviceDetail, error) {
if err := validateID(deviceID); err != nil {
return nil, err
}
requestJSON, err := json.Marshal(updateRequest)
if err != nil {
return nil, err
}

endpoint := fmt.Sprintf("%s/devices/%s", d.client.GetV1Url(), deviceID)
endpoint := buildURL(d.client.GetV1Url(), "devices", deviceID)
req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(requestJSON))
if err != nil {
return nil, err
Expand Down
17 changes: 13 additions & 4 deletions cloudconnexa/dns_records.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ func (c *DNSRecordsService) List() ([]DNSRecord, error) {
// This is the preferred method for getting a single DNS record as it uses the direct
// GET /api/v1/dns-records/{id} endpoint introduced in API v1.1.0.
func (c *DNSRecordsService) GetByID(recordID string) (*DNSRecord, error) {
endpoint := fmt.Sprintf("%s/dns-records/%s", c.client.GetV1Url(), recordID)
if err := validateID(recordID); err != nil {
return nil, err
}
endpoint := buildURL(c.client.GetV1Url(), "dns-records", recordID)
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
Expand Down Expand Up @@ -134,7 +137,7 @@ func (c *DNSRecordsService) Create(record DNSRecord) (*DNSRecord, error) {
return nil, err
}

req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/dns-records", c.client.GetV1Url()), bytes.NewBuffer(recordJSON))
req, err := http.NewRequest(http.MethodPost, buildURL(c.client.GetV1Url(), "dns-records"), bytes.NewBuffer(recordJSON))
if err != nil {
return nil, err
}
Expand All @@ -154,12 +157,15 @@ func (c *DNSRecordsService) Create(record DNSRecord) (*DNSRecord, error) {

// Update updates an existing DNS record.
func (c *DNSRecordsService) Update(record DNSRecord) error {
if err := validateID(record.ID); err != nil {
return err
}
recordJSON, err := json.Marshal(record)
if err != nil {
return err
}

req, err := http.NewRequest(http.MethodPut, fmt.Sprintf("%s/dns-records/%s", c.client.GetV1Url(), record.ID), bytes.NewBuffer(recordJSON))
req, err := http.NewRequest(http.MethodPut, buildURL(c.client.GetV1Url(), "dns-records", record.ID), bytes.NewBuffer(recordJSON))
if err != nil {
return err
}
Expand All @@ -170,7 +176,10 @@ func (c *DNSRecordsService) Update(record DNSRecord) error {

// Delete deletes a DNS record by ID.
func (c *DNSRecordsService) Delete(recordID string) error {
req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("%s/dns-records/%s", c.client.GetV1Url(), recordID), nil)
if err := validateID(recordID); err != nil {
return err
}
req, err := http.NewRequest(http.MethodDelete, buildURL(c.client.GetV1Url(), "dns-records", recordID), nil)
if err != nil {
return err
}
Expand Down
3 changes: 3 additions & 0 deletions cloudconnexa/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ import "errors"

// ErrCredentialsRequired is returned when client ID or client secret is missing.
var ErrCredentialsRequired = errors.New("both client_id and client_secret credentials must be specified")

// ErrEmptyID is returned when an empty ID is provided to a method that requires one.
var ErrEmptyID = errors.New("id cannot be empty")
15 changes: 12 additions & 3 deletions cloudconnexa/host_applications.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ func (c *HostApplicationsService) List() ([]ApplicationResponse, error) {

// Get retrieves a specific host application by its ID.
func (c *HostApplicationsService) Get(id string) (*ApplicationResponse, error) {
endpoint := fmt.Sprintf("%s/hosts/applications/%s", c.client.GetV1Url(), id)
if err := validateID(id); err != nil {
return nil, err
}
endpoint := buildURL(c.client.GetV1Url(), "hosts", "applications", id)
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
Expand Down Expand Up @@ -178,12 +181,15 @@ func (c *HostApplicationsService) Create(application *Application) (*Application

// Update updates an existing host application by its ID.
func (c *HostApplicationsService) Update(id string, application *Application) (*ApplicationResponse, error) {
if err := validateID(id); err != nil {
return nil, err
}
applicationJSON, err := json.Marshal(application)
if err != nil {
return nil, err
}

endpoint := fmt.Sprintf("%s/hosts/applications/%s", c.client.GetV1Url(), id)
endpoint := buildURL(c.client.GetV1Url(), "hosts", "applications", id)

req, err := http.NewRequest(http.MethodPut, endpoint, bytes.NewBuffer(applicationJSON))
if err != nil {
Expand All @@ -205,7 +211,10 @@ func (c *HostApplicationsService) Update(id string, application *Application) (*

// Delete removes a host application by its ID.
func (c *HostApplicationsService) Delete(id string) error {
endpoint := fmt.Sprintf("%s/hosts/applications/%s", c.client.GetV1Url(), id)
if err := validateID(id); err != nil {
return err
}
endpoint := buildURL(c.client.GetV1Url(), "hosts", "applications", id)
req, err := http.NewRequest(http.MethodDelete, endpoint, nil)
if err != nil {
return err
Expand Down
Loading
Loading