Skip to content

Commit 0a1b507

Browse files
committed
AI: Allow users to use a custom source for TensorFlow labels photoprism#5011 photoprism#5232
Signed-off-by: Michael Mayer <[email protected]>
1 parent 46b2197 commit 0a1b507

File tree

5 files changed

+29
-14
lines changed

5 files changed

+29
-14
lines changed

internal/ai/vision/config.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ func (c *ConfigValues) Load(fileName string) error {
9292

9393
switch model.Type {
9494
case ModelTypeLabels:
95-
c.Models[i] = NasnetModel
95+
c.Models[i] = NasnetModel.Clone()
9696
case ModelTypeNsfw:
97-
c.Models[i] = NsfwModel
97+
c.Models[i] = NsfwModel.Clone()
9898
case ModelTypeFace:
99-
c.Models[i] = FacenetModel
99+
c.Models[i] = FacenetModel.Clone()
100100
case ModelTypeCaption:
101-
c.Models[i] = CaptionModel
101+
c.Models[i] = CaptionModel.Clone()
102102
}
103103

104104
if runType != RunAuto {

internal/ai/vision/labels.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
var labelsFunc = labelsInternal
1717

1818
// SetLabelsFunc overrides the labels generator. Intended for tests.
19-
func SetLabelsFunc(fn func(Files, media.Src, string) (classify.Labels, error)) {
19+
func SetLabelsFunc(fn func(Files, media.Src, entity.Src) (classify.Labels, error)) {
2020
if fn == nil {
2121
labelsFunc = labelsInternal
2222
return
@@ -28,11 +28,11 @@ func SetLabelsFunc(fn func(Files, media.Src, string) (classify.Labels, error)) {
2828
// GenerateLabels finds matching labels for the specified image.
2929
// Caller must pass the appropriate metadata source string (e.g., entity.SrcOllama, entity.SrcOpenAI)
3030
// so that downstream indexing can record where the labels originated.
31-
func GenerateLabels(images Files, mediaSrc media.Src, labelSrc string) (classify.Labels, error) {
31+
func GenerateLabels(images Files, mediaSrc media.Src, labelSrc entity.Src) (classify.Labels, error) {
3232
return labelsFunc(images, mediaSrc, labelSrc)
3333
}
3434

35-
func labelsInternal(images Files, mediaSrc media.Src, labelSrc string) (result classify.Labels, err error) {
35+
func labelsInternal(images Files, mediaSrc media.Src, labelSrc entity.Src) (result classify.Labels, err error) {
3636
// Return if no thumbnail filenames were given.
3737
if len(images) == 0 {
3838
return result, errors.New("at least one image required")
@@ -127,7 +127,7 @@ func labelsInternal(images Files, mediaSrc media.Src, labelSrc string) (result c
127127
return result, err
128128
}
129129

130-
result = mergeLabels(result, labels)
130+
result = mergeLabels(result, labels, labelSrc)
131131
}
132132
} else {
133133
return result, errors.New("invalid labels model configuration")
@@ -141,15 +141,19 @@ func labelsInternal(images Files, mediaSrc media.Src, labelSrc string) (result c
141141
return result, nil
142142
}
143143

144-
// mergeLabels combines existing labels with newly detected labels and returns the result.
145-
func mergeLabels(result, labels classify.Labels) classify.Labels {
144+
// mergeLabels combines existing labels with newly detected labels, applies a custom source, and returns the result.
145+
func mergeLabels(result, labels classify.Labels, labelSrc entity.Src) classify.Labels {
146146
if len(labels) == 0 {
147147
return result
148148
}
149149

150150
for j := range labels {
151151
found := false
152152

153+
if labelSrc != entity.SrcAuto {
154+
labels[j].Source = labelSrc
155+
}
156+
153157
for k := range result {
154158
if labels[j].Name == result[k].Name {
155159
found = true

internal/ai/vision/labels_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,17 @@ func TestGenerateLabels(t *testing.T) {
4949
assert.InDelta(t, 60, result[0].Uncertainty, 10)
5050
assert.InDelta(t, float32(0.4), result[0].Confidence(), 0.1)
5151
})
52+
t.Run("CustomSourceLocal", func(t *testing.T) {
53+
labels, err := GenerateLabels(Files{examplesPath + "/cat_224.jpeg"}, media.SrcLocal, entity.SrcManual)
54+
if err != nil {
55+
t.Fatalf("GenerateLabels error: %v", err)
56+
}
57+
for _, label := range labels {
58+
if label.Source != entity.SrcManual {
59+
t.Fatalf("expected custom source %q, got %q", entity.SrcManual, label.Source)
60+
}
61+
}
62+
})
5263
t.Run("InvalidFile", func(t *testing.T) {
5364
_, err := GenerateLabels(Files{examplesPath + "/notexisting.jpg"}, media.SrcLocal, entity.SrcAuto)
5465
assert.Error(t, err)

internal/photoprism/index_vision_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func TestIndexLabelsSource(t *testing.T) {
8383

8484
t.Run("AutoUsesModelSource", func(t *testing.T) {
8585
var captured string
86-
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src string) (classify.Labels, error) {
86+
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src entity.Src) (classify.Labels, error) {
8787
captured = src
8888
return classify.Labels{{Name: "stub", Source: src, Uncertainty: 0}}, nil
8989
})
@@ -96,7 +96,7 @@ func TestIndexLabelsSource(t *testing.T) {
9696

9797
t.Run("CustomSource", func(t *testing.T) {
9898
var captured string
99-
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src string) (classify.Labels, error) {
99+
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src entity.Src) (classify.Labels, error) {
100100
captured = src
101101
return classify.Labels{{Name: "stub", Source: src, Uncertainty: 0}}, nil
102102
})

internal/photoprism/mediafile_vision_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func TestMediaFile_GenerateLabels(t *testing.T) {
8686

8787
t.Run("AutoUsesModelSource", func(t *testing.T) {
8888
var captured string
89-
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src string) (classify.Labels, error) {
89+
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src entity.Src) (classify.Labels, error) {
9090
captured = src
9191
return classify.Labels{{Name: "stub", Source: src}}, nil
9292
})
@@ -98,7 +98,7 @@ func TestMediaFile_GenerateLabels(t *testing.T) {
9898

9999
t.Run("CustomSourceOverrides", func(t *testing.T) {
100100
var captured string
101-
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src string) (classify.Labels, error) {
101+
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src entity.Src) (classify.Labels, error) {
102102
captured = src
103103
return classify.Labels{{Name: "stub", Source: src}}, nil
104104
})

0 commit comments

Comments
 (0)