Skip to content

Commit 2913a86

Browse files
committed
feat(registry): refactor default registry options to use cached values and ensure thread safety
1 parent 33ad3e4 commit 2913a86

File tree

5 files changed

+40
-74
lines changed

5 files changed

+40
-74
lines changed

cmd/cli/commands/package.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ func newModelRunnerTarget(client *desktop.Client, tag string) (*modelRunnerTarge
414414
var err error
415415
// Normalize the tag to add default namespace (ai/) and tag (:latest) if missing
416416
normalizedTag := models.NormalizeModelName(tag)
417-
target.tag, err = name.NewTag(normalizedTag, getDefaultRegistryOptions()...)
417+
target.tag, err = name.NewTag(normalizedTag, registry.GetDefaultRegistryOptions()...)
418418
if err != nil {
419419
return nil, fmt.Errorf("invalid tag: %w", err)
420420
}

cmd/cli/commands/tag.go

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,16 @@ package commands
22

33
import (
44
"fmt"
5-
"os"
65
"strings"
76

87
"github.com/docker/model-runner/cmd/cli/commands/completion"
98
"github.com/docker/model-runner/cmd/cli/desktop"
9+
"github.com/docker/model-runner/pkg/distribution/registry"
1010
"github.com/docker/model-runner/pkg/inference/models"
1111
"github.com/google/go-containerregistry/pkg/name"
1212
"github.com/spf13/cobra"
1313
)
1414

15-
// getDefaultRegistryOptions returns name.Option slice with custom default registry
16-
// and insecure flag if the corresponding environment variables are set.
17-
// - DEFAULT_REGISTRY: Override the default registry (index.docker.io)
18-
// - INSECURE_REGISTRY: Set to "true" to allow HTTP connections
19-
func getDefaultRegistryOptions() []name.Option {
20-
var opts []name.Option
21-
if defaultReg := os.Getenv("DEFAULT_REGISTRY"); defaultReg != "" {
22-
opts = append(opts, name.WithDefaultRegistry(defaultReg))
23-
}
24-
if os.Getenv("INSECURE_REGISTRY") == "true" {
25-
opts = append(opts, name.Insecure)
26-
}
27-
return opts
28-
}
29-
3015
func newTagCmd() *cobra.Command {
3116
c := &cobra.Command{
3217
Use: "tag SOURCE TARGET",
@@ -58,7 +43,7 @@ func tagModel(cmd *cobra.Command, desktopClient *desktop.Client, source, target
5843
// Normalize target model name to add default org and tag if missing
5944
target = models.NormalizeModelName(target)
6045
// Ensure tag is valid
61-
tag, err := name.NewTag(target, getDefaultRegistryOptions()...)
46+
tag, err := name.NewTag(target, registry.GetDefaultRegistryOptions()...)
6247
if err != nil {
6348
return fmt.Errorf("invalid tag: %w", err)
6449
}

pkg/distribution/internal/store/index.go

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,17 @@ import (
88
"path/filepath"
99

1010
"github.com/google/go-containerregistry/pkg/name"
11-
)
1211

13-
// getDefaultRegistryOptions returns name.Option slice with custom default registry
14-
// and insecure flag if the corresponding environment variables are set.
15-
// - DEFAULT_REGISTRY: Override the default registry (index.docker.io)
16-
// - INSECURE_REGISTRY: Set to "true" to allow HTTP connections
17-
func getDefaultRegistryOptions() []name.Option {
18-
var opts []name.Option
19-
if defaultReg := os.Getenv("DEFAULT_REGISTRY"); defaultReg != "" {
20-
opts = append(opts, name.WithDefaultRegistry(defaultReg))
21-
}
22-
if os.Getenv("INSECURE_REGISTRY") == "true" {
23-
opts = append(opts, name.Insecure)
24-
}
25-
return opts
26-
}
12+
"github.com/docker/model-runner/pkg/distribution/registry"
13+
)
2714

2815
// Index represents the index of all models in the store
2916
type Index struct {
3017
Models []IndexEntry `json:"models"`
3118
}
3219

3320
func (i Index) Tag(reference string, tag string) (Index, error) {
34-
tagRef, err := name.NewTag(tag, getDefaultRegistryOptions()...)
21+
tagRef, err := name.NewTag(tag, registry.GetDefaultRegistryOptions()...)
3522
if err != nil {
3623
return Index{}, fmt.Errorf("invalid tag: %w", err)
3724
}
@@ -54,7 +41,7 @@ func (i Index) Tag(reference string, tag string) (Index, error) {
5441
}
5542

5643
func (i Index) UnTag(tag string) (name.Tag, Index, error) {
57-
tagRef, err := name.NewTag(tag, getDefaultRegistryOptions()...)
44+
tagRef, err := name.NewTag(tag, registry.GetDefaultRegistryOptions()...)
5845
if err != nil {
5946
return name.Tag{}, Index{}, err
6047
}
@@ -156,12 +143,12 @@ type IndexEntry struct {
156143
}
157144

158145
func (e IndexEntry) HasTag(tag string) bool {
159-
ref, err := name.NewTag(tag, getDefaultRegistryOptions()...)
146+
ref, err := name.NewTag(tag, registry.GetDefaultRegistryOptions()...)
160147
if err != nil {
161148
return false
162149
}
163150
for _, t := range e.Tags {
164-
tr, err := name.ParseReference(t, getDefaultRegistryOptions()...)
151+
tr, err := name.ParseReference(t, registry.GetDefaultRegistryOptions()...)
165152
if err != nil {
166153
continue
167154
}
@@ -174,7 +161,7 @@ func (e IndexEntry) HasTag(tag string) bool {
174161

175162
func (e IndexEntry) hasTag(tag name.Tag) bool {
176163
for _, t := range e.Tags {
177-
tr, err := name.ParseReference(t, getDefaultRegistryOptions()...)
164+
tr, err := name.ParseReference(t, registry.GetDefaultRegistryOptions()...)
178165
if err != nil {
179166
continue
180167
}
@@ -189,7 +176,7 @@ func (e IndexEntry) MatchesReference(reference string) bool {
189176
if e.ID == reference {
190177
return true
191178
}
192-
ref, err := name.ParseReference(reference, getDefaultRegistryOptions()...)
179+
ref, err := name.ParseReference(reference, registry.GetDefaultRegistryOptions()...)
193180
if err != nil {
194181
return false
195182
}
@@ -215,7 +202,7 @@ func (e IndexEntry) Tag(tag name.Tag) IndexEntry {
215202
func (e IndexEntry) UnTag(tag name.Tag) IndexEntry {
216203
var tags []string
217204
for i, t := range e.Tags {
218-
tr, err := name.ParseReference(t, getDefaultRegistryOptions()...)
205+
tr, err := name.ParseReference(t, registry.GetDefaultRegistryOptions()...)
219206
if err != nil {
220207
continue
221208
}

pkg/distribution/registry/client.go

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net/http"
88
"os"
99
"strings"
10+
"sync"
1011

1112
"github.com/google/go-containerregistry/pkg/authn"
1213
"github.com/google/go-containerregistry/pkg/name"
@@ -22,25 +23,32 @@ const (
2223
DefaultUserAgent = "model-distribution"
2324
)
2425

25-
// getDefaultRegistryOptions returns name.Option slice with custom default registry
26+
var (
27+
defaultRegistryOpts []name.Option
28+
once sync.Once
29+
DefaultTransport = remote.DefaultTransport
30+
)
31+
32+
// GetDefaultRegistryOptions returns name.Option slice with custom default registry
2633
// and insecure flag if the corresponding environment variables are set.
34+
// Environment variables are read once at first call and cached for consistency.
35+
// Returns a copy of the options to prevent race conditions from slice modifications.
2736
// - DEFAULT_REGISTRY: Override the default registry (index.docker.io)
2837
// - INSECURE_REGISTRY: Set to "true" to allow HTTP connections
29-
func getDefaultRegistryOptions() []name.Option {
30-
var opts []name.Option
31-
if defaultReg := os.Getenv("DEFAULT_REGISTRY"); defaultReg != "" {
32-
opts = append(opts, name.WithDefaultRegistry(defaultReg))
33-
}
34-
if os.Getenv("INSECURE_REGISTRY") == "true" {
35-
opts = append(opts, name.Insecure)
36-
}
37-
return opts
38+
func GetDefaultRegistryOptions() []name.Option {
39+
once.Do(func() {
40+
var opts []name.Option
41+
if defaultReg := os.Getenv("DEFAULT_REGISTRY"); defaultReg != "" {
42+
opts = append(opts, name.WithDefaultRegistry(defaultReg))
43+
}
44+
if os.Getenv("INSECURE_REGISTRY") == "true" {
45+
opts = append(opts, name.Insecure)
46+
}
47+
defaultRegistryOpts = opts
48+
})
49+
return append([]name.Option(nil), defaultRegistryOpts...)
3850
}
3951

40-
var (
41-
DefaultTransport = remote.DefaultTransport
42-
)
43-
4452
type Client struct {
4553
transport http.RoundTripper
4654
userAgent string
@@ -91,7 +99,7 @@ func NewClient(opts ...ClientOption) *Client {
9199

92100
func (c *Client) Model(ctx context.Context, reference string) (types.ModelArtifact, error) {
93101
// Parse the reference
94-
ref, err := name.ParseReference(reference, getDefaultRegistryOptions()...)
102+
ref, err := name.ParseReference(reference, GetDefaultRegistryOptions()...)
95103
if err != nil {
96104
return nil, NewReferenceError(reference, err)
97105
}
@@ -131,7 +139,7 @@ func (c *Client) Model(ctx context.Context, reference string) (types.ModelArtifa
131139

132140
func (c *Client) BlobURL(reference string, digest v1.Hash) (string, error) {
133141
// Parse the reference
134-
ref, err := name.ParseReference(reference, getDefaultRegistryOptions()...)
142+
ref, err := name.ParseReference(reference, GetDefaultRegistryOptions()...)
135143
if err != nil {
136144
return "", NewReferenceError(reference, err)
137145
}
@@ -145,7 +153,7 @@ func (c *Client) BlobURL(reference string, digest v1.Hash) (string, error) {
145153

146154
func (c *Client) BearerToken(ctx context.Context, reference string) (string, error) {
147155
// Parse the reference
148-
ref, err := name.ParseReference(reference, getDefaultRegistryOptions()...)
156+
ref, err := name.ParseReference(reference, GetDefaultRegistryOptions()...)
149157
if err != nil {
150158
return "", NewReferenceError(reference, err)
151159
}
@@ -181,7 +189,7 @@ type Target struct {
181189
}
182190

183191
func (c *Client) NewTarget(tag string) (*Target, error) {
184-
ref, err := name.NewTag(tag, getDefaultRegistryOptions()...)
192+
ref, err := name.NewTag(tag, GetDefaultRegistryOptions()...)
185193
if err != nil {
186194
return nil, fmt.Errorf("invalid tag: %q: %w", tag, err)
187195
}

pkg/metrics/metrics.go

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strings"
88
"time"
99

10+
"github.com/docker/model-runner/pkg/distribution/registry"
1011
"github.com/docker/model-runner/pkg/distribution/types"
1112
"github.com/docker/model-runner/pkg/logging"
1213
"github.com/google/go-containerregistry/pkg/authn"
@@ -15,21 +16,6 @@ import (
1516
"github.com/sirupsen/logrus"
1617
)
1718

18-
// getDefaultRegistryOptions returns name.Option slice with custom default registry
19-
// and insecure flag if the corresponding environment variables are set.
20-
// - DEFAULT_REGISTRY: Override the default registry (index.docker.io)
21-
// - INSECURE_REGISTRY: Set to "true" to allow HTTP connections
22-
func getDefaultRegistryOptions() []name.Option {
23-
var opts []name.Option
24-
if defaultReg := os.Getenv("DEFAULT_REGISTRY"); defaultReg != "" {
25-
opts = append(opts, name.WithDefaultRegistry(defaultReg))
26-
}
27-
if os.Getenv("INSECURE_REGISTRY") == "true" {
28-
opts = append(opts, name.Insecure)
29-
}
30-
return opts
31-
}
32-
3319
type Tracker struct {
3420
doNotTrack bool
3521
transport http.RoundTripper
@@ -101,7 +87,7 @@ func (t *Tracker) trackModel(model types.Model, userAgent, action string) {
10187
}
10288
ua := strings.Join(parts, " ")
10389
for _, tag := range tags {
104-
ref, err := name.ParseReference(tag, getDefaultRegistryOptions()...)
90+
ref, err := name.ParseReference(tag, registry.GetDefaultRegistryOptions()...)
10591
if err != nil {
10692
t.log.Errorf("Error parsing reference: %v\n", err)
10793
return

0 commit comments

Comments
 (0)