Skip to content

Commit ba0d92e

Browse files
committed
refactor: separate kmeans from color
1 parent 59288a1 commit ba0d92e

File tree

9 files changed

+271
-134
lines changed

9 files changed

+271
-134
lines changed

color.go

Lines changed: 12 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package gg
33
import (
44
"image"
55
"image/color"
6-
"math"
7-
"math/rand"
86
"unsafe"
97
)
108

@@ -30,60 +28,15 @@ var (
3028
//
3129
// TakeThemeColorsKMeans 使用 k-means 算法从图像中提取 k 个主色。
3230
func TakeThemeColorsKMeans(img image.Image, k int) []color.RGBA {
33-
rgbaimg := ImageToRGBA(img)
34-
pixels := unsafe.Slice(
35-
(*color.RGBA)(unsafe.Pointer(unsafe.SliceData(rgbaimg.Pix))),
36-
uintptr(len(rgbaimg.Pix))/unsafe.Sizeof(color.RGBA{}),
37-
)
38-
39-
// 初始化k个聚类中心
40-
clusters := make([]color.RGBA, k)
41-
for i := range k {
42-
clusters[i] = pixels[rand.Intn(len(pixels))]
43-
}
44-
45-
// 迭代聚类
31+
ki := newKMeansImage(img, k) // 初始化k个聚类中心
4632
for {
47-
// 将每个像素点分配到最近的聚类中心
48-
clusterAssignments := make([]int, len(pixels))
49-
for i, pixel := range pixels {
50-
minDistance := math.MaxFloat64
51-
for j, cluster := range clusters {
52-
distance := distance(pixel, cluster)
53-
if distance < minDistance {
54-
minDistance = distance
55-
clusterAssignments[i] = j
56-
}
57-
}
58-
}
59-
60-
// 计算每个聚类的新中心
61-
newClusters := make([]color.RGBA, k)
62-
for currentCluster := range k {
63-
var r, g, b uint32
64-
n := 0
65-
for j, pixelCluster := range clusterAssignments {
66-
if pixelCluster == currentCluster {
67-
pixel := pixels[j]
68-
r += uint32(pixel.R)
69-
g += uint32(pixel.G)
70-
b += uint32(pixel.B)
71-
n++
72-
}
73-
}
74-
if n != 0 {
75-
newClusters[currentCluster] = color.RGBA{uint8(r / uint32(n)), uint8(g / uint32(n)), uint8(b / uint32(n)), 255}
76-
}
77-
}
78-
79-
// 如果聚类中心没有变化,则停止迭代
80-
if isArrayRGBAEqual(clusters, newClusters) {
33+
ki.assign()
34+
ki.update()
35+
if ki.epilogue() {
8136
break
8237
}
83-
clusters = newClusters
8438
}
85-
86-
return clusters
39+
return ki.result()
8740
}
8841

8942
// isArrayRGBAEqual compares two []color.RGBA is equal fastly.
@@ -115,12 +68,11 @@ func isArrayRGBAEqual(a, b []color.RGBA) bool {
11568
return true
11669
}
11770

118-
// 计算两个颜色之间的距离
119-
func distance(a, b color.RGBA) float64 {
120-
return math.Sqrt(sq(float64(a.R)-float64(b.R)) + sq(float64(a.G)-float64(b.G)) + sq(float64(a.B)-float64(b.B)))
121-
}
122-
123-
// 计算平方
124-
func sq(n float64) float64 {
125-
return n * n
71+
// distanceRGBAsq calc between two color.RGBAs (RGB only, no sqrt to speedup)
72+
//
73+
// distanceRGBAsq 计算两个 color.RGBA 颜色之间的距离(仅 RGB,忽略开方以加速)
74+
func distanceRGBAsq(a, b color.RGBA) float64 {
75+
return sq(float64(a.R)-float64(b.R)) +
76+
sq(float64(a.G)-float64(b.G)) +
77+
sq(float64(a.B)-float64(b.B))
12678
}

color_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func twoColorImage(w, h int, c1, c2 color.RGBA) image.Image {
3939
// colorInSlice 判断给定颜色是否在切片中(允许 tolerance 误差)
4040
func colorInSlice(c color.RGBA, slice []color.RGBA, tolerance float64) bool {
4141
for _, s := range slice {
42-
if distance(c, s) <= tolerance {
42+
if distanceRGBAsq(c, s) <= tolerance {
4343
return true
4444
}
4545
}
@@ -100,15 +100,15 @@ func TestSq_LargeValue(t *testing.T) {
100100

101101
func TestDistance_SameColor(t *testing.T) {
102102
c := color.RGBA{100, 150, 200, 255}
103-
if got := distance(c, c); got != 0 {
103+
if got := distanceRGBAsq(c, c); got != 0 {
104104
t.Errorf("distance(same, same) = %v, want 0", got)
105105
}
106106
}
107107

108108
func TestDistance_BlackAndWhite(t *testing.T) {
109109
// sqrt(255^2 * 3) = 255 * sqrt(3)
110110
want := 255 * math.Sqrt(3)
111-
got := distance(Black, White)
111+
got := math.Sqrt(distanceRGBAsq(Black, White))
112112
if math.Abs(got-want) > 1e-9 {
113113
t.Errorf("distance(black, white) = %v, want %v", got, want)
114114
}
@@ -119,7 +119,7 @@ func TestDistance_SingleChannel(t *testing.T) {
119119
b := color.RGBA{3, 4, 0, 255}
120120
// sqrt(9 + 16) = 5
121121
want := 5.0
122-
got := distance(a, b)
122+
got := math.Sqrt(distanceRGBAsq(a, b))
123123
if math.Abs(got-want) > 1e-9 {
124124
t.Errorf("distance(%v, %v) = %v, want %v", a, b, got, want)
125125
}
@@ -128,7 +128,7 @@ func TestDistance_SingleChannel(t *testing.T) {
128128
func TestDistance_Symmetry(t *testing.T) {
129129
a := color.RGBA{10, 20, 30, 255}
130130
b := color.RGBA{50, 80, 110, 255}
131-
if d1, d2 := distance(a, b), distance(b, a); math.Abs(d1-d2) > 1e-9 {
131+
if d1, d2 := distanceRGBAsq(a, b), distanceRGBAsq(b, a); math.Abs(d1-d2) > 1e-9 {
132132
t.Errorf("distance not symmetric: d(a,b)=%v, d(b,a)=%v", d1, d2)
133133
}
134134
}
@@ -138,7 +138,7 @@ func TestDistance_NonNegative(t *testing.T) {
138138
for range 100 {
139139
a := color.RGBA{uint8(rng.Intn(256)), uint8(rng.Intn(256)), uint8(rng.Intn(256)), 255}
140140
b := color.RGBA{uint8(rng.Intn(256)), uint8(rng.Intn(256)), uint8(rng.Intn(256)), 255}
141-
got := distance(a, b)
141+
got := distanceRGBAsq(a, b)
142142
if got < 0 {
143143
t.Errorf("distance returned negative value %v for %v, %v", got, a, b)
144144
}
@@ -150,7 +150,7 @@ func TestDistance_IgnoresAlpha(t *testing.T) {
150150
a2 := color.RGBA{100, 150, 200, 255}
151151
b := color.RGBA{50, 50, 50, 128}
152152
// Alpha 不参与计算,结果应相同
153-
if d1, d2 := distance(a1, b), distance(a2, b); math.Abs(d1-d2) > 1e-9 {
153+
if d1, d2 := distanceRGBAsq(a1, b), distanceRGBAsq(a2, b); math.Abs(d1-d2) > 1e-9 {
154154
t.Errorf("distance should ignore alpha: d(a1,b)=%v, d(a2,b)=%v", d1, d2)
155155
}
156156
}
@@ -320,7 +320,7 @@ func BenchmarkDistance(b *testing.B) {
320320
c := color.RGBA{50, 80, 30, 255}
321321
b.ResetTimer()
322322
for range b.N {
323-
distance(a, c)
323+
distanceRGBAsq(a, c)
324324
}
325325
}
326326

effects.go

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,17 @@
11
package gg
22

3-
func gateN100P100(per int) int {
4-
if per > 100 {
5-
per = 100
6-
}
7-
if per < -100 {
8-
per = -100
9-
}
10-
return per
11-
}
12-
13-
func gate0P255(n int) int {
14-
if n < 0 {
15-
n = 0
16-
}
17-
if n > 255 {
18-
n = 255
19-
}
20-
return n
21-
}
22-
233
// Brightness 调整亮度 范围:±100%
244
func (dc *Context) Brightness(per int) {
255
if per == 0 {
266
return
277
}
28-
per = gateN100P100(per)
8+
per = clamp(per, -100, 100)
299
gain := 255 * per / 100
3010
for i, v := range dc.im.Pix {
3111
if i%4 == 3 { // alpha
3212
continue
3313
}
34-
dc.im.Pix[i] = uint8(gate0P255(int(v) + gain))
14+
dc.im.Pix[i] = uint8(clamp(int(v)+gain, 0, 255))
3515
}
3616
}
3717

@@ -40,15 +20,15 @@ func (dc *Context) Contrast(per int) {
4020
if per == 0 {
4121
return
4222
}
43-
per = gateN100P100(per) + 100
23+
per = clamp(per, -100, 100) + 100
4424
switch {
4525
case 0 <= per && per < 100: // 损益
4626
gain := per
4727
for i, v := range dc.im.Pix {
4828
if i%4 == 3 { // alpha
4929
continue
5030
}
51-
dc.im.Pix[i] = uint8(gate0P255(int(v) * gain / 100))
31+
dc.im.Pix[i] = uint8(clamp(int(v)*gain/100, 0, 255))
5232
}
5333
case 100 < per && per <= 200: // 增益
5434
gain := 200 - per
@@ -59,7 +39,7 @@ func (dc *Context) Contrast(per int) {
5939
if i%4 == 3 { // alpha
6040
continue
6141
}
62-
dc.im.Pix[i] = uint8(gate0P255(int(v) * 100 / gain))
42+
dc.im.Pix[i] = uint8(clamp(int(v)*100/gain, 0, 255))
6343
}
6444
default:
6545
panic("unreachable")

examples/theme-colors_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package main
2+
3+
import (
4+
"image/color"
5+
"math"
6+
"testing"
7+
8+
"github.com/FloatTech/gg"
9+
"github.com/FloatTech/gg/fio"
10+
)
11+
12+
func TestThemeColorsGopher(t *testing.T) {
13+
im, err := fio.LoadPNG("gopher.png")
14+
if err != nil {
15+
t.Fatal(err)
16+
}
17+
18+
expected := []color.RGBA{
19+
{105, 213, 226, 255}, {249, 240, 227, 255}, {4, 6, 6, 255},
20+
}
21+
22+
const (
23+
k = 3
24+
maxAttempts = 30
25+
tolerance = 600.0 // 允许的颜色距离平方
26+
)
27+
28+
var result []color.RGBA
29+
found := false
30+
for range maxAttempts {
31+
result = gg.TakeThemeColorsKMeans(im, k)
32+
if len(result) != k {
33+
t.Fatalf("expected %d colors, got %d", k, len(result))
34+
}
35+
if allColorsMatch(expected, result, tolerance) {
36+
found = true
37+
break
38+
}
39+
}
40+
if !found {
41+
t.Errorf("theme colors did not match expected in %d attempts\nexpected: %v\ngot: %v", maxAttempts, expected, result)
42+
}
43+
44+
// 绘制提取到的主题色方块
45+
const (
46+
blockW = 120
47+
blockH = 120
48+
pad = 20
49+
)
50+
w := pad + k*(blockW+pad)
51+
h := pad + blockH + pad
52+
dc := gg.NewContext(w, h)
53+
54+
// 白色背景
55+
dc.SetColor(color.White)
56+
dc.Clear()
57+
58+
for i, c := range result {
59+
x := float64(pad + i*(blockW+pad))
60+
y := float64(pad)
61+
dc.SetColor(c)
62+
dc.DrawRectangle(x, y, blockW, blockH)
63+
dc.Fill()
64+
}
65+
66+
if err := dc.SavePNG(GetFileName() + ".png"); err != nil {
67+
t.Fatal(err)
68+
}
69+
}
70+
71+
// allColorsMatch 检查 expected 中每个颜色都能在 result 中找到匹配(无序)
72+
func allColorsMatch(expected, result []color.RGBA, tolerance float64) bool {
73+
if len(expected) != len(result) {
74+
return false
75+
}
76+
used := make([]bool, len(result))
77+
for _, e := range expected {
78+
matched := false
79+
for j, r := range result {
80+
if !used[j] && colorDistSq(e, r) <= tolerance {
81+
used[j] = true
82+
matched = true
83+
break
84+
}
85+
}
86+
if !matched {
87+
return false
88+
}
89+
}
90+
return true
91+
}
92+
93+
func colorDistSq(a, b color.RGBA) float64 {
94+
dr := float64(a.R) - float64(b.R)
95+
dg := float64(a.G) - float64(b.G)
96+
db := float64(a.B) - float64(b.B)
97+
return math.Abs(dr*dr) + math.Abs(dg*dg) + math.Abs(db*db)
98+
}

factory/img.go

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"image/gif"
1212
"io"
1313
"net/http"
14-
"os"
1514
"strings"
1615

1716
"github.com/FloatTech/gg"
@@ -89,15 +88,7 @@ func LoadAllFrames(path string, w, h int) ([]*Factory, error) {
8988
return nil, err
9089
}
9190
} else {
92-
file, err := os.Open(path)
93-
if err != nil {
94-
return nil, err
95-
}
96-
im, err = gif.DecodeAll(file)
97-
_ = file.Close()
98-
if err != nil {
99-
return nil, err
100-
}
91+
im, err = fio.LoadGIF(path)
10192
}
10293
img, err := Load(path)
10394
if err != nil {
@@ -127,15 +118,7 @@ func LoadAllTrueFrames(path string, w, h int) ([]*Factory, error) {
127118
return nil, err
128119
}
129120
} else {
130-
file, err := os.Open(path)
131-
if err != nil {
132-
return nil, err
133-
}
134-
im, err = gif.DecodeAll(file)
135-
_ = file.Close()
136-
if err != nil {
137-
return nil, err
138-
}
121+
im, err = fio.LoadGIF(path)
139122
}
140123
imgWidth, imgHeight := getGifDimensions(im)
141124
overpaintImage := image.NewRGBA(image.Rect(0, 0, imgWidth, imgHeight))

0 commit comments

Comments
 (0)