Skip to content

Commit 56bf830

Browse files
committed
feat(kmeans): cpu parallel compute
1 parent a1afaed commit 56bf830

File tree

3 files changed

+40
-11
lines changed

3 files changed

+40
-11
lines changed

internal/gpu/cmd.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ func CommandListCreate() (ze.CommandListHandle, error) {
77
return ctx.CommandListCreate(dev)
88
}
99

10-
// ExecuteCommandLists submits the command list for execution on the command queue.
10+
// ExecCommandLists submits the command list for execution on the command queue.
1111
func ExecCommandLists(hCommandList ...ze.CommandListHandle) error {
1212
return q.ExecuteCommandLists(hCommandList...)
1313
}

kmeans.go

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"image/color"
66
"math"
77
"math/rand"
8+
"runtime"
9+
"sync"
810
"unsafe"
911

1012
"github.com/disintegration/imaging"
@@ -61,15 +63,14 @@ func newKMeansImage(img image.Image, k uint16) kmeansImage {
6163
width := img.Bounds().Dx()
6264
height := img.Bounds().Dy()
6365
dstw, dsth := width, height
64-
ratio := 0.
6566
if dstw > 512 {
6667
dstw = 512
67-
ratio = float64(dstw) / float64(width)
68+
ratio := float64(dstw) / float64(width)
6869
dsth *= int(float64(height) * ratio)
6970
}
7071
if dsth > 512 {
7172
dsth = 512
72-
ratio = float64(dsth) / float64(height)
73+
ratio := float64(dsth) / float64(height)
7374
dstw = int(float64(width) * ratio)
7475
}
7576
ki.bounds = image.Rect(0, 0, dstw, dsth)
@@ -94,7 +95,33 @@ func (ki *kmeansImage) assign() {
9495
ki.gpuDestroy(true)
9596
}
9697

97-
for i, pixel := range ki.pixels {
98+
n := runtime.NumCPU()
99+
batchcnt := len(ki.pixels) / n
100+
rem := len(ki.pixels) % n
101+
wg := sync.WaitGroup{}
102+
wg.Add(n)
103+
if rem < 0 {
104+
wg.Add(1)
105+
}
106+
for batch := range n {
107+
go func(batch int) {
108+
base := batch * batchcnt
109+
for i, pixel := range ki.pixels[base : base+batchcnt] {
110+
minDistance := math.MaxFloat64
111+
assign := uint16(math.MaxUint16)
112+
for j, cluster := range ki.clusters {
113+
distance := distanceRGBAsq(pixel, cluster)
114+
if distance < minDistance {
115+
minDistance = distance
116+
assign = uint16(j)
117+
}
118+
}
119+
ki.clusterAssignments[base+i] = assign
120+
}
121+
}(batch)
122+
}
123+
base := n * batchcnt
124+
for i, pixel := range ki.pixels[n*batchcnt:] {
98125
minDistance := math.MaxFloat64
99126
assign := uint16(math.MaxUint16)
100127
for j, cluster := range ki.clusters {
@@ -104,7 +131,7 @@ func (ki *kmeansImage) assign() {
104131
assign = uint16(j)
105132
}
106133
}
107-
ki.clusterAssignments[i] = assign
134+
ki.clusterAssignments[base+i] = assign
108135
}
109136
}
110137

kmeans_ocl.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,14 @@ func (ki *kmeansImage) gpuInit() {
4141
width := ki.bounds.Dx()
4242
height := ki.bounds.Dy()
4343
dstw, dsth := width, height
44-
ratio := 0.
4544
if dstw > 512 {
4645
dstw = 512
47-
ratio = float64(dstw) / float64(width)
46+
ratio := float64(dstw) / float64(width)
4847
dsth *= int(float64(height) * ratio)
4948
}
5049
if dsth > 512 {
5150
dsth = 512
52-
ratio = float64(dsth) / float64(height)
51+
ratio := float64(dsth) / float64(height)
5352
dstw = int(float64(width) * ratio)
5453
}
5554
ki.bounds = image.Rect(0, 0, dstw, dsth)
@@ -211,13 +210,13 @@ func (ki *kmeansImage) gpuInit() {
211210
ki.gpuDestroy(true)
212211
return
213212
}
214-
err = krn1st.SetGroupSize(uint32(gX), uint32(gY), 1)
213+
err = krn1st.SetGroupSize(gX, gY, 1)
215214
if err != nil {
216215
canUseKmeansKernel = false
217216
ki.gpuDestroy(true)
218217
return
219218
}
220-
err = krnrem.SetGroupSize(uint32(gX), uint32(gY), 1)
219+
err = krnrem.SetGroupSize(gX, gY, 1)
221220
if err != nil {
222221
canUseKmeansKernel = false
223222
ki.gpuDestroy(true)
@@ -297,6 +296,9 @@ func (ki *kmeansImage) gpuAssign() error {
297296
err = lst.AppendLaunchKernel(ki.krn1st, &gozel.ZeGroupCount{
298297
Groupcountx: ki.gcx, Groupcounty: ki.gcy, Groupcountz: 1,
299298
}, kev, inpcpev, cluscpev)
299+
if err != nil {
300+
return err
301+
}
300302

301303
smpcpev, cl, err := gpu.EventCreate(gozel.ZE_EVENT_SCOPE_FLAG_HOST, 0)
302304
if err != nil {

0 commit comments

Comments
 (0)