Skip to content

Commit c521290

Browse files
authored
refactor: Zenz関連の処理の変更を容易にするためリファクタリングを実施 (#325)
* refactor: 内部的にリファクタの概念を導入 * feat: split file * refactor: remove predictNextCharacter-related features as it is no longer used * refactor: remove zenz-v1 support * refactor: remove unused api * feat: refactor to make ZenzContext more simple
1 parent 42e03d6 commit c521290

File tree

16 files changed

+691
-816
lines changed

16 files changed

+691
-816
lines changed

Sources/CliTool/Anco.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ struct Anco: AsyncParsableCommand {
1111
Subcommands.Evaluate.self,
1212
Subcommands.ZenzEvaluate.self,
1313
Subcommands.Session.self,
14-
Subcommands.ExperimentalPredict.self,
1514
Subcommands.NGram.self
1615
],
1716
defaultSubcommand: Subcommands.Run.self

Sources/CliTool/Subcommands/ExperimentalPredict.swift

Lines changed: 0 additions & 47 deletions
This file was deleted.

Sources/CliTool/Subcommands/SessionCommand.swift

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ extension Subcommands {
4242
var configZenzaiProfile: String?
4343
@Option(name: [.customLong("config_topic")], help: "enable topic prompting for zenz-v3 and later.")
4444
var configZenzaiTopic: String?
45-
@Flag(name: [.customLong("zenz_v1")], help: "Use zenz_v1 model.")
46-
var zenzV1 = false
4745
@Flag(name: [.customLong("zenz_v2")], help: "Use zenz_v2 model.")
4846
var zenzV2 = false
4947
@Flag(name: [.customLong("zenz_v3")], help: "Use zenz_v3 model.")
@@ -86,13 +84,13 @@ extension Subcommands {
8684
}
8785

8886
@MainActor mutating func run() async {
89-
if self.zenzV1 || self.zenzV2 {
87+
if self.zenzV2 {
9088
print("\(bold: "We strongly recommend to use zenz-v3 models")")
9189
}
92-
if (self.zenzV1 || self.zenzV2 || self.zenzV3) && self.zenzWeightPath.isEmpty {
90+
if (self.zenzV2 || self.zenzV3) && self.zenzWeightPath.isEmpty {
9391
preconditionFailure("\(bold: "zenz version is specified but --zenz weight is not specified")")
9492
}
95-
if !self.zenzWeightPath.isEmpty && (!self.zenzV1 && !self.zenzV2 && !self.zenzV3) {
93+
if !self.zenzWeightPath.isEmpty && (!self.zenzV2 && !self.zenzV3) {
9694
print("zenz version is not specified. By default, zenz-v3 will be used.")
9795
}
9896

@@ -193,17 +191,6 @@ extension Subcommands {
193191
print("anything should not be saved because the learning type is not for update memory")
194192
}
195193
continue
196-
case ":p", ":pred":
197-
// 次の文字の予測を取得する
198-
let results = converter.predictNextCharacter(
199-
leftSideContext: leftSideContext,
200-
count: 10,
201-
options: requestOptions(learningType: learningType, memoryDirectory: memoryDirectory, leftSideContext: leftSideContext)
202-
)
203-
if let firstCandidate = results.first {
204-
leftSideContext.append(firstCandidate.character)
205-
}
206-
continue
207194
case let command where command == ":ip" || command.hasPrefix(":ip "):
208195
// 入力中の次の文字の予測を取得する (zenz-v3)
209196
let parts = command.split(separator: " ")
@@ -263,7 +250,6 @@ extension Subcommands {
263250
\(bold: ":d, :del") - delete one character
264251
\(bold: ":n, :next") - see more candidates
265252
\(bold: ":s, :save") - save memory to temporary directory
266-
\(bold: ":p, :pred") - predict next one character
267253
\(bold: ":ip [n] [max_entropy=F] [min_length=N]") - predict next input character(s) (zenz-v3)
268254
\(bold: ":%d") - select candidate at that index (like :3 to select 3rd candidate)
269255
\(bold: ":ctx %s") - set the string as context
@@ -350,9 +336,7 @@ extension Subcommands {
350336
}
351337

352338
func requestOptions(learningType: LearningType, memoryDirectory: URL, leftSideContext: String?) -> ConvertRequestOptions {
353-
let zenzaiVersionDependentMode: ConvertRequestOptions.ZenzaiVersionDependentMode = if self.zenzV1 {
354-
.v1
355-
} else if self.zenzV2 {
339+
let zenzaiVersionDependentMode: ConvertRequestOptions.ZenzaiVersionDependentMode = if self.zenzV2 {
356340
.v2(.init(profile: self.configZenzaiProfile, leftSideContext: leftSideContext))
357341
} else {
358342
.v3(.init(profile: self.configZenzaiProfile, topic: self.configZenzaiTopic, leftSideContext: leftSideContext))
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import Foundation
2+
import HeapModule
3+
4+
struct FixedSizeHeap<Element: Comparable> {
5+
private var size: Int
6+
private var heap: Heap<Element>
7+
8+
init(size: Int) {
9+
self.size = size
10+
self.heap = []
11+
}
12+
13+
mutating func removeMax() {
14+
self.heap.removeMax()
15+
}
16+
17+
mutating func removeMin() {
18+
self.heap.removeMin()
19+
}
20+
21+
@discardableResult
22+
mutating func insertIfPossible(_ element: Element) -> Bool {
23+
if self.heap.count < self.size {
24+
self.heap.insert(element)
25+
return true
26+
} else if let min = self.heap.min, element > min {
27+
self.heap.replaceMin(with: element)
28+
return true
29+
} else {
30+
return false
31+
}
32+
}
33+
34+
var unordered: [Element] {
35+
self.heap.unordered
36+
}
37+
38+
var max: Element? {
39+
self.heap.max
40+
}
41+
42+
var min: Element? {
43+
self.heap.min
44+
}
45+
46+
var isEmpty: Bool {
47+
self.heap.isEmpty
48+
}
49+
}

Sources/KanaKanjiConverterModule/ConversionAlgorithms/Zenzai/Zenz/Zenz.swift

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ package final class Zenz {
2626
}
2727

2828
package func endSession() {
29-
try? self.zenzContext?.reset_context()
29+
try? self.zenzContext?.resetContext()
3030
}
3131

3232
func candidateEvaluate(
@@ -36,12 +36,13 @@ package final class Zenz {
3636
prefixConstraint: Kana2Kanji.PrefixConstraint,
3737
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: EfficientNGram, personal: EfficientNGram)?,
3838
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
39-
) -> ZenzContext.CandidateEvaluationResult {
39+
) -> CandidateEvaluationResult {
4040
guard let zenzContext else {
4141
return .error
4242
}
4343
for candidate in candidates {
44-
return zenzContext.evaluate_candidate(
44+
return ZenzCandidateEvaluator.evaluate(
45+
context: zenzContext,
4546
input: convertTarget.toKatakana(),
4647
candidate: candidate,
4748
requestRichCandidates: requestRichCandidates,
@@ -53,20 +54,6 @@ package final class Zenz {
5354
return .error
5455
}
5556

56-
func predictNextCharacter(leftSideContext: String, count: Int) -> [(character: Character, value: Float)] {
57-
guard let zenzContext else {
58-
return []
59-
}
60-
return zenzContext.predict_next_character(leftSideContext: leftSideContext, count: count)
61-
}
62-
63-
func predictNextInputCharacter(leftSideContext: String, composingText: String, count: Int, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> [(character: Character, value: Float)] {
64-
guard let zenzContext else {
65-
return []
66-
}
67-
return zenzContext.predict_next_input_character(leftSideContext: leftSideContext, composingText: composingText, count: count, versionDependentConfig: versionDependentConfig)
68-
}
69-
7057
func predictNextInputText(
7158
leftSideContext: String,
7259
composingText: String,
@@ -79,7 +66,8 @@ package final class Zenz {
7966
guard let zenzContext else {
8067
return ""
8168
}
82-
return zenzContext.predict_next_input_text(
69+
return ZenzInputTextGenerator.generate(
70+
context: zenzContext,
8371
leftSideContext: leftSideContext,
8472
composingText: composingText,
8573
count: count,
@@ -91,6 +79,9 @@ package final class Zenz {
9179
}
9280

9381
package func pureGreedyDecoding(pureInput: String, maxCount: Int = .max) -> String {
94-
self.zenzContext?.pure_greedy_decoding(leftSideContext: pureInput, maxCount: maxCount) ?? ""
82+
guard let zenzContext else {
83+
return ""
84+
}
85+
return ZenzPureGreedyDecoder.decode(context: zenzContext, leftSideContext: pureInput, maxCount: maxCount)
9586
}
9687
}

0 commit comments

Comments
 (0)