From 9be5810ee0b38a70ea16d3d75f91ffbf3768824e Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 10 Dec 2025 05:45:05 -0800 Subject: [PATCH 1/6] Update LlamaLanguageModel to replace implementation-specific properties with custom generation options --- .../Models/LlamaLanguageModel.swift | 509 ++++++++++++------ .../LlamaLanguageModelTests.swift | 91 +++- 2 files changed, 433 insertions(+), 167 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index ca91638f..8f7ed740 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -79,6 +79,27 @@ import Foundation /// ) /// ``` public struct CustomGenerationOptions: AnyLanguageModel.CustomGenerationOptions, Codable { + /// Context size to allocate for the model. + public var contextSize: UInt32? + + /// Batch size to use when evaluating tokens. + public var batchSize: UInt32? + + /// Number of threads to use for computation. + public var threads: Int32? + + /// Random seed for deterministic sampling. + public var seed: UInt32? + + /// Sampling temperature. + public var temperature: Float? + + /// Top-K sampling parameter. + public var topK: Int32? + + /// Top-P (nucleus) sampling parameter. + public var topP: Float? + /// The penalty applied to repeated tokens. /// /// Values greater than 1.0 discourage repetition, while values less than 1.0 @@ -117,49 +138,145 @@ import Foundation /// Creates custom generation options for llama.cpp. public init( + contextSize: UInt32? = nil, + batchSize: UInt32? = nil, + threads: Int32? = nil, + seed: UInt32? = nil, + temperature: Float? = nil, + topK: Int32? = nil, + topP: Float? = nil, repeatPenalty: Float? = nil, repeatLastN: Int32? = nil, frequencyPenalty: Float? = nil, presencePenalty: Float? = nil, mirostat: MirostatMode? = nil ) { + self.contextSize = contextSize + self.batchSize = batchSize + self.threads = threads + self.seed = seed + self.temperature = temperature + self.topK = topK + self.topP = topP self.repeatPenalty = repeatPenalty self.repeatLastN = repeatLastN self.frequencyPenalty = frequencyPenalty self.presencePenalty = presencePenalty self.mirostat = mirostat } + + /// Default llama.cpp options used when none are provided at runtime. + public static func defaults( + seed: UInt32 = UInt32.random(in: 0 ... UInt32.max) + ) -> Self { + .init( + contextSize: 2048, + batchSize: 512, + threads: Int32(ProcessInfo.processInfo.processorCount), + seed: seed, + temperature: 0.8, + topK: 40, + topP: 0.95, + repeatPenalty: 1.1, + repeatLastN: 64, + frequencyPenalty: 0.0, + presencePenalty: 0.0, + mirostat: nil + ) + } + + /// Returns a new options struct with values from `overrides` when provided. + func merging(overrides: CustomGenerationOptions?) -> CustomGenerationOptions { + guard let overrides else { return self } + return CustomGenerationOptions( + contextSize: overrides.contextSize ?? self.contextSize, + batchSize: overrides.batchSize ?? self.batchSize, + threads: overrides.threads ?? self.threads, + seed: overrides.seed ?? self.seed, + temperature: overrides.temperature ?? self.temperature, + topK: overrides.topK ?? self.topK, + topP: overrides.topP ?? self.topP, + repeatPenalty: overrides.repeatPenalty ?? self.repeatPenalty, + repeatLastN: overrides.repeatLastN ?? self.repeatLastN, + frequencyPenalty: overrides.frequencyPenalty ?? self.frequencyPenalty, + presencePenalty: overrides.presencePenalty ?? self.presencePenalty, + mirostat: overrides.mirostat ?? self.mirostat + ) + } } /// The path to the GGUF model file. public let modelPath: String /// The context size for the model. - public let contextSize: UInt32 + /// + /// - Important: This property is deprecated. Use ``GenerationOptions`` with + /// custom options instead: + /// ```swift + /// var options = GenerationOptions() + /// options[custom: LlamaLanguageModel.self] = .init(contextSize: 4096) + /// ``` + @available(*, deprecated, message: "Use GenerationOptions custom options instead") + public var contextSize: UInt32 { legacyDefaults.contextSize } /// The batch size for processing. - public let batchSize: UInt32 + /// + /// - Important: This property is deprecated. Use ``GenerationOptions`` with + /// custom options instead. + @available(*, deprecated, message: "Use GenerationOptions custom options instead") + public var batchSize: UInt32 { legacyDefaults.batchSize } /// The number of threads to use. - public let threads: Int32 + /// + /// - Important: This property is deprecated. Use ``GenerationOptions`` with + /// custom options instead. + @available(*, deprecated, message: "Use GenerationOptions custom options instead") + public var threads: Int32 { legacyDefaults.threads } /// The random seed for generation. - public let seed: UInt32 + /// + /// - Important: This property is deprecated. Use ``GenerationOptions`` with + /// custom options instead. + @available(*, deprecated, message: "Use GenerationOptions custom options instead") + public var seed: UInt32 { legacyDefaults.seed } /// The temperature for sampling. - public let temperature: Float + /// + /// - Important: This property is deprecated. Use ``GenerationOptions`` with + /// custom options instead. + @available(*, deprecated, message: "Use GenerationOptions custom options instead") + public var temperature: Float { legacyDefaults.temperature } /// The top-K sampling parameter. - public let topK: Int32 + /// + /// - Important: This property is deprecated. Use ``GenerationOptions`` with + /// custom options instead. + @available(*, deprecated, message: "Use GenerationOptions custom options instead") + public var topK: Int32 { legacyDefaults.topK } /// The top-P (nucleus) sampling parameter. - public let topP: Float + /// + /// - Important: This property is deprecated. Use ``GenerationOptions`` with + /// custom options instead. + @available(*, deprecated, message: "Use GenerationOptions custom options instead") + public var topP: Float { legacyDefaults.topP } /// The repeat penalty for generation. - public let repeatPenalty: Float + /// + /// - Important: This property is deprecated. Use ``GenerationOptions`` with + /// custom options instead. + @available(*, deprecated, message: "Use GenerationOptions custom options instead") + public var repeatPenalty: Float { legacyDefaults.repeatPenalty } /// The number of tokens to consider for repeat penalty. - public let repeatLastN: Int32 + /// + /// - Important: This property is deprecated. Use ``GenerationOptions`` with + /// custom options instead. + @available(*, deprecated, message: "Use GenerationOptions custom options instead") + public var repeatLastN: Int32 { legacyDefaults.repeatLastN } + + /// Normalized legacy defaults used for deprecated properties. + private let legacyDefaults: ResolvedGenerationOptions /// The minimum log level for llama.cpp output. /// @@ -172,6 +289,111 @@ import Foundation } } + /// Resolved, non-optional defaults for llama.cpp runtime parameters. + internal struct ResolvedGenerationOptions: Sendable { + var contextSize: UInt32 + var batchSize: UInt32 + var threads: Int32 + var seed: UInt32 + var temperature: Float + var topK: Int32 + var topP: Float + var repeatPenalty: Float + var repeatLastN: Int32 + var frequencyPenalty: Float + var presencePenalty: Float + var mirostat: CustomGenerationOptions.MirostatMode? + var sampling: GenerationOptions.SamplingMode? + var maximumResponseTokens: Int? + + init( + contextSize: UInt32 = 2048, + batchSize: UInt32 = 512, + threads: Int32 = Int32(ProcessInfo.processInfo.processorCount), + seed: UInt32 = UInt32.random(in: 0 ... UInt32.max), + temperature: Float = 0.8, + topK: Int32 = 40, + topP: Float = 0.95, + repeatPenalty: Float = 1.1, + repeatLastN: Int32 = 64, + frequencyPenalty: Float = 0.0, + presencePenalty: Float = 0.0, + mirostat: CustomGenerationOptions.MirostatMode? = nil, + sampling: GenerationOptions.SamplingMode? = nil, + maximumResponseTokens: Int? = nil + ) { + self.contextSize = contextSize + self.batchSize = batchSize + self.threads = threads + self.seed = seed + self.temperature = temperature + self.topK = topK + self.topP = topP + self.repeatPenalty = repeatPenalty + self.repeatLastN = repeatLastN + self.frequencyPenalty = frequencyPenalty + self.presencePenalty = presencePenalty + self.mirostat = mirostat + self.sampling = sampling + self.maximumResponseTokens = maximumResponseTokens + } + + init( + from options: CustomGenerationOptions?, + sampling: GenerationOptions.SamplingMode? = nil, + maximumResponseTokens: Int? = nil + ) { + self.init( + base: ResolvedGenerationOptions(), + overrides: options, + sampling: sampling, + maximumResponseTokens: maximumResponseTokens + ) + } + + init( + base: ResolvedGenerationOptions = .init(), + overrides options: CustomGenerationOptions?, + sampling: GenerationOptions.SamplingMode? = nil, + maximumResponseTokens: Int? = nil + ) { + guard let options else { + self = ResolvedGenerationOptions( + contextSize: base.contextSize, + batchSize: base.batchSize, + threads: base.threads, + seed: base.seed, + temperature: base.temperature, + topK: base.topK, + topP: base.topP, + repeatPenalty: base.repeatPenalty, + repeatLastN: base.repeatLastN, + frequencyPenalty: base.frequencyPenalty, + presencePenalty: base.presencePenalty, + mirostat: base.mirostat, + sampling: sampling ?? base.sampling, + maximumResponseTokens: maximumResponseTokens ?? base.maximumResponseTokens + ) + return + } + + self.contextSize = options.contextSize ?? base.contextSize + self.batchSize = options.batchSize ?? base.batchSize + self.threads = options.threads ?? base.threads + self.seed = options.seed ?? base.seed + self.temperature = options.temperature ?? base.temperature + self.topK = options.topK ?? base.topK + self.topP = options.topP ?? base.topP + self.repeatPenalty = options.repeatPenalty ?? base.repeatPenalty + self.repeatLastN = options.repeatLastN ?? base.repeatLastN + self.frequencyPenalty = options.frequencyPenalty ?? base.frequencyPenalty + self.presencePenalty = options.presencePenalty ?? base.presencePenalty + self.mirostat = options.mirostat ?? base.mirostat + self.sampling = sampling ?? base.sampling + self.maximumResponseTokens = maximumResponseTokens ?? base.maximumResponseTokens + } + } + /// The loaded model instance private var model: OpaquePointer? @@ -185,16 +407,22 @@ import Foundation /// /// - Parameters: /// - modelPath: The path to the GGUF model file. - /// - contextSize: The context size for the model. Defaults to 2048. - /// - batchSize: The batch size for processing. Defaults to 512. - /// - threads: The number of threads to use. Defaults to the number of processors. - /// - seed: The random seed for generation. Defaults to a random value. - /// - temperature: The temperature for sampling. Defaults to 0.8. - /// - topK: The top-K sampling parameter. Defaults to 40. - /// - topP: The top-P (nucleus) sampling parameter. Defaults to 0.95. - /// - repeatPenalty: The repeat penalty for generation. Defaults to 1.1. - /// - repeatLastN: The number of tokens to consider for repeat penalty. Defaults to 64. - public init( + public init(modelPath: String) { + self.modelPath = modelPath + self.legacyDefaults = ResolvedGenerationOptions() + } + + /// Creates a Llama language model using legacy parameter defaults. + /// + /// - Important: This initializer is deprecated. Use + /// `init(modelPath:defaultOptions:)` and configure per-request values via + /// ``GenerationOptions`` custom options instead. + @available( + *, + deprecated, + message: "Use init(modelPath:) and pass values via GenerationOptions custom options" + ) + public convenience init( modelPath: String, contextSize: UInt32 = 2048, batchSize: UInt32 = 512, @@ -206,16 +434,9 @@ import Foundation repeatPenalty: Float = 1.1, repeatLastN: Int32 = 64 ) { - self.modelPath = modelPath - self.contextSize = contextSize - self.batchSize = batchSize - self.threads = threads - self.seed = seed - self.temperature = temperature - self.topK = topK - self.topP = topP - self.repeatPenalty = repeatPenalty - self.repeatLastN = repeatLastN + // Deprecated: prefer setting these via GenerationOptions custom options. + // We intentionally ignore legacy parameters to avoid storing model-level state. + self.init(modelPath: modelPath) } deinit { @@ -241,7 +462,8 @@ import Foundation try await ensureModelLoaded() - let contextParams = createContextParams(from: options) + let runtimeOptions = resolvedOptions(from: options) + let contextParams = createContextParams(from: runtimeOptions) // Try to create context with error handling guard let context = llama_init_from_model(model!, contextParams) else { @@ -252,15 +474,15 @@ import Foundation llama_set_causal_attn(context, true) llama_set_warmup(context, false) - llama_set_n_threads(context, threads, threads) + llama_set_n_threads(context, runtimeOptions.threads, runtimeOptions.threads) - let maxTokens = options.maximumResponseTokens ?? 100 + let maxTokens = runtimeOptions.maximumResponseTokens ?? 100 let text = try await generateText( context: context, model: model!, prompt: prompt.description, maxTokens: maxTokens, - options: options + options: runtimeOptions ) return LanguageModelSession.Response( @@ -293,15 +515,15 @@ import Foundation ) } - let maxTokens = options.maximumResponseTokens ?? 100 - let stream: AsyncThrowingStream.Snapshot, any Error> = AsyncThrowingStream { continuation in let task = Task { do { try await ensureModelLoaded() - let contextParams = createContextParams(from: options) + let runtimeOptions = resolvedOptions(from: options) + let maxTokens = runtimeOptions.maximumResponseTokens ?? 100 + let contextParams = createContextParams(from: runtimeOptions) guard let context = llama_init_from_model(model!, contextParams) else { throw LlamaLanguageModelError.contextInitializationFailed } @@ -310,7 +532,7 @@ import Foundation // Stabilize runtime behavior per-context llama_set_causal_attn(context, true) llama_set_warmup(context, false) - llama_set_n_threads(context, self.threads, self.threads) + llama_set_n_threads(context, runtimeOptions.threads, runtimeOptions.threads) var accumulatedText = "" @@ -320,7 +542,7 @@ import Foundation model: model!, prompt: prompt.description, maxTokens: maxTokens, - options: options + options: runtimeOptions ) { accumulatedText += tokenText @@ -390,21 +612,88 @@ import Foundation return params } - private func createContextParams(from options: GenerationOptions) -> llama_context_params { + private func resolvedOptions(from options: GenerationOptions) -> ResolvedGenerationOptions { + ResolvedGenerationOptions( + base: legacyDefaults, + overrides: options[custom: LlamaLanguageModel.self], + sampling: options.sampling, + maximumResponseTokens: options.maximumResponseTokens + ) + } + + private func createContextParams(from options: ResolvedGenerationOptions) -> llama_context_params { var params = llama_context_default_params() - params.n_ctx = contextSize - params.n_batch = batchSize - params.n_threads = threads - params.n_threads_batch = threads + params.n_ctx = options.contextSize + params.n_batch = options.batchSize + params.n_threads = options.threads + params.n_threads_batch = options.threads return params } + private func applySampling( + sampler: UnsafeMutablePointer, + effectiveTemperature: Float, + options: ResolvedGenerationOptions + ) { + if let mirostat = options.mirostat { + llama_sampler_chain_add(sampler, llama_sampler_init_temp(effectiveTemperature)) + + switch mirostat { + case .v1(let tau, let eta): + llama_sampler_chain_add( + sampler, + llama_sampler_init_mirostat( + Int32(options.contextSize), + options.seed, + tau, + eta, + 100 + ) + ) + case .v2(let tau, let eta): + llama_sampler_chain_add(sampler, llama_sampler_init_mirostat_v2(options.seed, tau, eta)) + } + return + } + + if let sampling = options.sampling { + switch sampling.mode { + case .greedy: + llama_sampler_chain_add(sampler, llama_sampler_init_top_k(1)) + llama_sampler_chain_add(sampler, llama_sampler_init_top_p(1.0, 1)) + llama_sampler_chain_add(sampler, llama_sampler_init_greedy()) + case .topK(let k, let seed): + llama_sampler_chain_add(sampler, llama_sampler_init_top_k(Int32(k))) + llama_sampler_chain_add(sampler, llama_sampler_init_top_p(1.0, 1)) + llama_sampler_chain_add(sampler, llama_sampler_init_temp(effectiveTemperature)) + let samplingSeed = seed.map(UInt32.init) ?? options.seed + llama_sampler_chain_add(sampler, llama_sampler_init_dist(samplingSeed)) + case .nucleus(let threshold, let seed): + llama_sampler_chain_add(sampler, llama_sampler_init_top_k(0)) + llama_sampler_chain_add(sampler, llama_sampler_init_top_p(Float(threshold), 1)) + llama_sampler_chain_add(sampler, llama_sampler_init_temp(effectiveTemperature)) + let samplingSeed = seed.map(UInt32.init) ?? options.seed + llama_sampler_chain_add(sampler, llama_sampler_init_dist(samplingSeed)) + } + return + } + + if options.topK > 0 { + llama_sampler_chain_add(sampler, llama_sampler_init_top_k(options.topK)) + } + if options.topP < 1.0 { + llama_sampler_chain_add(sampler, llama_sampler_init_top_p(options.topP, 1)) + } + llama_sampler_chain_add(sampler, llama_sampler_init_temp(effectiveTemperature)) + llama_sampler_chain_add(sampler, llama_sampler_init_dist(options.seed)) + } + private func generateText( context: OpaquePointer, model: OpaquePointer, prompt: String, maxTokens: Int, - options: GenerationOptions + options: ResolvedGenerationOptions ) async throws -> String { @@ -418,7 +707,7 @@ import Foundation throw LlamaLanguageModelError.tokenizationFailed } - var batch = llama_batch_init(Int32(batchSize), 0, 1) + var batch = llama_batch_init(Int32(options.batchSize), 0, 1) defer { llama_batch_free(batch) } batch.n_tokens = Int32(promptTokens.count) @@ -446,19 +735,19 @@ import Foundation throw LlamaLanguageModelError.decodingFailed } defer { llama_sampler_free(sampler) } + let samplerPtr = UnsafeMutablePointer(sampler) - // Get custom options if provided - let customOptions = options[custom: LlamaLanguageModel.self] + let effectiveTemperature = Float(options.temperature) // Apply repeat/frequency/presence penalties from custom options - let effectiveRepeatPenalty = customOptions?.repeatPenalty ?? repeatPenalty - let effectiveRepeatLastN = customOptions?.repeatLastN ?? repeatLastN - let effectiveFrequencyPenalty = customOptions?.frequencyPenalty ?? 0.0 - let effectivePresencePenalty = customOptions?.presencePenalty ?? 0.0 + let effectiveRepeatPenalty = options.repeatPenalty + let effectiveRepeatLastN = options.repeatLastN + let effectiveFrequencyPenalty = options.frequencyPenalty + let effectivePresencePenalty = options.presencePenalty if effectiveRepeatPenalty != 1.0 || effectiveFrequencyPenalty != 0.0 || effectivePresencePenalty != 0.0 { llama_sampler_chain_add( - sampler, + samplerPtr, llama_sampler_init_penalties( effectiveRepeatLastN, effectiveRepeatPenalty, @@ -468,52 +757,7 @@ import Foundation ) } - // Check for mirostat sampling (takes precedence over standard sampling) - if let mirostat = customOptions?.mirostat { - let temp = Float(options.temperature ?? Double(temperature)) - llama_sampler_chain_add(sampler, llama_sampler_init_temp(temp)) - - switch mirostat { - case .v1(let tau, let eta): - llama_sampler_chain_add( - sampler, - llama_sampler_init_mirostat(Int32(contextSize), seed, tau, eta, 100) - ) - case .v2(let tau, let eta): - llama_sampler_chain_add(sampler, llama_sampler_init_mirostat_v2(seed, tau, eta)) - } - } else if let sampling = options.sampling { - // Use standard sampling parameters from options - switch sampling.mode { - case .greedy: - llama_sampler_chain_add(sampler, llama_sampler_init_top_k(1)) - llama_sampler_chain_add(sampler, llama_sampler_init_top_p(1.0, 1)) - llama_sampler_chain_add(sampler, llama_sampler_init_greedy()) - case .topK(let k, let seed): - llama_sampler_chain_add(sampler, llama_sampler_init_top_k(Int32(k))) - llama_sampler_chain_add(sampler, llama_sampler_init_top_p(1.0, 1)) - if let temperature = options.temperature { - llama_sampler_chain_add(sampler, llama_sampler_init_temp(Float(temperature))) - } - if let seed = seed { - llama_sampler_chain_add(sampler, llama_sampler_init_dist(UInt32(seed))) - } - case .nucleus(let threshold, let seed): - llama_sampler_chain_add(sampler, llama_sampler_init_top_k(0)) // Disable top-k - llama_sampler_chain_add(sampler, llama_sampler_init_top_p(Float(threshold), 1)) - if let temperature = options.temperature { - llama_sampler_chain_add(sampler, llama_sampler_init_temp(Float(temperature))) - } - if let seed = seed { - llama_sampler_chain_add(sampler, llama_sampler_init_dist(UInt32(seed))) - } - } - } else { - // Use model's default sampling parameters - if topK > 0 { llama_sampler_chain_add(sampler, llama_sampler_init_top_k(topK)) } - if topP < 1.0 { llama_sampler_chain_add(sampler, llama_sampler_init_top_p(topP, 1)) } - llama_sampler_chain_add(sampler, llama_sampler_init_dist(seed)) - } + applySampling(sampler: samplerPtr, effectiveTemperature: effectiveTemperature, options: options) // Generate tokens one by one var generatedText = "" @@ -560,10 +804,8 @@ import Foundation model: OpaquePointer, prompt: String, maxTokens: Int, - options: GenerationOptions - ) - -> AsyncThrowingStream - { + options: ResolvedGenerationOptions + ) -> AsyncThrowingStream { return AsyncThrowingStream { continuation in self.performTextGeneration( context: context, @@ -581,7 +823,7 @@ import Foundation model: OpaquePointer, prompt: String, maxTokens: Int, - options: GenerationOptions, + options: ResolvedGenerationOptions, continuation: AsyncThrowingStream.Continuation ) { do { @@ -598,7 +840,7 @@ import Foundation } // Initialize batch - var batch = llama_batch_init(Int32(batchSize), 0, 1) + var batch = llama_batch_init(Int32(options.batchSize), 0, 1) defer { llama_batch_free(batch) } // Evaluate the prompt @@ -626,20 +868,20 @@ import Foundation throw LlamaLanguageModelError.decodingFailed } defer { llama_sampler_free(sampler) } + let samplerPtr = UnsafeMutablePointer(sampler) - // Get custom options if provided - let customOptions = options[custom: LlamaLanguageModel.self] + let effectiveTemperature = Float(options.temperature) // Apply repeat/frequency/presence penalties from custom options - let effectiveRepeatPenalty = customOptions?.repeatPenalty ?? self.repeatPenalty - let effectiveRepeatLastN = customOptions?.repeatLastN ?? self.repeatLastN - let effectiveFrequencyPenalty = customOptions?.frequencyPenalty ?? 0.0 - let effectivePresencePenalty = customOptions?.presencePenalty ?? 0.0 + let effectiveRepeatPenalty = options.repeatPenalty + let effectiveRepeatLastN = options.repeatLastN + let effectiveFrequencyPenalty = options.frequencyPenalty + let effectivePresencePenalty = options.presencePenalty if effectiveRepeatPenalty != 1.0 || effectiveFrequencyPenalty != 0.0 || effectivePresencePenalty != 0.0 { llama_sampler_chain_add( - sampler, + samplerPtr, llama_sampler_init_penalties( effectiveRepeatLastN, effectiveRepeatPenalty, @@ -650,50 +892,7 @@ import Foundation } // Check for mirostat sampling (takes precedence over standard sampling) - if let mirostat = customOptions?.mirostat { - let temp = Float(options.temperature ?? Double(self.temperature)) - llama_sampler_chain_add(sampler, llama_sampler_init_temp(temp)) - - switch mirostat { - case .v1(let tau, let eta): - llama_sampler_chain_add( - sampler, - llama_sampler_init_mirostat(Int32(self.contextSize), self.seed, tau, eta, 100) - ) - case .v2(let tau, let eta): - llama_sampler_chain_add(sampler, llama_sampler_init_mirostat_v2(self.seed, tau, eta)) - } - } else if let sampling = options.sampling { - // Use standard sampling parameters from options - switch sampling.mode { - case .greedy: - llama_sampler_chain_add(sampler, llama_sampler_init_top_k(1)) - llama_sampler_chain_add(sampler, llama_sampler_init_top_p(1.0, 1)) - case .topK(let k, let seed): - llama_sampler_chain_add(sampler, llama_sampler_init_top_k(Int32(k))) - llama_sampler_chain_add(sampler, llama_sampler_init_top_p(1.0, 1)) - if let temperature = options.temperature { - llama_sampler_chain_add(sampler, llama_sampler_init_temp(Float(temperature))) - } - if let seed = seed { - llama_sampler_chain_add(sampler, llama_sampler_init_dist(UInt32(seed))) - } - case .nucleus(let threshold, let seed): - llama_sampler_chain_add(sampler, llama_sampler_init_top_k(0)) // Disable top-k - llama_sampler_chain_add(sampler, llama_sampler_init_top_p(Float(threshold), 1)) - if let temperature = options.temperature { - llama_sampler_chain_add(sampler, llama_sampler_init_temp(Float(temperature))) - } - if let seed = seed { - llama_sampler_chain_add(sampler, llama_sampler_init_dist(UInt32(seed))) - } - } - } else { - // Use model's default sampling parameters - if self.topK > 0 { llama_sampler_chain_add(sampler, llama_sampler_init_top_k(self.topK)) } - if self.topP < 1.0 { llama_sampler_chain_add(sampler, llama_sampler_init_top_p(self.topP, 1)) } - llama_sampler_chain_add(sampler, llama_sampler_init_dist(self.seed)) - } + applySampling(sampler: samplerPtr, effectiveTemperature: effectiveTemperature, options: options) // Generate tokens one by one var n_cur = batch.n_tokens diff --git a/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift b/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift index eab1b3ed..b5fecdaa 100644 --- a/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift @@ -11,22 +11,82 @@ import Testing ) struct LlamaLanguageModelTests { let model = LlamaLanguageModel( - modelPath: ProcessInfo.processInfo.environment["LLAMA_MODEL_PATH"]!, - contextSize: 2048, - temperature: 0.8 + modelPath: ProcessInfo.processInfo.environment["LLAMA_MODEL_PATH"]! ) @Test func initialization() { - let customModel = LlamaLanguageModel( - modelPath: "/path/to/model.gguf", - contextSize: 4096, - temperature: 0.7, - topK: 50 - ) + let customModel = LlamaLanguageModel(modelPath: "/path/to/model.gguf") #expect(customModel.modelPath == "/path/to/model.gguf") - #expect(customModel.contextSize == 4096) - #expect(customModel.temperature == 0.7) - #expect(customModel.topK == 50) + #expect(customModel.contextSize == 2048) + #expect(customModel.batchSize == 512) + #expect(customModel.threads == Int32(ProcessInfo.processInfo.processorCount)) + #expect(customModel.temperature == 0.8) + #expect(customModel.topK == 40) + #expect(customModel.topP == 0.95) + #expect(customModel.repeatPenalty == 1.1) + #expect(customModel.repeatLastN == 64) + } + + @Test func customGenerationOptionsRoundTrip() { + var options = GenerationOptions( + temperature: 0.6, + maximumResponseTokens: 25 + ) + + let custom = LlamaLanguageModel.CustomGenerationOptions( + contextSize: 1024, + batchSize: 256, + threads: 1, + seed: 42, + temperature: 0.55, + topK: 25, + topP: 0.85, + repeatPenalty: 1.15, + repeatLastN: 48, + frequencyPenalty: 0.05, + presencePenalty: 0.05, + mirostat: .v2(tau: 5.0, eta: 0.2) + ) + options[custom: LlamaLanguageModel.self] = custom + + let retrieved = options[custom: LlamaLanguageModel.self] + #expect(retrieved?.contextSize == 1024) + #expect(retrieved?.batchSize == 256) + #expect(retrieved?.threads == 1) + #expect(retrieved?.seed == 42) + #expect(retrieved?.temperature == 0.55) + #expect(retrieved?.topK == 25) + #expect(retrieved?.topP == 0.85) + #expect(retrieved?.repeatPenalty == 1.15) + #expect(retrieved?.repeatLastN == 48) + #expect(retrieved?.frequencyPenalty == 0.05) + #expect(retrieved?.presencePenalty == 0.05) + #expect(retrieved?.mirostat == .v2(tau: 5.0, eta: 0.2)) + } + + @Test func deprecatedInitializerFallback() { + let legacy = LlamaLanguageModel( + modelPath: "/legacy/model.gguf", + contextSize: 1024, + batchSize: 128, + threads: 3, + seed: 7, + temperature: 0.65, + topK: 32, + topP: 0.88, + repeatPenalty: 1.02, + repeatLastN: 24 + ) + + // Deprecated initializer ignores parameters; defaults are used. + #expect(legacy.contextSize == 2048) + #expect(legacy.batchSize == 512) + #expect(legacy.threads == Int32(ProcessInfo.processInfo.processorCount)) + #expect(legacy.temperature == 0.8) + #expect(legacy.topK == 40) + #expect(legacy.topP == 0.95) + #expect(legacy.repeatPenalty == 1.1) + #expect(legacy.repeatLastN == 64) } @Test func logLevelConfiguration() { @@ -169,6 +229,13 @@ import Testing // Set llama.cpp-specific custom options options[custom: LlamaLanguageModel.self] = .init( + contextSize: 1024, + batchSize: 256, + threads: 2, + seed: 123, + temperature: 0.75, + topK: 30, + topP: 0.9, repeatPenalty: 1.2, repeatLastN: 128, frequencyPenalty: 0.1, From ca8bc2b2de0aaea25a39936133ca6eefd154a56a Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 10 Dec 2025 08:37:59 -0800 Subject: [PATCH 2/6] Incorporate feedback from review --- .../Models/LlamaLanguageModel.swift | 27 +++++-------------- .../LlamaLanguageModelTests.swift | 11 ++++++++ 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index 8f7ed740..8ab413fc 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -185,24 +185,6 @@ import Foundation ) } - /// Returns a new options struct with values from `overrides` when provided. - func merging(overrides: CustomGenerationOptions?) -> CustomGenerationOptions { - guard let overrides else { return self } - return CustomGenerationOptions( - contextSize: overrides.contextSize ?? self.contextSize, - batchSize: overrides.batchSize ?? self.batchSize, - threads: overrides.threads ?? self.threads, - seed: overrides.seed ?? self.seed, - temperature: overrides.temperature ?? self.temperature, - topK: overrides.topK ?? self.topK, - topP: overrides.topP ?? self.topP, - repeatPenalty: overrides.repeatPenalty ?? self.repeatPenalty, - repeatLastN: overrides.repeatLastN ?? self.repeatLastN, - frequencyPenalty: overrides.frequencyPenalty ?? self.frequencyPenalty, - presencePenalty: overrides.presencePenalty ?? self.presencePenalty, - mirostat: overrides.mirostat ?? self.mirostat - ) - } } /// The path to the GGUF model file. @@ -613,8 +595,13 @@ import Foundation } private func resolvedOptions(from options: GenerationOptions) -> ResolvedGenerationOptions { - ResolvedGenerationOptions( - base: legacyDefaults, + var base = legacyDefaults + if let temp = options.temperature { + base.temperature = Float(temp) + } + + return ResolvedGenerationOptions( + base: base, overrides: options[custom: LlamaLanguageModel.self], sampling: options.sampling, maximumResponseTokens: options.maximumResponseTokens diff --git a/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift b/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift index b5fecdaa..105a443d 100644 --- a/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift @@ -64,6 +64,17 @@ import Testing #expect(retrieved?.mirostat == .v2(tau: 5.0, eta: 0.2)) } + @Test func customGenerationOptionsDefaults() { + let defaults = LlamaLanguageModel.CustomGenerationOptions.defaults() + #expect(defaults.contextSize == 2048) + #expect(defaults.batchSize == 512) + #expect(defaults.temperature == 0.8) + #expect(defaults.topK == 40) + #expect(defaults.topP == 0.95) + #expect(defaults.repeatPenalty == 1.1) + #expect(defaults.repeatLastN == 64) + } + @Test func deprecatedInitializerFallback() { let legacy = LlamaLanguageModel( modelPath: "/legacy/model.gguf", From cb7e45e706e05e4d9979e3d6945afe25a4a8eca6 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 10 Dec 2025 08:42:21 -0800 Subject: [PATCH 3/6] Change spelling to CustomGenerationOptions.default --- Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift | 6 ++---- Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index 8ab413fc..e1c40835 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -166,14 +166,12 @@ import Foundation } /// Default llama.cpp options used when none are provided at runtime. - public static func defaults( - seed: UInt32 = UInt32.random(in: 0 ... UInt32.max) - ) -> Self { + public static var `default`: Self { .init( contextSize: 2048, batchSize: 512, threads: Int32(ProcessInfo.processInfo.processorCount), - seed: seed, + seed: UInt32.random(in: 0 ... UInt32.max), temperature: 0.8, topK: 40, topP: 0.95, diff --git a/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift b/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift index 105a443d..1e8f57c6 100644 --- a/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift @@ -65,7 +65,7 @@ import Testing } @Test func customGenerationOptionsDefaults() { - let defaults = LlamaLanguageModel.CustomGenerationOptions.defaults() + let defaults = LlamaLanguageModel.CustomGenerationOptions.default #expect(defaults.contextSize == 2048) #expect(defaults.batchSize == 512) #expect(defaults.temperature == 0.8) From 81d1bc21ac4a978724073f05a8448c383f7eedc4 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 10 Dec 2025 09:15:03 -0800 Subject: [PATCH 4/6] Randomize seed on each request by default --- Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index e1c40835..4abc8ec0 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -166,12 +166,15 @@ import Foundation } /// Default llama.cpp options used when none are provided at runtime. + /// + /// The `seed` is `nil` by default, meaning a random seed will be generated + /// for each generation request. public static var `default`: Self { .init( contextSize: 2048, batchSize: 512, threads: Int32(ProcessInfo.processInfo.processorCount), - seed: UInt32.random(in: 0 ... UInt32.max), + seed: nil, temperature: 0.8, topK: 40, topP: 0.95, From 25c2e55cdad625921ca0db3beb1b2e3b9cc85743 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 10 Dec 2025 09:15:15 -0800 Subject: [PATCH 5/6] Fix documentation comment --- Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index 4abc8ec0..1764505e 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -398,7 +398,7 @@ import Foundation /// Creates a Llama language model using legacy parameter defaults. /// /// - Important: This initializer is deprecated. Use - /// `init(modelPath:defaultOptions:)` and configure per-request values via + /// `init(modelPath:)` and configure per-request values via /// ``GenerationOptions`` custom options instead. @available( *, From e91764eddff46772a8845fb8a81828c88301612e Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 10 Dec 2025 09:15:27 -0800 Subject: [PATCH 6/6] Exercise all of default properties --- Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift b/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift index 1e8f57c6..b6fa0ccd 100644 --- a/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift @@ -68,11 +68,16 @@ import Testing let defaults = LlamaLanguageModel.CustomGenerationOptions.default #expect(defaults.contextSize == 2048) #expect(defaults.batchSize == 512) + #expect(defaults.threads == Int32(ProcessInfo.processInfo.processorCount)) + #expect(defaults.seed == nil) #expect(defaults.temperature == 0.8) #expect(defaults.topK == 40) #expect(defaults.topP == 0.95) #expect(defaults.repeatPenalty == 1.1) #expect(defaults.repeatLastN == 64) + #expect(defaults.frequencyPenalty == 0.0) + #expect(defaults.presencePenalty == 0.0) + #expect(defaults.mirostat == nil) } @Test func deprecatedInitializerFallback() {