Skip to content

Commit 3505de3

Browse files
committed
feat: support exclude_model_weights option
The option ignores weight layers when mounting, useful for model parameter distribution. Signed-off-by: imeoer <[email protected]>
1 parent 6fbe66c commit 3505de3

File tree

10 files changed

+156
-37
lines changed

10 files changed

+156
-37
lines changed

pkg/config/config.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ func (cfg *RawConfig) ParameterKeyCheckDiskQuota() string {
9191
return cfg.ServiceName + "/check-disk-quota"
9292
}
9393

94+
func (cfg *RawConfig) ParameterKeyExcludeModelWeights() string {
95+
return cfg.ServiceName + "/exclude-model-weights"
96+
}
97+
9498
// /var/lib/dragonfly/model-csi/volumes
9599
func (cfg *RawConfig) GetVolumesDir() string {
96100
return filepath.Join(cfg.RootDir, "volumes")
@@ -203,6 +207,10 @@ func parse(path string) (*RawConfig, error) {
203207
return nil, errors.Wrapf(err, "check dragonfly endpoint: %s", endpoint.Path)
204208
}
205209
}
210+
211+
if cfg.PullConfig.Concurrency == 0 {
212+
cfg.PullConfig.Concurrency = 5
213+
}
206214
}
207215

208216
return &cfg, nil

pkg/server/http_handler.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,11 @@ func (h *HttpHandler) CreateVolume(c echo.Context) error {
9090
_, err := h.svc.CreateVolume(c.Request().Context(), &csi.CreateVolumeRequest{
9191
Name: volumeName,
9292
Parameters: map[string]string{
93-
h.cfg.Get().ParameterKeyType(): "image",
94-
h.cfg.Get().ParameterKeyReference(): req.Reference,
95-
h.cfg.Get().ParameterKeyMountID(): req.MountID,
96-
h.cfg.Get().ParameterKeyCheckDiskQuota(): strconv.FormatBool(req.CheckDiskQuota),
93+
h.cfg.Get().ParameterKeyType(): "image",
94+
h.cfg.Get().ParameterKeyReference(): req.Reference,
95+
h.cfg.Get().ParameterKeyMountID(): req.MountID,
96+
h.cfg.Get().ParameterKeyCheckDiskQuota(): strconv.FormatBool(req.CheckDiskQuota),
97+
h.cfg.Get().ParameterKeyExcludeModelWeights(): strconv.FormatBool(req.ExcludeModelWeights),
9798
},
9899
})
99100
if err != nil {

pkg/server/server_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ type mockPuller struct {
4040
hook *service.Hook
4141
}
4242

43-
func (puller *mockPuller) Pull(ctx context.Context, reference, targetDir string) error {
43+
func (puller *mockPuller) Pull(
44+
ctx context.Context, reference, targetDir string, excludeModelWeights bool,
45+
) error {
4446
if err := os.MkdirAll(targetDir, 0755); err != nil {
4547
return err
4648
}

pkg/service/controller_local.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ func (s *Service) localCreateVolume(ctx context.Context, req *csi.CreateVolumeRe
3333
modelReference := strings.TrimSpace(parameters[s.cfg.Get().ParameterKeyReference()])
3434
mountID := strings.TrimSpace(parameters[s.cfg.Get().ParameterKeyMountID()])
3535
checkDiskQuotaParam := strings.TrimSpace(parameters[s.cfg.Get().ParameterKeyCheckDiskQuota()])
36+
excludeModelWeightsParam := strings.TrimSpace(parameters[s.cfg.Get().ParameterKeyExcludeModelWeights()])
3637
isStaticVolume := mountID == ""
3738

3839
if volumeName == "" {
@@ -58,6 +59,14 @@ func (s *Service) localCreateVolume(ctx context.Context, req *csi.CreateVolumeRe
5859
return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: %v", s.cfg.Get().ParameterKeyCheckDiskQuota(), err)
5960
}
6061
}
62+
excludeModelWeights := false
63+
if excludeModelWeightsParam != "" {
64+
var err error
65+
excludeModelWeights, err = strconv.ParseBool(excludeModelWeightsParam)
66+
if err != nil {
67+
return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: %v", s.cfg.Get().ParameterKeyExcludeModelWeights(), err)
68+
}
69+
}
6170

6271
parentSpan := trace.SpanFromContext(ctx)
6372
parentSpan.SetAttributes(attribute.String("volume_name", volumeName))
@@ -69,7 +78,7 @@ func (s *Service) localCreateVolume(ctx context.Context, req *csi.CreateVolumeRe
6978
startedAt := time.Now()
7079
ctx, span := tracing.Tracer.Start(ctx, "PullModel")
7180
span.SetAttributes(attribute.String("model_dir", modelDir))
72-
if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, "", modelReference, modelDir, checkDiskQuota); err != nil {
81+
if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, "", modelReference, modelDir, checkDiskQuota, excludeModelWeights); err != nil {
7382
span.SetStatus(otelCodes.Error, "failed to pull model")
7483
span.RecordError(err)
7584
span.End()
@@ -102,7 +111,7 @@ func (s *Service) localCreateVolume(ctx context.Context, req *csi.CreateVolumeRe
102111
startedAt := time.Now()
103112
ctx, span := tracing.Tracer.Start(ctx, "PullModel")
104113
span.SetAttributes(attribute.String("model_dir", modelDir))
105-
if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, mountID, modelReference, modelDir, checkDiskQuota); err != nil {
114+
if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, mountID, modelReference, modelDir, checkDiskQuota, excludeModelWeights); err != nil {
106115
span.SetStatus(otelCodes.Error, "failed to pull model")
107116
span.RecordError(err)
108117
span.End()

pkg/service/node.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package service
22

33
import (
44
"path/filepath"
5+
"strconv"
56
"strings"
67
"time"
78

@@ -92,9 +93,19 @@ func (s *Service) nodePublishVolume(
9293
}
9394

9495
staticInlineModelReference := volumeAttributes[s.cfg.Get().ParameterKeyReference()]
96+
excludeModelWeightsParam := volumeAttributes[s.cfg.Get().ParameterKeyExcludeModelWeights()]
9597
if staticInlineModelReference != "" {
98+
excludeModelWeights := false
99+
if excludeModelWeightsParam != "" {
100+
var err error
101+
excludeModelWeights, err = strconv.ParseBool(excludeModelWeightsParam)
102+
if err != nil {
103+
return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: %v", s.cfg.Get().ParameterKeyExcludeModelWeights(), err)
104+
}
105+
}
106+
96107
logger.WithContext(ctx).Infof("publishing static inline volume: %s", staticInlineModelReference)
97-
resp, err := s.nodePublishVolumeStaticInlineVolume(ctx, volumeID, targetPath, staticInlineModelReference)
108+
resp, err := s.nodePublishVolumeStaticInlineVolume(ctx, volumeID, targetPath, staticInlineModelReference, excludeModelWeights)
98109
return resp, isStaticVolume, err
99110
}
100111

pkg/service/node_static_inline.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ import (
1515
"google.golang.org/grpc/status"
1616
)
1717

18-
func (s *Service) nodePublishVolumeStaticInlineVolume(ctx context.Context, volumeName, targetPath, reference string) (*csi.NodePublishVolumeResponse, error) {
18+
func (s *Service) nodePublishVolumeStaticInlineVolume(ctx context.Context, volumeName, targetPath, reference string, excludeModelWeights bool) (*csi.NodePublishVolumeResponse, error) {
1919
modelDir := s.cfg.Get().GetModelDir(volumeName)
2020

2121
startedAt := time.Now()
22-
if err := s.worker.PullModel(ctx, true, volumeName, "", reference, modelDir, false); err != nil {
22+
if err := s.worker.PullModel(ctx, true, volumeName, "", reference, modelDir, false, excludeModelWeights); err != nil {
2323
return nil, status.Error(codes.Internal, errors.Wrap(err, "pull model").Error())
2424
}
2525
duration := time.Since(startedAt)

pkg/service/puller.go

Lines changed: 96 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"io"
77
"os"
8+
"path/filepath"
89
"sort"
910
"sync"
1011
"sync/atomic"
@@ -27,6 +28,11 @@ import (
2728
otelCodes "go.opentelemetry.io/otel/codes"
2829
)
2930

31+
const (
32+
safetensorFilePath = "model.safetensors.index.json"
33+
safetensorFileExt = ".safetensors"
34+
)
35+
3036
type PullHook interface {
3137
BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest)
3238
AfterPullLayer(desc ocispec.Descriptor, err error)
@@ -50,7 +56,7 @@ func NewHook(ctx context.Context, progressCb func(progress status.Progress)) *Ho
5056
}
5157

5258
type Puller interface {
53-
Pull(ctx context.Context, reference, targetDir string) error
59+
Pull(ctx context.Context, reference, targetDir string, excludeModelWeights bool) error
5460
}
5561

5662
var NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *Hook, diskQuotaChecker *DiskQuotaChecker) Puller {
@@ -201,7 +207,50 @@ func (h *Hook) GetProgress() status.Progress {
201207
return h.getProgress()
202208
}
203209

204-
func (p *puller) Pull(ctx context.Context, reference, targetDir string) error {
210+
func isSafetensorFile(layer backend.InspectedModelArtifactLayer) bool {
211+
// Check file path
212+
if layer.Filepath == safetensorFilePath {
213+
return true
214+
}
215+
// Compatibility for old model artifact format
216+
if filepath.Ext(layer.Filepath) == safetensorFileExt {
217+
return true
218+
}
219+
return false
220+
}
221+
222+
func isWeightLayer(layer backend.InspectedModelArtifactLayer) bool {
223+
// Check media type
224+
if layer.MediaType == modelspec.MediaTypeModelWeightRaw ||
225+
layer.MediaType == modelspec.MediaTypeModelWeight ||
226+
layer.MediaType == modelspec.MediaTypeModelWeightGzip ||
227+
layer.MediaType == modelspec.MediaTypeModelWeightZstd {
228+
return true
229+
}
230+
if isSafetensorFile(layer) {
231+
return true
232+
}
233+
return false
234+
}
235+
236+
func getPatternsWithoutWeights(ctx context.Context, layers []backend.InspectedModelArtifactLayer) []string {
237+
paths := []string{}
238+
for idx := range layers {
239+
layer := layers[idx]
240+
if layer.Filepath == "" {
241+
logger.Logger().WithContext(ctx).Warnf(
242+
"layer %s has no file path, skip", layer.Digest,
243+
)
244+
continue
245+
}
246+
if !isWeightLayer(layer) {
247+
paths = append(paths, layer.Filepath)
248+
}
249+
}
250+
return paths
251+
}
252+
253+
func (p *puller) Pull(ctx context.Context, reference, targetDir string, excludeModelWeights bool) error {
205254
keyChain, err := auth.GetKeyChainByRef(reference)
206255
if err != nil {
207256
return errors.Wrapf(err, "get auth for model: %s", reference)
@@ -223,28 +272,55 @@ func (p *puller) Pull(ctx context.Context, reference, targetDir string) error {
223272
return errors.Wrapf(err, "create model dir: %s", targetDir)
224273
}
225274

226-
if p.pullCfg.Concurrency < 1 {
227-
p.pullCfg.Concurrency = 5
228-
}
275+
go p.checkLongPulling(ctx)
229276

230-
pullConfig := modctlConfig.NewPull()
231-
pullConfig.Concurrency = int(p.pullCfg.Concurrency)
232-
pullConfig.PlainHTTP = keyChain.ServerScheme == "http"
233-
pullConfig.Proxy = p.pullCfg.ProxyURL
234-
pullConfig.DragonflyEndpoint = p.pullCfg.DragonflyEndpoint
235-
pullConfig.Insecure = true
236-
pullConfig.ExtractDir = targetDir
237-
pullConfig.ExtractFromRemote = true
238-
pullConfig.Hooks = p.hook
239-
pullConfig.ProgressWriter = io.Discard
240-
pullConfig.DisableProgress = true
277+
plainHTTP := keyChain.ServerScheme == "http"
278+
279+
if !excludeModelWeights {
280+
pullConfig := modctlConfig.NewPull()
281+
pullConfig.Concurrency = int(p.pullCfg.Concurrency)
282+
pullConfig.PlainHTTP = plainHTTP
283+
pullConfig.Proxy = p.pullCfg.ProxyURL
284+
pullConfig.DragonflyEndpoint = p.pullCfg.DragonflyEndpoint
285+
pullConfig.Insecure = true
286+
pullConfig.ExtractDir = targetDir
287+
pullConfig.ExtractFromRemote = true
288+
pullConfig.Hooks = p.hook
289+
pullConfig.ProgressWriter = io.Discard
290+
pullConfig.DisableProgress = true
291+
292+
if err := b.Pull(ctx, reference, pullConfig); err != nil {
293+
logger.WithContext(ctx).WithError(err).Errorf("failed to pull model image: %s", reference)
294+
return errors.Wrap(err, "pull model image")
295+
}
241296

242-
go p.checkLongPulling(ctx)
297+
return nil
298+
}
243299

244-
if err := b.Pull(ctx, reference, pullConfig); err != nil {
245-
logger.WithContext(ctx).WithError(err).Errorf("failed to pull model image: %s", reference)
246-
return errors.Wrap(err, "pull model image")
300+
start := time.Now()
301+
result, err := b.Inspect(ctx, reference, &modctlConfig.Inspect{
302+
Remote: true,
303+
Insecure: true,
304+
PlainHTTP: plainHTTP,
305+
})
306+
if err != nil {
307+
return errors.Wrap(err, "inspect model")
247308
}
309+
logger.WithContext(ctx).Infof("inspected model %s, duration: %s", reference, time.Since(start))
310+
modelArtifact, ok := result.(*backend.InspectedModelArtifact)
311+
if !ok {
312+
return errors.Errorf("invalid inspected result: %s", reference)
313+
}
314+
315+
patterns := getPatternsWithoutWeights(ctx, modelArtifact.Layers)
316+
317+
fetchConfig := modctlConfig.NewFetch()
318+
fetchConfig.Concurrency = int(p.pullCfg.Concurrency)
319+
fetchConfig.PlainHTTP = plainHTTP
320+
fetchConfig.Proxy = p.pullCfg.ProxyURL
321+
fetchConfig.Insecure = true
322+
fetchConfig.Output = targetDir
323+
fetchConfig.Patterns = patterns
248324

249325
return nil
250326
}

pkg/service/quota.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"os"
66
"path/filepath"
77
"syscall"
8+
"time"
89

910
"github.com/dustin/go-humanize"
1011
"github.com/modelpack/modctl/pkg/backend"
@@ -111,10 +112,12 @@ func (d *DiskQuotaChecker) Check(ctx context.Context, b backend.Backend, referen
111112
availSize = int64(st.Bavail) * int64(st.Bsize)
112113
}
113114

115+
start := time.Now()
114116
modelSize, err := d.getModelSize(ctx, b, reference, plainHTTP)
115117
if err != nil {
116118
return errors.Wrap(err, "get model size")
117119
}
120+
logger.WithContext(ctx).Infof("inspected model %s, size: %s, duration: %s", reference, humanizeBytes(modelSize), time.Since(start))
118121

119122
logger.WithContext(ctx).Infof(
120123
"root dir maximum limit size: %s, available: %s, model: %s",

pkg/service/request.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
package service
22

33
type MountRequest struct {
4-
MountID string `json:"mount_id"`
5-
Reference string `json:"reference"`
6-
CheckDiskQuota bool `json:"check_disk_quota"`
4+
MountID string `json:"mount_id"`
5+
Reference string `json:"reference"`
6+
CheckDiskQuota bool `json:"check_disk_quota"`
7+
ExcludeModelWeights bool `json:"exclude_model_weights"`
78
}

pkg/service/worker.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,19 @@ func (worker *Worker) DeleteModel(ctx context.Context, isStaticVolume bool, volu
114114
return err
115115
}
116116

117-
func (worker *Worker) PullModel(ctx context.Context, isStaticVolume bool, volumeName, mountID, reference, modelDir string, checkDiskQuota bool) error {
117+
func (worker *Worker) PullModel(
118+
ctx context.Context,
119+
isStaticVolume bool,
120+
volumeName, mountID,
121+
reference,
122+
modelDir string,
123+
checkDiskQuota bool,
124+
excludeModelWeights bool,
125+
) error {
118126
start := time.Now()
119127

120128
statusPath := filepath.Join(filepath.Dir(modelDir), "status.json")
121-
err := worker.pullModel(ctx, statusPath, volumeName, mountID, reference, modelDir, checkDiskQuota)
129+
err := worker.pullModel(ctx, statusPath, volumeName, mountID, reference, modelDir, checkDiskQuota, excludeModelWeights)
122130
metrics.NodeOpObserve("pull_image", start, err)
123131

124132
if err != nil && !errors.Is(err, ErrConflict) {
@@ -130,7 +138,7 @@ func (worker *Worker) PullModel(ctx context.Context, isStaticVolume bool, volume
130138
return err
131139
}
132140

133-
func (worker *Worker) pullModel(ctx context.Context, statusPath, volumeName, mountID, reference, modelDir string, checkDiskQuota bool) error {
141+
func (worker *Worker) pullModel(ctx context.Context, statusPath, volumeName, mountID, reference, modelDir string, checkDiskQuota, excludeModelWeights bool) error {
134142
setStatus := func(state status.State, progress status.Progress) (*status.Status, error) {
135143
status, err := worker.sm.Set(statusPath, status.Status{
136144
VolumeName: volumeName,
@@ -188,7 +196,7 @@ func (worker *Worker) pullModel(ctx context.Context, statusPath, volumeName, mou
188196
if err != nil {
189197
return nil, errors.Wrapf(err, "set status before pull model")
190198
}
191-
if err := puller.Pull(ctx, reference, modelDir); err != nil {
199+
if err := puller.Pull(ctx, reference, modelDir, excludeModelWeights); err != nil {
192200
if errors.Is(err, context.Canceled) {
193201
err = errors.Wrapf(err, "pull model canceled")
194202
if _, err2 := setStatus(status.StatePullCanceled, hook.GetProgress()); err2 != nil {

0 commit comments

Comments
 (0)