Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ jobs:
run: |
pre-commit run --all || (echo "Style checks failed, please install pre-commit and run pre-commit run --all and push the change"; echo ""; git --no-pager diff; exit 1)

linux_build:
needs: lint
if: github.repository == 'ml-explore/mlx-swift-lm'
runs-on: ubuntu-24.04
container:
image: swift:6.2.3-noble
steps:
- uses: actions/checkout@v6
with:
submodules: recursive

- name: Build
run: swift build

mac_build_and_test:
needs: lint
if: github.repository == 'ml-explore/mlx-swift-lm'
Expand Down
119 changes: 65 additions & 54 deletions Libraries/MLXLLM/LLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import Tokenizers

/// Creates a function that decodes configuration data and instantiates a model with the proper configuration
private func create<C: Codable, M>(
_ configurationType: C.Type, _ modelInit: @escaping (C) -> M
) -> (Data) throws -> M {
_ configurationType: C.Type, _ modelInit: @escaping @Sendable (C) -> M
) -> @Sendable (Data) throws -> M {
{ data in
let configuration = try JSONDecoder().decode(C.self, from: data)
return modelInit(configuration)
Expand All @@ -22,58 +22,69 @@ private func create<C: Codable, M>(
public enum LLMTypeRegistry {

/// Shared instance with default model types.
public static let shared: ModelTypeRegistry = .init(creators: [
"mistral": create(LlamaConfiguration.self, LlamaModel.init),
"llama": create(LlamaConfiguration.self, LlamaModel.init),
"phi": create(PhiConfiguration.self, PhiModel.init),
"phi3": create(Phi3Configuration.self, Phi3Model.init),
"phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init),
"gemma": create(GemmaConfiguration.self, GemmaModel.init),
"gemma2": create(Gemma2Configuration.self, Gemma2Model.init),
"gemma3": create(Gemma3TextConfiguration.self, Gemma3TextModel.init),
"gemma3_text": create(Gemma3TextConfiguration.self, Gemma3TextModel.init),
"gemma3n": create(Gemma3nTextConfiguration.self, Gemma3nTextModel.init),
"qwen2": create(Qwen2Configuration.self, Qwen2Model.init),
"qwen3": create(Qwen3Configuration.self, Qwen3Model.init),
"qwen3_moe": create(Qwen3MoEConfiguration.self, Qwen3MoEModel.init),
"qwen3_next": create(Qwen3NextConfiguration.self, Qwen3NextModel.init),
"minicpm": create(MiniCPMConfiguration.self, MiniCPMModel.init),
"starcoder2": create(Starcoder2Configuration.self, Starcoder2Model.init),
"cohere": create(CohereConfiguration.self, CohereModel.init),
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
"deepseek_v3": create(DeepseekV3Configuration.self, DeepseekV3Model.init),
"granite": create(GraniteConfiguration.self, GraniteModel.init),
"granitemoehybrid": create(
GraniteMoeHybridConfiguration.self, GraniteMoeHybridModel.init),
"mimo": create(MiMoConfiguration.self, MiMoModel.init),
"mimo_v2_flash": create(MiMoV2FlashConfiguration.self, MiMoV2FlashModel.init),
"minimax": create(MiniMaxConfiguration.self, MiniMaxModel.init),
"glm4": create(GLM4Configuration.self, GLM4Model.init),
"glm4_moe": create(GLM4MoEConfiguration.self, GLM4MoEModel.init),
"glm4_moe_lite": create(GLM4MoELiteConfiguration.self, GLM4MoELiteModel.init),
"acereason": create(Qwen2Configuration.self, Qwen2Model.init),
"falcon_h1": create(FalconH1Configuration.self, FalconH1Model.init),
"bitnet": create(BitnetConfiguration.self, BitnetModel.init),
"smollm3": create(SmolLM3Configuration.self, SmolLM3Model.init),
"ernie4_5": create(Ernie45Configuration.self, Ernie45Model.init),
"lfm2": create(LFM2Configuration.self, LFM2Model.init),
"baichuan_m1": create(BaichuanM1Configuration.self, BaichuanM1Model.init),
"exaone4": create(Exaone4Configuration.self, Exaone4Model.init),
"gpt_oss": create(GPTOSSConfiguration.self, GPTOSSModel.init),
"lille-130m": create(Lille130mConfiguration.self, Lille130mModel.init),
"olmoe": create(OlmoEConfiguration.self, OlmoEModel.init),
"olmo2": create(Olmo2Configuration.self, Olmo2Model.init),
"olmo3": create(Olmo3Configuration.self, Olmo3Model.init),
"bailing_moe": create(BailingMoeConfiguration.self, BailingMoeModel.init),
"lfm2_moe": create(LFM2MoEConfiguration.self, LFM2MoEModel.init),
"nanochat": create(NanoChatConfiguration.self, NanoChatModel.init),
"nemotron_h": create(NemotronHConfiguration.self, NemotronHModel.init),
"afmoe": create(AfMoEConfiguration.self, AfMoEModel.init),
"jamba_3b": create(JambaConfiguration.self, JambaModel.init),
"mistral3": create(Mistral3TextConfiguration.self, Mistral3TextModel.init),
"apertus": create(ApertusConfiguration.self, ApertusModel.init),
])
public static let shared: ModelTypeRegistry = .init(creators: all())

/// All predefined model types.
private static func all() -> [String: @Sendable (Data) throws -> any LanguageModel] {
var models: [String: @Sendable (Data) throws -> any LanguageModel] = [
"mistral": create(LlamaConfiguration.self, LlamaModel.init),
"llama": create(LlamaConfiguration.self, LlamaModel.init),
"phi": create(PhiConfiguration.self, PhiModel.init),
"phi3": create(Phi3Configuration.self, Phi3Model.init),
"phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init),
"gemma": create(GemmaConfiguration.self, GemmaModel.init),
"gemma2": create(Gemma2Configuration.self, Gemma2Model.init),
"gemma3": create(Gemma3TextConfiguration.self, Gemma3TextModel.init),
"gemma3_text": create(Gemma3TextConfiguration.self, Gemma3TextModel.init),
"gemma3n": create(Gemma3nTextConfiguration.self, Gemma3nTextModel.init),
"qwen2": create(Qwen2Configuration.self, Qwen2Model.init),
"qwen3": create(Qwen3Configuration.self, Qwen3Model.init),
"qwen3_moe": create(Qwen3MoEConfiguration.self, Qwen3MoEModel.init),
"qwen3_next": create(Qwen3NextConfiguration.self, Qwen3NextModel.init),
"minicpm": create(MiniCPMConfiguration.self, MiniCPMModel.init),
"starcoder2": create(Starcoder2Configuration.self, Starcoder2Model.init),
"cohere": create(CohereConfiguration.self, CohereModel.init),
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
"deepseek_v3": create(DeepseekV3Configuration.self, DeepseekV3Model.init),
"granite": create(GraniteConfiguration.self, GraniteModel.init),
"granitemoehybrid": create(
GraniteMoeHybridConfiguration.self, GraniteMoeHybridModel.init),
"mimo": create(MiMoConfiguration.self, MiMoModel.init),
"mimo_v2_flash": create(MiMoV2FlashConfiguration.self, MiMoV2FlashModel.init),
"minimax": create(MiniMaxConfiguration.self, MiniMaxModel.init),
"glm4": create(GLM4Configuration.self, GLM4Model.init),
"glm4_moe": create(GLM4MoEConfiguration.self, GLM4MoEModel.init),
"glm4_moe_lite": create(GLM4MoELiteConfiguration.self, GLM4MoELiteModel.init),
"acereason": create(Qwen2Configuration.self, Qwen2Model.init),
"falcon_h1": create(FalconH1Configuration.self, FalconH1Model.init),
"smollm3": create(SmolLM3Configuration.self, SmolLM3Model.init),
"ernie4_5": create(Ernie45Configuration.self, Ernie45Model.init),
"lfm2": create(LFM2Configuration.self, LFM2Model.init),
"baichuan_m1": create(BaichuanM1Configuration.self, BaichuanM1Model.init),
"exaone4": create(Exaone4Configuration.self, Exaone4Model.init),
"gpt_oss": create(GPTOSSConfiguration.self, GPTOSSModel.init),
"lille-130m": create(Lille130mConfiguration.self, Lille130mModel.init),
"olmoe": create(OlmoEConfiguration.self, OlmoEModel.init),
"olmo2": create(Olmo2Configuration.self, Olmo2Model.init),
"olmo3": create(Olmo3Configuration.self, Olmo3Model.init),
"bailing_moe": create(BailingMoeConfiguration.self, BailingMoeModel.init),
"lfm2_moe": create(LFM2MoEConfiguration.self, LFM2MoEModel.init),
"nanochat": create(NanoChatConfiguration.self, NanoChatModel.init),
"nemotron_h": create(NemotronHConfiguration.self, NemotronHModel.init),
"afmoe": create(AfMoEConfiguration.self, AfMoEModel.init),
"jamba_3b": create(JambaConfiguration.self, JambaModel.init),
"mistral3": create(Mistral3TextConfiguration.self, Mistral3TextModel.init),
"apertus": create(ApertusConfiguration.self, ApertusModel.init),
]

// Bitnet requires Metal custom kernels and is only available on Apple platforms
#if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) || os(visionOS)
models["bitnet"] = create(BitnetConfiguration.self, BitnetModel.init)
#endif

return models
}
}

/// Registry of models and any overrides that go with them, e.g. prompt augmentation.
Expand Down
3 changes: 1 addition & 2 deletions Libraries/MLXLLM/Lora+Data.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ enum LoRADataError: LocalizedError {
var errorDescription: String? {
switch self {
case .fileNotFound(let directory, let name):
return String(
localized: "Could not find data file '\(name)' in directory '\(directory.path())'.")
return "Could not find data file '\(name)' in directory '\(directory.path())'."
}
}
}
Expand Down
Loading
Loading