Skip to content

Commit 4442981

Browse files
authored
feat: Zenzaiによる予測変換を行う際、ローマ字入力中の未確定suffixの処理を改善 (#322)
* feat: add ip command * feat: efficient prediction * feat: support zenzai based ip in normal conversion * feat: roman input時のzenzaiによる予測入力を改善
1 parent 8e9784c commit 4442981

File tree

4 files changed

+124
-20
lines changed

4 files changed

+124
-20
lines changed

Sources/CliTool/Subcommands/SessionCommand.swift

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,18 +232,25 @@ extension Subcommands {
232232
let predictCount = max(1, min(requestedCount, 50))
233233
let predictMinLength = max(1, min(minLength, predictCount))
234234
let ipStart = Date()
235-
let predictedText = converter.predictNextInputText(
235+
let (predictedText, suffixCount) = converter.predictNextInputText(
236236
leftSideContext: leftSideContext,
237-
composingText: composingText.convertTarget,
237+
composingText: composingText,
238238
count: predictCount,
239239
minLength: predictMinLength,
240240
maxEntropy: maxEntropy,
241-
options: requestOptions(learningType: learningType, memoryDirectory: memoryDirectory, leftSideContext: leftSideContext)
241+
options: requestOptions(learningType: learningType, memoryDirectory: memoryDirectory, leftSideContext: leftSideContext),
242+
inputStyle: inputStyle,
243+
debugPossibleNexts: true
242244
)
243245
guard !predictedText.isEmpty else {
244246
continue
245247
}
246248
print("\(bold: "Time (ip):") \(-ipStart.timeIntervalSinceNow)")
249+
250+
if suffixCount > 0 {
251+
composingText.deleteBackwardFromCursorPosition(count: suffixCount)
252+
}
253+
247254
let insertText = (inputStyle == .roman2kana) ? predictedText.toHiragana() : predictedText
248255
composingText.insertAtCursorPosition(insertText, inputStyle: inputStyle)
249256
input = insertText

Sources/KanaKanjiConverterModule/ConversionAlgorithms/Zenzai/Zenz/Zenz.swift

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,15 @@ package final class Zenz {
6767
return zenzContext.predict_next_input_character(leftSideContext: leftSideContext, composingText: composingText, count: count, versionDependentConfig: versionDependentConfig)
6868
}
6969

70-
func predictNextInputText(leftSideContext: String, composingText: String, count: Int, minLength: Int = 1, maxEntropy: Float?, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> String {
70+
func predictNextInputText(
71+
leftSideContext: String,
72+
composingText: String,
73+
count: Int,
74+
minLength: Int = 1,
75+
maxEntropy: Float?,
76+
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode,
77+
possibleNexts: [String] = []
78+
) -> String {
7179
guard let zenzContext else {
7280
return ""
7381
}
@@ -77,7 +85,8 @@ package final class Zenz {
7785
count: count,
7886
minLength: minLength,
7987
maxEntropy: maxEntropy,
80-
versionDependentConfig: versionDependentConfig
88+
versionDependentConfig: versionDependentConfig,
89+
possibleNexts: possibleNexts
8190
)
8291
}
8392

Sources/KanaKanjiConverterModule/ConversionAlgorithms/Zenzai/Zenz/ZenzContext.swift

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,15 @@ final class ZenzContext {
434434
return minHeap.unordered.sorted { $0.value > $1.value }.map { ($0.character, $0.value / exp_sum) }
435435
}
436436

437-
func predict_next_input_text(leftSideContext: String, composingText: String, count: Int, minLength: Int = 1, maxEntropy: Float?, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> String {
437+
func predict_next_input_text(
438+
leftSideContext: String,
439+
composingText: String,
440+
count: Int,
441+
minLength: Int = 1,
442+
maxEntropy: Float?,
443+
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode,
444+
possibleNexts: [String] = []
445+
) -> String {
438446
guard count > 0 else {
439447
return ""
440448
}
@@ -466,6 +474,18 @@ final class ZenzContext {
466474
let maxLeftSideContextLength = mode.maxLeftSideContextLength ?? 40
467475
let trimmedLeftContext = leftSideContext.isEmpty ? "" : String(leftSideContext.suffix(maxLeftSideContextLength))
468476
let input = composingText.toKatakana()
477+
let allowedPrefixes: [String] = possibleNexts.filter { !$0.isEmpty }
478+
479+
@inline(__always)
480+
func isAllowedPrefix(_ candidate: String) -> Bool {
481+
guard !allowedPrefixes.isEmpty else {
482+
return true
483+
}
484+
let normalized = candidate.toKatakana()
485+
return allowedPrefixes.contains(where: {
486+
$0.hasPrefix(normalized)
487+
})
488+
}
469489

470490
let prompt: String = if trimmedLeftContext.isEmpty {
471491
conditions.joined(separator: "") + inputTag + input
@@ -478,6 +498,7 @@ final class ZenzContext {
478498
let stopCharacters: Set<Character> = ["", "", "", ""]
479499
var predictedCharacters: [Character] = []
480500
predictedCharacters.reserveCapacity(count)
501+
var predictedText = ""
481502

482503
for _ in 0..<count {
483504
let startOffset = prompt_tokens.count - 1
@@ -498,6 +519,7 @@ final class ZenzContext {
498519
var sumexpX: Float = 0
499520
var bestValue: Float = -Float.infinity
500521
var bestCharacter: Character?
522+
var bestNextText: String = ""
501523
for index in startIndex..<endIndex {
502524
let token = llama_token(index - startIndex)
503525
let repeat_penalty = Float(1.0 + token_to_penalty_weight[token, default: 0])
@@ -513,8 +535,15 @@ final class ZenzContext {
513535
guard let validCharacter = String(data: tokenPieceData, encoding: .utf8), let c = validCharacter.first else {
514536
continue
515537
}
538+
539+
let nextText = predictedText + String(c)
540+
guard isAllowedPrefix(nextText) else {
541+
continue
542+
}
543+
516544
bestValue = value
517545
bestCharacter = c
546+
bestNextText = nextText
518547
}
519548

520549
if let maxEntropy, predictedCharacters.count >= minLength, sumexp > 0 {
@@ -532,7 +561,12 @@ final class ZenzContext {
532561
break
533562
}
534563

564+
if !isAllowedPrefix(bestNextText) {
565+
break
566+
}
567+
535568
predictedCharacters.append(bestCharacter)
569+
predictedText = bestNextText
536570
let appendedTokens = self.tokenize(text: self.preprocessText(text: String(bestCharacter)), add_bos: false, add_eos: false)
537571
if appendedTokens.isEmpty {
538572
break

Sources/KanaKanjiConverterModule/ConverterAPI/KanaKanjiConverter.swift

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,22 +120,59 @@ public final class KanaKanjiConverter {
120120
)
121121
}
122122

123-
public func predictNextInputText(leftSideContext: String, composingText: String, count: Int, minLength: Int = 1, maxEntropy: Float?, options: ConvertRequestOptions) -> String {
123+
package func predictNextInputText(
124+
leftSideContext: String,
125+
composingText: ComposingText,
126+
count: Int,
127+
minLength: Int = 1,
128+
maxEntropy: Float?,
129+
options: ConvertRequestOptions,
130+
inputStyle: InputStyle = .direct,
131+
debugPossibleNexts: Bool = false
132+
) -> (predictedText: String, suffixCount: Int) {
124133
guard let zenz = self.getModel(modelURL: options.zenzaiMode.weightURL) else {
125134
print("zenz-v3 model unavailable")
126-
return ""
135+
return ("", 0)
127136
}
128137
guard options.zenzaiMode.versionDependentMode.version == .v3 else {
129138
print("input prediction requires zenz-v3 models")
130-
return ""
139+
return ("", 0)
131140
}
132-
return zenz.predictNextInputText(
133-
leftSideContext: leftSideContext,
134-
composingText: composingText,
135-
count: count,
136-
minLength: minLength,
137-
maxEntropy: maxEntropy,
138-
versionDependentConfig: options.zenzaiMode.versionDependentMode
141+
let (baseComposeText, resolvedPossibleNexts, suffixCount): (
142+
baseConvertTarget: String,
143+
resolvedPossibleNexts: [String],
144+
droppedSuffixCount: Int
145+
) = {
146+
if inputStyle == .direct {
147+
return (composingText.convertTarget, [], 0)
148+
}
149+
let table: InputTable
150+
if case .roman2kana = inputStyle {
151+
table = InputStyleManager.shared.table(for: .defaultRomanToKana)
152+
} else if case .mapped(let id) = inputStyle {
153+
table = InputStyleManager.shared.table(for: id)
154+
} else {
155+
return (composingText.convertTarget, [], 0)
156+
}
157+
if let suffixInfo = self.romanSuffixAndPossibleNexts(composingText: composingText, table: table) {
158+
return (suffixInfo.baseConvertTarget, suffixInfo.possibleNexts, composingText.convertTarget.count - suffixInfo.baseConvertTarget.count)
159+
}
160+
return (composingText.convertTarget, [], 0)
161+
}()
162+
if debugPossibleNexts {
163+
print("possibleNexts:", resolvedPossibleNexts)
164+
}
165+
return (
166+
zenz.predictNextInputText(
167+
leftSideContext: leftSideContext,
168+
composingText: baseComposeText,
169+
count: count,
170+
minLength: minLength,
171+
maxEntropy: maxEntropy,
172+
versionDependentConfig: options.zenzaiMode.versionDependentMode,
173+
possibleNexts: resolvedPossibleNexts
174+
),
175+
suffixCount
139176
)
140177
}
141178

@@ -432,21 +469,26 @@ public final class KanaKanjiConverter {
432469
case .v3(let mode):
433470
mode.leftSideContext ?? ""
434471
}
435-
let predictedText = self.predictNextInputText(
472+
473+
let inputStyle = composingText.input.last?.inputStyle ?? .direct
474+
let (predictedText, suffixCount) = self.predictNextInputText(
436475
leftSideContext: leftSideContext,
437-
composingText: composingText.convertTarget,
476+
composingText: composingText,
438477
count: 10,
439478
minLength: 1,
440479
maxEntropy: 3.0,
441-
options: options
480+
options: options,
481+
inputStyle: inputStyle
442482
)
443483
guard !predictedText.isEmpty else {
444484
return []
445485
}
446486

447-
let inputStyle = composingText.input.last?.inputStyle ?? .direct
448487
let insertText = (inputStyle == .roman2kana) ? predictedText.toHiragana() : predictedText
449488
var predictedComposingText = composingText
489+
if suffixCount > 0 {
490+
predictedComposingText.deleteBackwardFromCursorPosition(count: suffixCount)
491+
}
450492
predictedComposingText.insertAtCursorPosition(insertText, inputStyle: inputStyle)
451493

452494
var fallbackOptions = options
@@ -471,6 +513,18 @@ public final class KanaKanjiConverter {
471513
return [firstCandidate]
472514
}
473515

516+
private func romanSuffixAndPossibleNexts(composingText: ComposingText, table: InputTable) -> (baseConvertTarget: String, possibleNexts: [String])? {
517+
let romanSuffix = composingText.convertTarget.suffix(while: {String($0).onlyRomanAlphabet})
518+
guard !romanSuffix.isEmpty else {
519+
return nil
520+
}
521+
let possibleNexts = table.possibleNexts[String(romanSuffix), default: []]
522+
guard !possibleNexts.isEmpty else {
523+
return nil
524+
}
525+
return (String(composingText.convertTarget.dropLast(romanSuffix.count)), possibleNexts)
526+
}
527+
474528
/// トップレベルに追加する付加的な変換候補を生成する関数
475529
/// - Parameters:
476530
/// - inputData: 変換対象のInputData。

0 commit comments

Comments
 (0)