Skip to content

Commit 8e46927

Browse files
committed
Fix eos token logic
1 parent c0cc2a3 commit 8e46927

File tree

3 files changed

+23
-11
lines changed

3 files changed

+23
-11
lines changed

Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,7 @@ import Foundation
974974
batch: batchPointer,
975975
position: initialPosition,
976976
maximumTokens: maxTokens,
977+
endTokens: [],
977978
tokenToTextFn: { [self] token in self.tokenToText(vocab: vocab, token: llama_token(token)) }
978979
)
979980
var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema)
@@ -1004,6 +1005,7 @@ import Foundation
10041005
batch: UnsafeMutablePointer<llama_batch>,
10051006
position: Int32,
10061007
maximumTokens: Int,
1008+
endTokens: Set<Int>? = nil,
10071009
tokenToTextFn: @escaping (Int) -> String?
10081010
) {
10091011
self.context = context
@@ -1016,9 +1018,13 @@ import Foundation
10161018
self.totalTokenBudget = maximumTokens
10171019
self.eosToken = Int(llama_vocab_eos(vocab))
10181020

1019-
let eotTokenValue = llama_vocab_eot(vocab)
1020-
let endOfTurnToken = eotTokenValue != LLAMA_TOKEN_NULL ? Int(eotTokenValue) : eosToken
1021-
self.endTokens = [self.eosToken, endOfTurnToken]
1021+
if let endTokens {
1022+
self.endTokens = endTokens
1023+
} else {
1024+
let eotTokenValue = llama_vocab_eot(vocab)
1025+
let endOfTurnToken = eotTokenValue != LLAMA_TOKEN_NULL ? Int(eotTokenValue) : eosToken
1026+
self.endTokens = [self.eosToken, endOfTurnToken]
1027+
}
10221028

10231029
self.tokenToTextFn = tokenToTextFn
10241030
self.tokensExcludedFromRepetitionPenalty = Self.buildTokensExcludedFromRepetitionPenalty(

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,8 @@ import Foundation
726726
context: context,
727727
input: lmInput,
728728
parameters: generateParameters,
729-
maximumTokens: maxTokens
729+
maximumTokens: maxTokens,
730+
endTokens: []
730731
)
731732

732733
var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema)
@@ -805,7 +806,8 @@ import Foundation
805806
context: ModelContext,
806807
input: MLXLMCommon.LMInput,
807808
parameters: MLXLMCommon.GenerateParameters,
808-
maximumTokens: Int
809+
maximumTokens: Int,
810+
endTokens: Set<Int>? = nil
809811
) throws {
810812
self.model = context.model
811813
self.tokenizer = context.tokenizer
@@ -819,11 +821,15 @@ import Foundation
819821
throw StructuredGenerationError.invalidVocabSize
820822
}
821823
self.eosToken = eosTokenId
822-
self.endTokens = Self.buildEndTokens(
823-
eosTokenId: eosTokenId,
824-
tokenizer: context.tokenizer,
825-
configuration: context.configuration
826-
)
824+
if let endTokens {
825+
self.endTokens = endTokens
826+
} else {
827+
self.endTokens = Self.buildEndTokens(
828+
eosTokenId: eosTokenId,
829+
tokenizer: context.tokenizer,
830+
configuration: context.configuration
831+
)
832+
}
827833

828834
self.tokensExcludedFromRepetitionPenalty = Self.buildTokensExcludedFromRepetitionPenalty(
829835
tokenizer: context.tokenizer

Sources/AnyLanguageModel/StructuredGeneration.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct ConstrainedJSONGenerator<Backend: TokenBackend> {
5151
let quoteToken = try Self.singleToken(for: "\"", backend: backend)
5252
self.quoteToken = quoteToken
5353

54-
self.stringTerminators = [quoteToken]
54+
self.stringTerminators = backend.endTokens.union([quoteToken])
5555

5656
var structuralTerminators = Set<Int>()
5757
for structuralText in [",", "}", "]", ":"] {

0 commit comments

Comments
 (0)