Skip to content

Commit 715f8a6

Browse files
committed
fix: data race on cfg and make test command
- The assignment `*cfg = *newCfg` in `func (cfg *Config) reload(path string)` could lead to a data race. Avoid this by refactoring using `atomic.Value`. - Previously, `make test` didn't actually work, this commit fixes it as well. Signed-off-by: imeoer <yansong.ys@antgroup.com>
1 parent 8a994c0 commit 715f8a6

25 files changed

+210
-180
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/model-csi-driver
22
/model-csi-cli
3-
/unit.test
3+
/server.test
44
/output
55
/cover.out.tmp
66
/coverage.log

Makefile

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,14 @@ REVISION=$(shell git rev-parse HEAD)$(shell if ! git diff --no-ext-diff --quiet
1313

1414
RELEASE_INFO = -X main.revision=${REVISION} -X main.gitVersion=${VERSION} -X main.buildTime=${BUILD_TIMESTAMP}
1515

16-
.PHONY: release
16+
.PHONY: release test
1717

1818
release:
1919
@CGO_ENABLED=0 ${PROXY} GOOS=linux GOARCH=${GOARCH} go vet -tags disable_libgit2 $(PACKAGES)
2020
@CGO_ENABLED=0 ${PROXY} GOOS=linux GOARCH=${GOARCH} go build -tags disable_libgit2 -ldflags '${RELEASE_INFO} -w -extldflags "-static"' -o ./ ./cmd/model-csi-driver
2121
@CGO_ENABLED=0 ${PROXY} GOOS=linux GOARCH=${GOARCH} go build -tags disable_libgit2 -ldflags '${RELEASE_INFO} -w -extldflags "-static"' -o ./ ./cmd/model-csi-cli
2222

2323
test:
24-
@CGO_ENABLED=1 go test -tags disable_libgit2 -coverprofile cover.out.tmp -race -v -timeout 10m github.com/modelpack/model-csi-driver/pkg/server | tee coverage.log
25-
26-
test-local:
27-
go test -tags disable_libgit2 -race -c -o ./unit.test github.com/modelpack/model-csi-driver/pkg/server
28-
sudo CONFIG_PATH=./test/testdata/config.test.yaml ./unit.test -test.timeout 1h -test.v -test.run ^TestServer$
24+
go list ./... | grep -v -E github.com/modelpack/model-csi-driver/pkg/server | xargs go test -tags disable_libgit2 -race -v -timeout 10m
25+
go test -tags disable_libgit2 -race -c -o ./server.test github.com/modelpack/model-csi-driver/pkg/server
26+
sudo CONFIG_PATH=./test/testdata/config.test.yaml ./server.test -test.v -test.timeout 10m

pkg/client/grpc.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func NewGRPCClient(cfg *config.Config, addr string) (*GRPCClient, error) {
4343
invoker grpc.UnaryInvoker,
4444
opts ...grpc.CallOption,
4545
) error {
46-
newCtx := metadata.AppendToOutgoingContext(ctx, authTokenKey, cfg.ExternalCSIAuthorization)
46+
newCtx := metadata.AppendToOutgoingContext(ctx, authTokenKey, cfg.Get().ExternalCSIAuthorization)
4747
return invoker(newCtx, method, req, reply, cc, opts...)
4848
}),
4949
)
@@ -119,7 +119,7 @@ func (c *GRPCClient) PublishStaticInlineVolume(ctx context.Context, volumeID, ta
119119
VolumeId: volumeID,
120120
TargetPath: targetPath,
121121
VolumeContext: map[string]string{
122-
c.cfg.ParameterKeyReference(): reference,
122+
c.cfg.Get().ParameterKeyReference(): reference,
123123
},
124124
})
125125
if err != nil {

pkg/config/config.go

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import (
44
"net/url"
55
"os"
66
"path/filepath"
7+
"sync/atomic"
78

89
"github.com/dustin/go-humanize"
10+
"github.com/modelpack/model-csi-driver/pkg/logger"
911
"github.com/pkg/errors"
1012
"gopkg.in/yaml.v2"
1113
)
@@ -28,7 +30,7 @@ func (s *HumanizeSize) UnmarshalYAML(unmarshal func(interface{}) error) error {
2830
return nil
2931
}
3032

31-
type Config struct {
33+
type RawConfig struct {
3234
// Pattern:
3335
// static: /var/lib/dragonfly/model-csi/volumes/$volumeName/model
3436
// dynamic: /var/lib/dragonfly/model-csi/volumes/$volumeName/models
@@ -40,7 +42,7 @@ type Config struct {
4042
DynamicCSIEndpoint string `yaml:"dynamic_csi_endpoint"`
4143
CSIEndpoint string `yaml:"csi_endpoint"`
4244
MetricsAddr string `yaml:"metrics_addr"`
43-
TraceEndpooint string `yaml:"trace_endpoint"`
45+
TraceEndpoint string `yaml:"trace_endpoint"`
4446
PprofAddr string `yaml:"pprof_addr"`
4547
PullConfig PullConfig `yaml:"pull_config"`
4648
Features Features `yaml:"features"`
@@ -61,89 +63,89 @@ type PullConfig struct {
6163
PullLayerTimeoutInSeconds uint `yaml:"pull_layer_timeout_in_seconds"`
6264
}
6365

64-
func (cfg *Config) ParameterKeyType() string {
66+
func (cfg *RawConfig) ParameterKeyType() string {
6567
return cfg.ServiceName + "/type"
6668
}
6769

68-
func (cfg *Config) ParameterKeyReference() string {
70+
func (cfg *RawConfig) ParameterKeyReference() string {
6971
return cfg.ServiceName + "/reference"
7072
}
7173

72-
func (cfg *Config) ParameterKeyMountID() string {
74+
func (cfg *RawConfig) ParameterKeyMountID() string {
7375
return cfg.ServiceName + "/mount-id"
7476
}
7577

76-
func (cfg *Config) ParameterKeyStatusState() string {
78+
func (cfg *RawConfig) ParameterKeyStatusState() string {
7779
return cfg.ServiceName + "/status/state"
7880
}
7981

80-
func (cfg *Config) ParameterKeyStatusProgress() string {
82+
func (cfg *RawConfig) ParameterKeyStatusProgress() string {
8183
return cfg.ServiceName + "/status/progress"
8284
}
8385

84-
func (cfg *Config) ParameterVolumeContextNodeIP() string {
86+
func (cfg *RawConfig) ParameterVolumeContextNodeIP() string {
8587
return cfg.ServiceName + "/node-ip"
8688
}
8789

88-
func (cfg *Config) ParameterKeyCheckDiskQuota() string {
90+
func (cfg *RawConfig) ParameterKeyCheckDiskQuota() string {
8991
return cfg.ServiceName + "/check-disk-quota"
9092
}
9193

9294
// /var/lib/dragonfly/model-csi/volumes
93-
func (cfg *Config) GetVolumesDir() string {
95+
func (cfg *RawConfig) GetVolumesDir() string {
9496
return filepath.Join(cfg.RootDir, "volumes")
9597
}
9698

9799
// /var/lib/dragonfly/model-csi/volumes/$volumeName
98-
func (cfg *Config) GetVolumeDir(volumeName string) string {
100+
func (cfg *RawConfig) GetVolumeDir(volumeName string) string {
99101
return filepath.Join(cfg.GetVolumesDir(), volumeName)
100102
}
101103

102104
// /var/lib/dragonfly/model-csi/volumes/$volumeName/model
103-
func (cfg *Config) GetModelDir(volumeName string) string {
105+
func (cfg *RawConfig) GetModelDir(volumeName string) string {
104106
return filepath.Join(cfg.GetVolumesDir(), volumeName, "model")
105107
}
106108

107109
// /var/lib/dragonfly/model-csi/volumes/$volumeName
108-
func (cfg *Config) GetVolumeDirForDynamic(volumeName string) string {
110+
func (cfg *RawConfig) GetVolumeDirForDynamic(volumeName string) string {
109111
return filepath.Join(cfg.GetVolumesDir(), volumeName)
110112
}
111113

112114
// /var/lib/dragonfly/model-csi/volumes/$volumeName/models
113-
func (cfg *Config) GetModelsDirForDynamic(volumeName string) string {
115+
func (cfg *RawConfig) GetModelsDirForDynamic(volumeName string) string {
114116
return filepath.Join(cfg.GetVolumeDirForDynamic(volumeName), "models")
115117
}
116118

117119
// /var/lib/dragonfly/model-csi/volumes/$volumeName/models/$mountID
118-
func (cfg *Config) GetMountIDDirForDynamic(volumeName, mountID string) string {
120+
func (cfg *RawConfig) GetMountIDDirForDynamic(volumeName, mountID string) string {
119121
return filepath.Join(cfg.GetVolumeDirForDynamic(volumeName), "models", mountID)
120122
}
121123

122124
// /var/lib/dragonfly/model-csi/volumes/$volumeName/models/$mountID/model
123-
func (cfg *Config) GetModelDirForDynamic(volumeName, mountID string) string {
125+
func (cfg *RawConfig) GetModelDirForDynamic(volumeName, mountID string) string {
124126
return filepath.Join(cfg.GetVolumeDirForDynamic(volumeName), "models", mountID, "model")
125127
}
126128

127129
// /var/lib/dragonfly/model-csi/volumes/$volumeName/csi
128-
func (cfg *Config) GetCSISockDirForDynamic(volumeName string) string {
130+
func (cfg *RawConfig) GetCSISockDirForDynamic(volumeName string) string {
129131
return filepath.Join(cfg.GetVolumeDirForDynamic(volumeName), "csi")
130132
}
131133

132-
func (cfg *Config) IsControllerMode() bool {
134+
func (cfg *RawConfig) IsControllerMode() bool {
133135
return cfg.Mode == "controller"
134136
}
135137

136-
func (cfg *Config) IsNodeMode() bool {
138+
func (cfg *RawConfig) IsNodeMode() bool {
137139
return cfg.Mode == "node"
138140
}
139141

140-
func parse(path string) (*Config, error) {
142+
func parse(path string) (*RawConfig, error) {
141143
data, err := os.ReadFile(path)
142144
if err != nil {
143145
return nil, errors.Wrap(err, "read config file")
144146
}
145147

146-
var cfg Config
148+
var cfg RawConfig
147149
if err := yaml.Unmarshal(data, &cfg); err != nil {
148150
return nil, errors.Wrap(err, "unmarshal config file")
149151
}
@@ -206,13 +208,46 @@ func parse(path string) (*Config, error) {
206208
return &cfg, nil
207209
}
208210

211+
type Config struct {
212+
atomic.Value
213+
}
214+
209215
func New(path string) (*Config, error) {
210216
cfg, err := parse(path)
211217
if err != nil {
212218
return nil, err
213219
}
214220

215-
go cfg.watch(path)
221+
atomicCfg := NewWithRaw(cfg)
222+
223+
go atomicCfg.watch(path)
224+
225+
return atomicCfg, nil
226+
}
227+
228+
func NewWithRaw(cfg *RawConfig) *Config {
229+
atomicCfg := &Config{
230+
Value: atomic.Value{},
231+
}
232+
atomicCfg.Store(cfg)
233+
return atomicCfg
234+
}
235+
236+
func (cfg *Config) Get() *RawConfig {
237+
return cfg.Load().(*RawConfig)
238+
}
239+
240+
func (cfg *Config) reload(path string) {
241+
newCfg, err := parse(path)
242+
if err != nil {
243+
logger.Logger().WithError(err).Error("failed to parse config file")
244+
return
245+
}
246+
247+
mutex.Lock()
248+
defer mutex.Unlock()
249+
250+
cfg.Store(newCfg)
216251

217-
return cfg, nil
252+
logger.Logger().Infof("config reloaded: %s", path)
218253
}

pkg/config/config_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,5 @@ func TestConfig(t *testing.T) {
5050
time.Sleep(time.Second * 1)
5151

5252
// Verify the config is reloaded
53-
require.Equal(t, uint64(0x50000000000), uint64(cfg.Features.DiskUsageLimit))
53+
require.Equal(t, uint64(0x50000000000), uint64(cfg.Get().Features.DiskUsageLimit))
5454
}

pkg/config/watcher.go

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,3 @@ func (cfg *Config) watch(path string) {
4848

4949
select {}
5050
}
51-
52-
func (cfg *Config) reload(path string) {
53-
newCfg, err := parse(path)
54-
if err != nil {
55-
logger.Logger().WithError(err).Error("failed to parse config file")
56-
return
57-
}
58-
59-
mutex.Lock()
60-
defer mutex.Unlock()
61-
62-
*cfg = *newCfg
63-
64-
logger.Logger().Infof("config reloaded: %s", path)
65-
}

pkg/provider/provider.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func New(cfg *config.Config, svc *service.Service) (gocsi.StoragePluginProvider,
2828
sp *gocsi.StoragePlugin,
2929
lis net.Listener) error {
3030

31-
log.WithField("service", cfg.ServiceName).Debug("BeforeServe")
31+
log.WithField("service", cfg.Get().ServiceName).Debug("BeforeServe")
3232
return nil
3333
},
3434

pkg/server/http.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ type ErrorResponse struct {
3434
func NewHTTPServer(cfg *config.Config, svc *service.Service) (*HttpServer, error) {
3535
echo := echo.New()
3636

37-
endpoint, err := url.Parse(cfg.DynamicCSIEndpoint)
37+
endpoint, err := url.Parse(cfg.Get().DynamicCSIEndpoint)
3838
if err != nil {
39-
return nil, errors.Wrapf(err, "parse dynamic csi endpoint: %s", cfg.DynamicCSIEndpoint)
39+
return nil, errors.Wrapf(err, "parse dynamic csi endpoint: %s", cfg.Get().DynamicCSIEndpoint)
4040
}
4141

4242
listener, err := net.Listen("unix", endpoint.Path)

pkg/server/http_handler.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ func (h *HttpHandler) CreateVolume(c echo.Context) error {
9090
_, err := h.svc.CreateVolume(c.Request().Context(), &csi.CreateVolumeRequest{
9191
Name: volumeName,
9292
Parameters: map[string]string{
93-
h.cfg.ParameterKeyType(): "image",
94-
h.cfg.ParameterKeyReference(): req.Reference,
95-
h.cfg.ParameterKeyMountID(): req.MountID,
96-
h.cfg.ParameterKeyCheckDiskQuota(): strconv.FormatBool(req.CheckDiskQuota),
93+
h.cfg.Get().ParameterKeyType(): "image",
94+
h.cfg.Get().ParameterKeyReference(): req.Reference,
95+
h.cfg.Get().ParameterKeyMountID(): req.MountID,
96+
h.cfg.Get().ParameterKeyCheckDiskQuota(): strconv.FormatBool(req.CheckDiskQuota),
9797
},
9898
})
9999
if err != nil {

0 commit comments

Comments
 (0)