Skip to content

Commit e604980

Browse files
authored
fixup: check safetensors files by file extension (#13)
fixup: check weight files by file extension We noticed that during modctl builds, some non-weight files might be mistakenly identified as weight files (application/vnd.cnai.model.doc.v1.raw), such as the `tiktoken.model` file in kimi k2. This could cause these files to be excluded when `excludeWeights = true`, affecting inference engine startup. This PR fixes it by checking file extensions to accurately identify weight files, preventing the above problem. Currently, this solution applies only to .safetensors files, if additional weight file formats are introduced in the future, the detection logic will need to be further extended. Signed-off-by: imeoer <[email protected]>
1 parent 0658ae6 commit e604980

File tree

3 files changed

+9
-23
lines changed

3 files changed

+9
-23
lines changed

pkg/service/model.go

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
package service
22

33
import (
4+
"path/filepath"
45
"sync"
56
"time"
67

7-
legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1"
88
"github.com/modelpack/modctl/pkg/backend"
99
modctlConfig "github.com/modelpack/modctl/pkg/config"
1010
"github.com/modelpack/model-csi-driver/pkg/logger"
1111
"github.com/modelpack/model-csi-driver/pkg/utils"
12-
modelspec "github.com/modelpack/model-spec/specs-go/v1"
1312
"github.com/pkg/errors"
1413
"golang.org/x/net/context"
1514
)
@@ -24,26 +23,17 @@ type ModelArtifact struct {
2423
artifact *backend.InspectedModelArtifact
2524
}
2625

27-
func isSafetensorIndexFile(layer backend.InspectedModelArtifactLayer) bool {
28-
return layer.Filepath == safetensorIndexFilePath
29-
}
30-
3126
func isWeightLayer(layer backend.InspectedModelArtifactLayer) bool {
32-
if layer.MediaType == modelspec.MediaTypeModelWeightRaw ||
33-
layer.MediaType == modelspec.MediaTypeModelWeight ||
34-
layer.MediaType == modelspec.MediaTypeModelWeightGzip ||
35-
layer.MediaType == modelspec.MediaTypeModelWeightZstd {
36-
return true
37-
}
38-
if layer.MediaType == legacymodelspec.MediaTypeModelWeightRaw ||
39-
layer.MediaType == legacymodelspec.MediaTypeModelWeight ||
40-
layer.MediaType == legacymodelspec.MediaTypeModelWeightGzip ||
41-
layer.MediaType == legacymodelspec.MediaTypeModelWeightZstd {
27+
// For *.safetensors files
28+
if filepath.Ext(layer.Filepath) == ".safetensors" {
4229
return true
4330
}
44-
if isSafetensorIndexFile(layer) {
31+
32+
// For safetensors index file
33+
if layer.Filepath == "model.safetensors.index.json" {
4534
return true
4635
}
36+
4737
return false
4838
}
4939

pkg/service/model_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func TestModelArtifact(t *testing.T) {
4141
MediaType: modelspec.MediaTypeModelWeightRaw,
4242
Digest: "sha256:layer1",
4343
Size: 3 * 1024 * 1024,
44-
Filepath: "zoo.safetensors",
44+
Filepath: "bar.zoo.safetensors",
4545
},
4646
},
4747
}, nil
@@ -60,7 +60,7 @@ func TestModelArtifact(t *testing.T) {
6060

6161
paths, err := modelArtifact.GetPatterns(ctx, false)
6262
require.NoError(t, err)
63-
require.Equal(t, []string{"foo.safetensors", "README.md", "zoo.safetensors"}, paths)
63+
require.Equal(t, []string{"foo.safetensors", "README.md", "bar.zoo.safetensors"}, paths)
6464

6565
paths, err = modelArtifact.GetPatterns(ctx, true)
6666
require.NoError(t, err)

pkg/service/puller.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ import (
1616
"github.com/pkg/errors"
1717
)
1818

19-
const (
20-
safetensorIndexFilePath = "model.safetensors.index.json"
21-
)
22-
2319
type PullHook interface {
2420
BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest)
2521
AfterPullLayer(desc ocispec.Descriptor, err error)

0 commit comments

Comments
 (0)