Skip to content

Commit ee9aaac

Browse files
committed
Cleaning up a little of the implementation
1 parent 75915d6 commit ee9aaac

File tree

2 files changed

+32
-23
lines changed

2 files changed

+32
-23
lines changed

pkg/controller/kube_controller.go

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ import (
44
"context"
55
"fmt"
66
"io"
7+
"net/http"
78
"net/url"
9+
"slices"
810
"strings"
911
"time"
1012

@@ -22,6 +24,7 @@ const channelURLSuffix = "https://dl.k8s.io/release/"
2224

2325
type ClusterVersionScheduler struct {
2426
client kubernetes.Interface
27+
http http.Client
2528
log *logrus.Entry
2629
metrics *metrics.Metrics
2730
interval time.Duration
@@ -40,10 +43,18 @@ func NewKubeReconciler(
4043
log.Info("Kubernetes version checking disabled (no channel specified)")
4144
return nil
4245
}
46+
log = log.WithField("controller", "channel")
47+
48+
httpClient := retryablehttp.NewClient()
49+
httpClient.RetryMax = 3
50+
httpClient.RetryWaitMin = 1 * time.Second
51+
httpClient.RetryWaitMax = 30 * time.Second
52+
httpClient.Logger = log
4353

4454
return &ClusterVersionScheduler{
4555
log: log.WithField("channel", channel),
4656
client: kubernetes.NewForConfigOrDie(config),
57+
http: *httpClient.StandardClient(),
4758
interval: interval,
4859
metrics: metrics,
4960
channel: channel,
@@ -52,6 +63,7 @@ func NewKubeReconciler(
5263

5364
func (s *ClusterVersionScheduler) Start(ctx context.Context) error {
5465
go s.runScheduler(ctx)
66+
// Run an initial check on startup
5567
return s.reconcile()
5668
}
5769

@@ -83,7 +95,7 @@ func (s *ClusterVersionScheduler) reconcile() error {
8395
}
8496

8597
// Get latest version from specified channel
86-
latest, err := getLatestVersion(s.channel)
98+
latest, err := s.getLatestVersion(s.channel)
8799
if err != nil {
88100
return fmt.Errorf("fetching latest version from channel %s: %w", s.channel, err)
89101
}
@@ -114,13 +126,13 @@ func (s *ClusterVersionScheduler) reconcile() error {
114126
return nil
115127
}
116128

117-
func getLatestVersion(channel string) (string, error) {
129+
func (s *ClusterVersionScheduler) getLatestVersion(channel string) (string, error) {
118130
// Always use upstream Kubernetes channels - this is the authoritative source
119131
// Platform detection is kept for logging purposes only
120-
return getLatestVersionFromUpstream(channel)
132+
return s.getLatestVersionFromUpstream(channel)
121133
}
122134

123-
func getLatestVersionFromUpstream(channel string) (string, error) {
135+
func (s *ClusterVersionScheduler) getLatestVersionFromUpstream(channel string) (string, error) {
124136
// Validate channel - only allow known Kubernetes channels
125137
if !isValidKubernetesChannel(channel) {
126138
return "", fmt.Errorf("unsupported channel: %s. Valid channels: stable, latest, latest-1.xx", channel)
@@ -135,21 +147,20 @@ func getLatestVersionFromUpstream(channel string) (string, error) {
135147
return "", fmt.Errorf("failed to join channel URL: %w", err)
136148
}
137149

138-
client := retryablehttp.NewClient()
139-
client.RetryMax = 3
140-
client.RetryWaitMin = 1 * time.Second
141-
client.RetryWaitMax = 30 * time.Second
142-
client.Logger = nil
143-
144-
resp, err := client.Get(channelURL)
150+
resp, err := s.http.Get(channelURL)
145151
if err != nil {
146152
return "", fmt.Errorf("failed to fetch from channel URL %s: %w", channelURL, err)
147153
}
148-
defer resp.Body.Close()
154+
defer func() {
155+
_ = resp.Body.Close()
156+
}()
149157

150-
if resp.StatusCode != 200 {
158+
if resp.StatusCode != http.StatusOK {
151159
return "", fmt.Errorf("unexpected status code %d when fetching channel %s", resp.StatusCode, channel)
152160
}
161+
if resp.Header.Get("content-type") != "text/plain" {
162+
return "", fmt.Errorf("unexpected content-type %s when fetching channel %s", resp.Header.Get("content-type"), channel)
163+
}
153164

154165
body, err := io.ReadAll(resp.Body)
155166
if err != nil {
@@ -158,7 +169,7 @@ func getLatestVersionFromUpstream(channel string) (string, error) {
158169

159170
version := strings.TrimSpace(string(body))
160171
if version == "" {
161-
return "", fmt.Errorf("empty version returned from channel %s", channel)
172+
return "", fmt.Errorf("empty version received from channel %s", channel)
162173
}
163174

164175
return version, nil
@@ -168,15 +179,10 @@ func isValidKubernetesChannel(channel string) bool {
168179
// Only allow official Kubernetes channels
169180
validChannels := []string{"stable", "latest"}
170181

171-
// Allow latest-X.Y format
172-
if strings.HasPrefix(channel, "latest-1.") {
182+
// Allow latest-X.Y and stable-X.Y formats
183+
if strings.HasPrefix(channel, "latest-1.") || strings.HasPrefix(channel, "stable-1.") {
173184
return true
174185
}
175186

176-
for _, valid := range validChannels {
177-
if channel == valid {
178-
return true
179-
}
180-
}
181-
return false
187+
return slices.Contains(validChannels, channel)
182188
}

pkg/controller/kube_controller_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,10 @@ func TestGetLatestVersion(t *testing.T) {
220220
if err != nil {
221221
return "", err
222222
}
223-
defer resp.Body.Close()
223+
224+
defer func() {
225+
require.NoError(t, resp.Body.Close())
226+
}()
224227

225228
if resp.StatusCode != http.StatusOK {
226229
return "", assert.AnError

0 commit comments

Comments
 (0)