|
| 1 | +package service |
| 2 | + |
| 3 | +import ( |
| 4 | + "os" |
| 5 | + "path/filepath" |
| 6 | + "testing" |
| 7 | + |
| 8 | + "github.com/modelpack/model-csi-driver/pkg/config" |
| 9 | + "github.com/modelpack/model-csi-driver/pkg/metrics" |
| 10 | + "github.com/modelpack/model-csi-driver/pkg/status" |
| 11 | + "github.com/prometheus/client_golang/prometheus" |
| 12 | + "github.com/prometheus/client_golang/prometheus/testutil" |
| 13 | + dto "github.com/prometheus/client_model/go" |
| 14 | + "github.com/stretchr/testify/require" |
| 15 | +) |
| 16 | + |
| 17 | +func TestCacheManagerScanUpdatesMetrics(t *testing.T) { |
| 18 | + tempDir := t.TempDir() |
| 19 | + |
| 20 | + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tempDir} |
| 21 | + cfg := config.NewWithRaw(rawCfg) |
| 22 | + |
| 23 | + sm, err := status.NewStatusManager() |
| 24 | + require.NoError(t, err) |
| 25 | + |
| 26 | + // Create a static volume status |
| 27 | + staticStatusPath := filepath.Join(tempDir, "volumes", "pvc-static", "status.json") |
| 28 | + _, err = sm.Set(staticStatusPath, status.Status{Reference: "ref-static", MountID: "m-static"}) |
| 29 | + require.NoError(t, err) |
| 30 | + |
| 31 | + // Create a dynamic volume status under models/<mountID>/status.json |
| 32 | + dynamicStatusPath := filepath.Join(tempDir, "volumes", "csi-dyn", "models", "mount-1", "status.json") |
| 33 | + _, err = sm.Set(dynamicStatusPath, status.Status{Reference: "ref-dyn", MountID: "mount-1"}) |
| 34 | + require.NoError(t, err) |
| 35 | + |
| 36 | + // An extra file to ensure cache size covers arbitrary files under RootDir. |
| 37 | + extraPath := filepath.Join(tempDir, "extra.bin") |
| 38 | + require.NoError(t, os.WriteFile(extraPath, []byte("abc"), 0o644)) |
| 39 | + |
| 40 | + paths := []string{staticStatusPath, dynamicStatusPath, extraPath} |
| 41 | + var expectedSize int64 |
| 42 | + for _, p := range paths { |
| 43 | + st, statErr := os.Stat(p) |
| 44 | + require.NoError(t, statErr) |
| 45 | + expectedSize += st.Size() |
| 46 | + } |
| 47 | + |
| 48 | + cm := &CacheManager{cfg: cfg, sm: sm} |
| 49 | + require.NoError(t, cm.Scan()) |
| 50 | + |
| 51 | + require.Equal(t, float64(expectedSize), testutil.ToFloat64(metrics.NodeCacheSizeInBytes)) |
| 52 | + require.Equal(t, float64(1), testutil.ToFloat64(metrics.NodeMountedStaticImages)) |
| 53 | + require.Equal(t, float64(0), testutil.ToFloat64(metrics.NodeMountedInlineImages)) |
| 54 | + require.Equal(t, float64(1), testutil.ToFloat64(metrics.NodeMountedDynamicImages)) |
| 55 | + |
| 56 | + // Verify mount item metrics are exported as a snapshot without Reset/Delete races. |
| 57 | + reg := prometheus.NewRegistry() |
| 58 | + reg.MustRegister(metrics.MountItems) |
| 59 | + |
| 60 | + mfs, err := reg.Gather() |
| 61 | + require.NoError(t, err) |
| 62 | + |
| 63 | + mf := findMetricFamily(t, mfs, metrics.Prefix+"mount_item") |
| 64 | + require.Len(t, mf.Metric, 2) |
| 65 | + |
| 66 | + staticLabels := map[string]string{ |
| 67 | + "reference": "ref-static", |
| 68 | + "type": "static", |
| 69 | + "volume_name": "pvc-static", |
| 70 | + "mount_id": "m-static", |
| 71 | + } |
| 72 | + dynamicLabels := map[string]string{ |
| 73 | + "reference": "ref-dyn", |
| 74 | + "type": "dynamic", |
| 75 | + "volume_name": "csi-dyn", |
| 76 | + "mount_id": "mount-1", |
| 77 | + } |
| 78 | + |
| 79 | + var foundStatic, foundDynamic bool |
| 80 | + for _, m := range mf.Metric { |
| 81 | + if hasLabels(m, staticLabels) { |
| 82 | + foundStatic = true |
| 83 | + } |
| 84 | + if hasLabels(m, dynamicLabels) { |
| 85 | + foundDynamic = true |
| 86 | + } |
| 87 | + } |
| 88 | + require.True(t, foundStatic, "static mount item metric not found") |
| 89 | + require.True(t, foundDynamic, "dynamic mount item metric not found") |
| 90 | +} |
| 91 | + |
| 92 | +func findMetricFamily(t *testing.T, mfs []*dto.MetricFamily, name string) *dto.MetricFamily { |
| 93 | + t.Helper() |
| 94 | + for _, mf := range mfs { |
| 95 | + if mf.GetName() == name { |
| 96 | + return mf |
| 97 | + } |
| 98 | + } |
| 99 | + require.FailNow(t, "metric family not found", name) |
| 100 | + return nil |
| 101 | +} |
| 102 | + |
| 103 | +func hasLabels(m *dto.Metric, want map[string]string) bool { |
| 104 | + labels := map[string]string{} |
| 105 | + for _, lp := range m.GetLabel() { |
| 106 | + labels[lp.GetName()] = lp.GetValue() |
| 107 | + } |
| 108 | + for k, v := range want { |
| 109 | + if labels[k] != v { |
| 110 | + return false |
| 111 | + } |
| 112 | + } |
| 113 | + return true |
| 114 | +} |
0 commit comments