Skip to content

Commit bce1eb3

Browse files
[REFACT]: Validation Method 수정
- 기존 DBI K값 2에서 optimal하게 나오는 경우가 있었음. - 이에 따라 silhouette 로 변경
1 parent 0bfd616 commit bce1eb3

File tree

4 files changed

+84
-28
lines changed

4 files changed

+84
-28
lines changed

Sources/Clustering/Cluster.swift

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ final class Cluster<T: ClusterData> {
1616
var group: LinkedList<T>
1717

1818
private var sumOfLocation: Location
19+
20+
var size: Int { group.size }
1921

2022
// MARK: - Initalizers
2123
init(centroid: Location) {
@@ -58,12 +60,18 @@ extension Cluster {
5860
if group.size == 0 { return }
5961
centroid = sumOfLocation / Double(group.size)
6062
}
61-
63+
}
64+
65+
// MARK: - Validation 로직
66+
extension Cluster {
6267
/// Cluster의 Centriod로부터 분산을 리턴합니다.
63-
func deviation() -> Double {
68+
func deviation(from data: T) -> Double {
69+
// 해당 데이터가 클러스터내에 위치한다면, -1을 해줍니다.
70+
let divisor = group.contain(data) ? group.size - 1 : group.size
71+
6472
return group.allValues()
6573
.map { $0.location }
66-
.reduce(0) { $0 + centroid.distance(with: $1) } / Double(group.size)
74+
.reduce(0) { $0 + data.location.distance(with: $1) } / Double(divisor)
6775
}
6876
}
6977

Sources/Clustering/Clustering.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ extension Clustering {
3737
) {
3838
queue.cancelAllOperations()
3939

40-
let kMeansResults = kRange.map { k -> KMeans in
40+
let kMeansResults = kRange
41+
.filter { $0 <= data.count && $0 >= 2 }
42+
.map { k -> KMeans in
4143
let kMeans = KMeans(k: k, data: data, maxIterations: maxIterations)
4244
queue.addOperation(kMeans)
4345
return kMeans
@@ -59,7 +61,7 @@ extension Clustering {
5961

6062
/// Optimal한 Clustering을 리턴합니다..
6163
private func getOptimalClustering(_ kMeansResults: [KMeans<DataType>]) -> KMeans<DataType>? {
62-
kMeansResults.min(by: { $0.dbi < $1.dbi })
64+
kMeansResults.max(by: { $0.silhouetteScore < $1.silhouetteScore })
6365
}
6466

6567
private func convertToClusterResults(_ clusters: [Cluster<DataType>]) -> [ClusterResult<DataType>] {

Sources/Clustering/KMeans.swift

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ final class KMeans<T: ClusterData>: Operation {
2020

2121
/// Cluster들의 Centriods
2222
var centroids: [Location] { clusters.map { $0.centroid } }
23-
24-
private(set) var dbi = Double.greatestFiniteMagnitude
25-
23+
24+
/// Clustering의 Silhouette 점수
25+
private(set) var silhouetteScore: Double = 0
26+
2627
override var isAsynchronous: Bool { true }
2728

2829
// MARK: - Initializers
@@ -39,7 +40,7 @@ final class KMeans<T: ClusterData>: Operation {
3940
guard !isCancelled else { return }
4041

4142
run()
42-
daviesBouldInIndex()
43+
setSilhouetteScore()
4344
}
4445
}
4546

@@ -157,25 +158,65 @@ private extension KMeans {
157158
}
158159
}
159160

160-
// MARK: DBI Method
161-
extension KMeans {
162-
func daviesBouldInIndex() {
163-
var sum: Double = 0
164-
let deviations = clusters.map { $0.deviation() }
165-
166-
for i in 0..<clusters.count {
167-
var maxValue: Double = 0
168-
for j in 0..<clusters.count where i != j {
169-
let sumOfDevations = deviations[i] + deviations[j]
170-
171-
let distanceCenters = clusters[i].centroid.distance(with: clusters[j].centroid)
172-
173-
maxValue = max(maxValue, sumOfDevations / distanceCenters)
161+
// MARK: - Validation Method
162+
private extension KMeans {
163+
/// silhouetteScore을 계산하여 silhouetteScore 변수에 대입합니다.
164+
/// 시간 복잡도 : `O(n^2)`
165+
func setSilhouetteScore() {
166+
self.silhouetteScore = -1
167+
168+
clusters.forEach { cluster in
169+
if cluster.size == 0 {
170+
silhouetteScore = max(silhouetteScore, 0)
171+
} else {
172+
silhouetteScore = max(silhouetteScore, meanSilhouetteCoefficient(at: cluster))
174173
}
175-
176-
sum += maxValue
177174
}
175+
}
176+
177+
/// cluster의 평균 Silhouette 계수를 리턴합니다.
178+
func meanSilhouetteCoefficient(at cluster: Cluster<T>) -> Double {
179+
let cohesionCoefficients = cohesionCoefficients(cluster)
180+
var minSeparationCoefficients = [Double](
181+
repeating: .greatestFiniteMagnitude,
182+
count: cluster.size
183+
)
178184

179-
dbi = sum / Double(clusters.count)
185+
clusters.filter { $0 != cluster }
186+
.forEach { otherCluster in
187+
minSeparationCoefficients = zip(
188+
seperationCoefficients(cluster, with: otherCluster),
189+
minSeparationCoefficients
190+
)
191+
.map { min($0.0, $0.1) }
192+
}
193+
194+
let zipCoefficients = zip(cohesionCoefficients, minSeparationCoefficients)
195+
196+
let dividend = zipCoefficients.map { $1 - $0 }
197+
let divisor = zipCoefficients.map { max($0, $1) }
198+
199+
let sumOfSilhouetteIndex = zip(dividend, divisor)
200+
.map { $0 / $1 }
201+
.reduce(0) { $0 + $1 }
202+
203+
return sumOfSilhouetteIndex / Double(cluster.size)
204+
}
205+
206+
/// 해당 클러스터 내의 모든점의 cohesion 계수를 리턴합니다.
207+
/// 각 점이 속한 클러스터간의 응집도를 판별합니다.
208+
func cohesionCoefficients(_ cluster: Cluster<T>) -> [Double] {
209+
cluster.group.allValues()
210+
.map { cluster.deviation(from: $0) }
211+
}
212+
213+
/// 클러스터내의 모든점의 seperation 계수를 리턴합니다.
214+
/// `with`파라미터를 통해 전달한 클러스터 간의 응집도를 판별합니다.
215+
func seperationCoefficients(
216+
_ cluster: Cluster<T>,
217+
with other: Cluster<T>
218+
) -> [Double] {
219+
cluster.group.allValues()
220+
.map { other.deviation(from: $0) }
180221
}
181222
}

Sources/Clustering/LinkedList/LinkedList.swift

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ final class LinkedList<T: Equatable> {
1616
private var tail: Node<T>?
1717

1818
/// 연결리스트의 현재 노드를 리턴합니다.
19-
var now: Node<T>?
20-
19+
private var now: Node<T>?
20+
2121
/// 연결리스트의 사이즈를 리턴합니다.
2222
private(set) var size: Int
2323

@@ -158,6 +158,11 @@ extension LinkedList {
158158
return -1
159159
}
160160

161+
/// 해당 값을 가진 노드가 있는지 여부를 리턴합니다.
162+
func contain(_ target: T) -> Bool {
163+
indexOf(target) != -1 ? true : false
164+
}
165+
161166
/// `other`을 현재 링크 리스트 뒤에 병합합니다.
162167
func merge(other: LinkedList<T>) {
163168
self.tail?.moveBetween(tail?.prev, and: other.head)

0 commit comments

Comments
 (0)